in src/SimpleReplay/extract.py [0:0]
def save_logs(logs, last_connections, output_directory, connections, start_time, end_time):
num_queries = 0
for filename, transaction in logs.items():
num_queries += len(transaction)
logger.info(
f"Exporting {len(logs)} transactions ({num_queries} queries) to {output_directory}"
)
is_s3 = True
if output_directory.startswith("s3://"):
output_s3_location = output_directory[5:].partition("/")
bucket_name = output_s3_location[0]
output_prefix = output_s3_location[2]
s3_client = boto3.client("s3")
else:
is_s3 = False
os.makedirs(output_directory + "/SQLs/")
missing_audit_log_connections = set()
# Save the main logs and find replacements
replacements = set()
for filename, queries in logs.items():
file_text = "--Time interval: true\n\n"
for query in queries:
query.text = remove_line_comments(query.text).strip()
header_info = "--Record time: " + query.record_time.isoformat() + "\n"
if query.start_time:
header_info += "--Start time: " + query.start_time.isoformat() + "\n"
if query.end_time:
header_info += "--End time: " + query.end_time.isoformat() + "\n"
try:
header_info += f"--Database: {query.database_name}\n"
header_info += f"--Username: {query.username}\n"
header_info += f"--Pid: {query.pid}\n"
header_info += f"--Xid: {query.xid}\n"
except AttributeError:
logger.error(f'Query is missing header info, skipping {filename}: {query}')
continue
if "copy " in query.text.lower() and "from 's3:" in query.text.lower(): #Raj
bucket = re.search(r"from 's3:\/\/[^']*", query.text, re.IGNORECASE).group()[6:]
replacements.add(bucket)
query.text = re.sub(
r"IAM_ROLE 'arn:aws:iam::\d+:role/\S+'",
f" IAM_ROLE ''",
query.text,
flags=re.IGNORECASE,
)
if "unload" in query.text.lower() and "to 's3:" in query.text.lower():
query.text = re.sub(
r"IAM_ROLE 'arn:aws:iam::\d+:role/\S+'",
f" IAM_ROLE ''",
query.text,
flags=re.IGNORECASE,
)
if not len(query.text) == 0:
query.text = f"/* Replay source file: {filename} */ {query.text.strip()}"
if not query.text.endswith(";"):
query.text += ";"
file_text += header_info + query.text + "\n"
if (
not hash((query.database_name, query.username, query.pid)) in last_connections
):
missing_audit_log_connections.add(
(query.database_name, query.username, query.pid)
)
if is_s3:
s3_client.put_object(
Body=file_text.strip(),
Bucket=bucket_name,
Key=output_prefix + "/SQLs/" + filename,
)
else:
f = open(output_directory + "/SQLs/" + filename, "w")
f.write(file_text.strip())
f.close()
logger.info(f"Generating {len(missing_audit_log_connections)} missing connections.")
for missing_audit_log_connection_info in missing_audit_log_connections:
connection = ConnectionLog(
start_time,
end_time, # for missing connections set start_time and end_time to our extraction range
missing_audit_log_connection_info[0],
missing_audit_log_connection_info[1],
missing_audit_log_connection_info[2],
)
pk = connection.get_pk()
connections[pk] = connection
logger.info(
f"Exporting a total of {len(connections.values())} connections to {output_directory}"
)
# Save the connections logs
sorted_connections = connections.values()
connections_dict = connection_time_replacement([connection.__dict__ for connection in sorted_connections])
connections_string = json.dumps(
[connection.__dict__ for connection in sorted_connections],
indent=4,
default=str,
)
if is_s3:
s3_client.put_object(
Body=connections_string,
Bucket=bucket_name,
Key=output_prefix + "/connections.json",
)
else:
connections_file = open(output_directory + "/connections.json", "x")
connections_file.write(connections_string)
connections_file.close()
# Save the replacements
logger.info(f"Exporting copy replacements to {output_directory}")
replacements_string = (
"Original location,Replacement location,Replacement IAM role\n"
)
for bucket in replacements:
replacements_string += bucket + ",,\n"
if is_s3:
s3_client.put_object(
Body=replacements_string,
Bucket=bucket_name,
Key=output_prefix + "/copy_replacements.csv",
)
else:
replacements_file = open(output_directory + "/copy_replacements.csv", "w")
replacements_file.write(replacements_string)
replacements_file.close()