def main()

in archived/bandits_recsys_movielens_testbed/src/train.py [0:0]


def main():
    """Train a Vowpal Wabbit (VW) model through C++ process."""

    channel_names = json.loads(os.environ["SM_CHANNELS"])
    hyperparameters = json.loads(os.environ["SM_HPS"])

    # Fetch algorithm hyperparameters
    num_arms = int(hyperparameters.get("num_arms", 0))  # Used if arm features are not present
    num_policies = int(hyperparameters.get("num_policies", 3))
    exploration_policy = hyperparameters.get("exploration_policy", "egreedy").lower()
    epsilon = float(hyperparameters.get("epsilon", 0))
    mellowness = float(hyperparameters.get("mellowness", 0.01))
    arm_features_present = bool(hyperparameters.get("arm_features", True))

    # Fetch environment parameters
    item_pool_size = int(hyperparameters.get("item_pool_size", 0))
    top_k = int(hyperparameters.get("top_k", 5))
    max_users = int(hyperparameters.get("max_users", 3))
    total_interactions = int(hyperparameters.get("total_interactions", 1000))

    if not arm_features_present and num_arms is 0:
        raise ValueError("Customer Error: Please provide a non-zero value for 'num_arms'")
    logging.info("channels %s" % channel_names)
    logging.info("hps: %s" % hyperparameters)

    # Different exploration policies in VW
    # https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Contextual-Bandit-algorithms
    valid_policies = ["egreedy", "bag", "regcbopt", "regcb"]
    if exploration_policy not in valid_policies:
        raise ValueError(f"Customer Error: exploration_policy must be one of {valid_policies}.")

    if exploration_policy == "egreedy":
        vw_args_base = f"--cb_explore_adf --cb_type mtr --epsilon {epsilon}"
    elif exploration_policy in ["regcbopt", "regcb"]:
        vw_args_base = (
            f"--cb_explore_adf --cb_type mtr --{exploration_policy} --mellowness {mellowness}"
        )
    else:
        vw_args_base = f"--cb_explore_adf --cb_type mtr --{exploration_policy} {num_policies}"

    # If pre-trained model is present
    if MODEL_CHANNEL not in channel_names:
        logging.info(
            f"No pre-trained model has been specified in channel {MODEL_CHANNEL}."
            f"Training will start from scratch."
        )
        vw_agent = VWAgent(
            cli_args=vw_args_base,
            output_dir=MODEL_OUTPUT_DIR,
            model_path=None,
            test_only=False,
            quiet_mode=False,
            adf_mode=arm_features_present,
            num_actions=num_arms,
        )
    else:
        # Load the pre-trained model for training.
        model_folder = os.environ[f"SM_CHANNEL_{MODEL_CHANNEL.upper()}"]
        metadata_path, weights_path = extract_model(model_folder)
        logging.info(f"Loading model from {weights_path}")
        vw_agent = VWAgent.load_model(
            metadata_loc=metadata_path,
            weights_loc=weights_path,
            test_only=False,
            quiet_mode=False,
            output_dir=MODEL_OUTPUT_DIR,
        )

    # Start the VW C++ process. This python program will communicate with the C++ process using PIPES
    vw_agent.start()

    if "movielens" not in channel_names:
        raise ValueError(
            "Cannot find `movielens` channel. Please make sure to provide the data as `movielens` channel."
        )

    # Initialize MovieLens environment
    env = MovieLens100KEnv(
        data_dir=os.environ["SM_CHANNEL_MOVIELENS"],
        item_pool_size=item_pool_size,
        top_k=top_k,
        max_users=max_users,
    )

    regrets = []
    random_regrets = []

    obs = env.reset()

    # Learn by interacting with the environment
    for i in range(total_interactions):
        user_features, items_features = obs
        actions, probs = vw_agent.choose_actions(
            shared_features=user_features,
            candidate_arms_features=items_features,
            user_id=env.current_user_id,
            candidate_ids=env.current_item_pool,
            top_k=5,
        )

        clicks, regret, random_regret = env.get_feedback(actions)
        regrets.append(regret)
        random_regrets.append(random_regret)

        for index, reward in enumerate(clicks):
            vw_agent.learn(
                shared_features=user_features,
                candidate_arms_features=items_features,
                action_index=actions[index],
                reward=reward,
                user_id=env.current_user_id,
                candidate_ids=env.current_item_pool,
                action_prob=probs[index],
                cost_fn=lambda x: -x,
            )

        # Step the environment to pick next user and new list of candidate items
        obs, rewards, done, info = env.step(actions)
        if i % 500 == 0:
            logging.info(f"Processed {i} interactions")

    stdout = vw_agent.save_model(close=True)
    print(stdout.decode())
    logging.info(f"Model learned using {total_interactions} training experiences.")

    all_regrets = {"agent": regrets, "random": random_regrets}
    with open(os.path.join(DATA_OUTPUT_DIR, "output.json"), "w") as file:
        json.dump(all_regrets, file)