in src/SimpleReplay/replay.py [0:0]
def main():
global logger
logger = init_logging(logging.INFO)
global g_config
parser = argparse.ArgumentParser()
parser.add_argument("config_file", type=argparse.FileType("r"), help="Location of replay config file.",)
args = parser.parse_args()
with args.config_file as stream:
try:
g_config = yaml.safe_load(stream)
except yaml.YAMLError as exception:
logger.error(exception)
validate_config(g_config)
level = logging.getLevelName(g_config.get('log_level', 'INFO').upper())
set_log_level(level)
if g_config.get("logfile_level") != "none":
level = logging.getLevelName(g_config.get('logfile_level', 'DEBUG').upper())
log_file = 'replay.log'
add_logfile(log_file, level=level, preamble=yaml.dump(g_config), backup_count=g_config.get("backup_count", 2))
# print the version
log_version()
replay_name = f'Replay_{g_config["target_cluster_endpoint"].split(".")[0]}_{datetime.datetime.now(tz=datetime.timezone.utc).isoformat()}'
# use a manager to share the stats dict to all processes
# manager = multiprocessing.Manager()
manager = SyncManager()
def init_manager():
signal.signal(signal.SIGINT, signal.SIG_IGN)
manager.start(init_manager)
if not g_config["replay_output"]:
g_config["replay_output"] = None
(connection_logs, total_connections) = parse_connections(
g_config["workload_location"],
g_config["time_interval_between_transactions"],
g_config["time_interval_between_queries"],
)
logger.info(f"Found {total_connections} total connections, {total_connections - len(connection_logs)} are excluded by filters. Replaying {len(connection_logs)}.")
# Associate transactions with connections
logger.info(
f"Loading transactions from {g_config['workload_location']}, this might take some time."
)
# group all connections by connection key
connection_idx_by_key = {}
for idx, c in enumerate(connection_logs):
connection_key = get_connection_key(c.database_name, c.username, c.pid)
connection_idx_by_key.setdefault(connection_key, []).append(idx)
all_transactions = parse_transactions(g_config["workload_location"])
transaction_count = len(all_transactions)
query_count = 0
# assign the correct connection to each transaction by looking at the most
# recent connection prior to the transaction start. This relies on connections
# being sorted.
for t in all_transactions:
connection_key = get_connection_key(t.database_name, t.username, t.pid)
possible_connections = connection_idx_by_key[connection_key]
best_match_idx = None
for c_idx in possible_connections:
# truncate session start time, since query/transaction time is truncated to seconds
if connection_logs[c_idx].session_initiation_time.replace(microsecond=0) > t.start_time():
break
best_match_idx = c_idx
if best_match_idx is None:
logger.warning(f"Couldn't find matching connection in {len(possible_connections)} connections for transaction {t}, skipping")
continue
connection_logs[best_match_idx].transactions.append(t)
query_count += len(t.queries)
logger.info(f"Found {transaction_count} transactions, {query_count} queries")
connection_logs = [_ for _ in connection_logs if len(_.transactions) > 0]
logger.info(f"{len(connection_logs)} connections contain transactions and will be replayed ")
global g_total_connections
g_total_connections = len(connection_logs)
first_event_time = datetime.datetime.now(tz=datetime.timezone.utc)
last_event_time = datetime.datetime.utcfromtimestamp(0).replace(
tzinfo=datetime.timezone.utc
)
for connection in connection_logs:
if (
connection.session_initiation_time
and connection.session_initiation_time < first_event_time
):
first_event_time = connection.session_initiation_time
if (
connection.disconnection_time
and connection.disconnection_time > last_event_time
):
last_event_time = connection.disconnection_time
if connection.transactions[0].queries[0].start_time and connection.transactions[0].queries[0].start_time < first_event_time:
first_event_time = connection.transactions[0].queries[0].start_time
if connection.transactions[-1].queries[-1].end_time and connection.transactions[-1].queries[-1].end_time > last_event_time:
last_event_time = connection.transactions[-1].queries[-1].end_time
logger.info(
"Estimated original workload execution time: "
+ str((last_event_time - first_event_time))
)
if g_config["execute_copy_statements"] == "true":
logger.debug("Configuring COPY replacements")
replacements = parse_copy_replacements(g_config["workload_location"])
assign_copy_replacements(connection_logs, replacements)
if g_config["execute_unload_statements"] == "true":
if g_config["unload_iam_role"]:
if g_config["replay_output"].startswith("s3://"):
logger.debug("Configuring UNLOADs")
assign_unloads(
connection_logs,
g_config["replay_output"],
replay_name,
g_config["unload_iam_role"],
)
else:
logger.debug(
'UNLOADs not configured since "replay_output" is not an S3 location.'
)
logger.debug("Configuring time intervals")
assign_time_intervals(connection_logs)
logger.debug("Configuring CREATE USER PASSWORD random replacements")
assign_create_user_password(connection_logs)
replay_start_time = datetime.datetime.now(tz=datetime.timezone.utc)
# test connection
try:
# use the first user as a test
get_connection_string(connection_logs[0].username, database=connection_logs[0].database_name, max_attempts=1)
except CredentialsException as e:
logger.error(f"Unable to retrieve credentials using GetClusterCredentials ({str(e)}). Please verify that an IAM policy exists granting access. See the README for more details.")
sys.exit(-1)
if len(connection_logs) == 0:
logger.info("No logs to replay, nothing to do.")
sys.exit()
# Actual replay
logger.debug("Starting replay")
per_process_stats = {}
try:
start_replay(connection_logs,
g_config["default_interface"],
g_config["odbc_driver"],
first_event_time,
last_event_time,
g_config.get("num_workers"),
manager,
per_process_stats,
transaction_count,
query_count)
except KeyboardInterrupt:
logger.warning("Got CTRL-C, exiting...")
logger.debug("Aggregating stats")
aggregated_stats = init_stats({})
for idx, stat in per_process_stats.items():
collect_stats(aggregated_stats, stat)
logger.info("Replay summary:")
logger.info(
f"Attempted to replay {query_count} queries, {transaction_count} transactions, {len(connection_logs)} connections."
)
try:
logger.info(
f"Successfully replayed {aggregated_stats.get('transaction_success', 0)} out of {transaction_count} ({round((aggregated_stats.get('transaction_success', 0)/transaction_count)*100)}%) transactions."
)
logger.info(
f"Successfully replayed {aggregated_stats.get('query_success', 0)} out of {query_count} ({round((aggregated_stats.get('query_success', 0)/query_count)*100)}%) queries."
)
except ZeroDivisionError:
pass
if g_config["replay_output"]:
error_location = g_config["replay_output"]
else:
error_location = g_config["workload_location"]
logger.info(f"Encountered {len(aggregated_stats['connection_error_log'])} connection errors and {len(aggregated_stats['transaction_error_log'])} transaction errors")
# and save them
export_errors(
aggregated_stats['connection_error_log'],
aggregated_stats['transaction_error_log'],
error_location,
replay_name,
)
logger.info(f"Replay finished in {datetime.datetime.now(tz=datetime.timezone.utc) - replay_start_time}.")
if (
g_config["replay_output"]
and g_config["unload_system_table_queries"]
and g_config["target_cluster_system_table_unload_iam_role"]
):
logger.info(f'Exporting system tables to {g_config["replay_output"]}')
unload_system_table(
g_config["default_interface"],
g_config["unload_system_table_queries"],
g_config["replay_output"] + "/" + replay_name,
g_config["target_cluster_system_table_unload_iam_role"],
)
logger.info(f'Exported system tables to {g_config["replay_output"]}')
print_stats(per_process_stats)
manager.shutdown()