def main()

in src/SimpleReplay/replay.py [0:0]


def main():
    global logger
    logger = init_logging(logging.INFO)

    global g_config

    parser = argparse.ArgumentParser()
    parser.add_argument("config_file", type=argparse.FileType("r"), help="Location of replay config file.",)
    args = parser.parse_args()

    with args.config_file as stream:
        try:
            g_config = yaml.safe_load(stream)
        except yaml.YAMLError as exception:
            logger.error(exception)

    validate_config(g_config)

    level = logging.getLevelName(g_config.get('log_level', 'INFO').upper())
    set_log_level(level)

    if g_config.get("logfile_level") != "none":
        level = logging.getLevelName(g_config.get('logfile_level', 'DEBUG').upper())
        log_file = 'replay.log'
        add_logfile(log_file, level=level, preamble=yaml.dump(g_config), backup_count=g_config.get("backup_count", 2))

    # print the version
    log_version()

    replay_name = f'Replay_{g_config["target_cluster_endpoint"].split(".")[0]}_{datetime.datetime.now(tz=datetime.timezone.utc).isoformat()}'

    # use a manager to share the stats dict to all processes
#    manager = multiprocessing.Manager()
    manager = SyncManager()

    def init_manager():
        signal.signal(signal.SIGINT, signal.SIG_IGN)

    manager.start(init_manager)

    if not g_config["replay_output"]:
        g_config["replay_output"] = None

    (connection_logs, total_connections) = parse_connections(
        g_config["workload_location"],
        g_config["time_interval_between_transactions"],
        g_config["time_interval_between_queries"],
    )
    logger.info(f"Found {total_connections} total connections, {total_connections - len(connection_logs)} are excluded by filters. Replaying {len(connection_logs)}.")

    # Associate transactions with connections
    logger.info(
        f"Loading transactions from {g_config['workload_location']}, this might take some time."
    )

    # group all connections by connection key
    connection_idx_by_key = {}
    for idx, c in enumerate(connection_logs):
        connection_key = get_connection_key(c.database_name, c.username, c.pid)
        connection_idx_by_key.setdefault(connection_key, []).append(idx)

    all_transactions = parse_transactions(g_config["workload_location"])

    transaction_count = len(all_transactions)
    query_count = 0

    # assign the correct connection to each transaction by looking at the most
    # recent connection prior to the transaction start. This relies on connections
    # being sorted.
    for t in all_transactions:
        connection_key = get_connection_key(t.database_name, t.username, t.pid)
        possible_connections = connection_idx_by_key[connection_key]
        best_match_idx = None
        for c_idx in possible_connections:
            # truncate session start time, since query/transaction time is truncated to seconds
            if connection_logs[c_idx].session_initiation_time.replace(microsecond=0) > t.start_time():
                break
            best_match_idx = c_idx
        if best_match_idx is None:
            logger.warning(f"Couldn't find matching connection in {len(possible_connections)} connections for transaction {t}, skipping")
            continue
        connection_logs[best_match_idx].transactions.append(t)
        query_count += len(t.queries)

    logger.info(f"Found {transaction_count} transactions, {query_count} queries")
    connection_logs = [_ for _ in connection_logs if len(_.transactions) > 0]
    logger.info(f"{len(connection_logs)} connections contain transactions and will be replayed ")

    global g_total_connections
    g_total_connections = len(connection_logs)

    first_event_time = datetime.datetime.now(tz=datetime.timezone.utc)
    last_event_time = datetime.datetime.utcfromtimestamp(0).replace(
        tzinfo=datetime.timezone.utc
    )

    for connection in connection_logs:
        if (
            connection.session_initiation_time
            and connection.session_initiation_time < first_event_time
        ):
            first_event_time = connection.session_initiation_time
        if (
            connection.disconnection_time
            and connection.disconnection_time > last_event_time
        ):
            last_event_time = connection.disconnection_time

        if connection.transactions[0].queries[0].start_time and connection.transactions[0].queries[0].start_time < first_event_time:
            first_event_time = connection.transactions[0].queries[0].start_time
        if connection.transactions[-1].queries[-1].end_time and connection.transactions[-1].queries[-1].end_time > last_event_time:
            last_event_time = connection.transactions[-1].queries[-1].end_time

    logger.info(
        "Estimated original workload execution time: "
        + str((last_event_time - first_event_time))
    )

    if g_config["execute_copy_statements"] == "true":
        logger.debug("Configuring COPY replacements")
        replacements = parse_copy_replacements(g_config["workload_location"])
        assign_copy_replacements(connection_logs, replacements)

    if g_config["execute_unload_statements"] == "true":
        if g_config["unload_iam_role"]:
            if g_config["replay_output"].startswith("s3://"):
                logger.debug("Configuring UNLOADs")
                assign_unloads(
                    connection_logs,
                    g_config["replay_output"],
                    replay_name,
                    g_config["unload_iam_role"],
                )
            else:
                logger.debug(
                    'UNLOADs not configured since "replay_output" is not an S3 location.'
                )

    logger.debug("Configuring time intervals")
    assign_time_intervals(connection_logs)

    logger.debug("Configuring CREATE USER PASSWORD random replacements")
    assign_create_user_password(connection_logs)

    replay_start_time = datetime.datetime.now(tz=datetime.timezone.utc)

    # test connection
    try:
        # use the first user as a test
        get_connection_string(connection_logs[0].username, database=connection_logs[0].database_name, max_attempts=1)
    except CredentialsException as e:
        logger.error(f"Unable to retrieve credentials using GetClusterCredentials ({str(e)}).  Please verify that an IAM policy exists granting access.  See the README for more details.")
        sys.exit(-1)

    if len(connection_logs) == 0:
        logger.info("No logs to replay, nothing to do.")
        sys.exit()

    # Actual replay
    logger.debug("Starting replay")
    per_process_stats = {}
    try:
        start_replay(connection_logs,
                     g_config["default_interface"],
                     g_config["odbc_driver"],
                     first_event_time,
                     last_event_time,
                     g_config.get("num_workers"),
                     manager,
                     per_process_stats,
                     transaction_count,
                     query_count)
    except KeyboardInterrupt:
        logger.warning("Got CTRL-C, exiting...")

    logger.debug("Aggregating stats")
    aggregated_stats = init_stats({})
    for idx, stat in per_process_stats.items():
        collect_stats(aggregated_stats, stat)

    logger.info("Replay summary:")
    logger.info(
        f"Attempted to replay {query_count} queries, {transaction_count} transactions, {len(connection_logs)} connections."
    )
    try:
        logger.info(
            f"Successfully replayed {aggregated_stats.get('transaction_success', 0)} out of {transaction_count} ({round((aggregated_stats.get('transaction_success', 0)/transaction_count)*100)}%) transactions."
        )
        logger.info(
            f"Successfully replayed {aggregated_stats.get('query_success', 0)} out of {query_count} ({round((aggregated_stats.get('query_success', 0)/query_count)*100)}%) queries."
        )
    except ZeroDivisionError:
        pass

    if g_config["replay_output"]:
        error_location = g_config["replay_output"]
    else:
        error_location = g_config["workload_location"]

    logger.info(f"Encountered {len(aggregated_stats['connection_error_log'])} connection errors and {len(aggregated_stats['transaction_error_log'])} transaction errors")

    # and save them
    export_errors(
        aggregated_stats['connection_error_log'],
        aggregated_stats['transaction_error_log'],
        error_location,
        replay_name,
    )

    logger.info(f"Replay finished in {datetime.datetime.now(tz=datetime.timezone.utc) - replay_start_time}.")

    if (
        g_config["replay_output"]
        and g_config["unload_system_table_queries"]
        and g_config["target_cluster_system_table_unload_iam_role"]
    ):
        logger.info(f'Exporting system tables to {g_config["replay_output"]}')

        unload_system_table(
            g_config["default_interface"],
            g_config["unload_system_table_queries"],
            g_config["replay_output"] + "/" + replay_name,
            g_config["target_cluster_system_table_unload_iam_role"],
        )

        logger.info(f'Exported system tables to {g_config["replay_output"]}')

    print_stats(per_process_stats)
    manager.shutdown()