def split_train_test()

in fclib/fclib/dataset/ojdata.py [0:0]


def split_train_test(data_dir, n_splits=1, horizon=2, gap=2, first_week=40, last_week=156, write_csv=False):
    """Generate training, testing, and auxiliary datasets. Training data includes the historical 
    sales and external features; testing data contains the future sales and external features; 
    auxiliary data includes the future price, deal, and advertisement information which can be 
    used for making predictions (we assume such auxiliary information is available at the time 
    when we generate the forecasts). Use this function to generate the train, test, aux data for
    each forecast period on the fly, or use write_csv flag to write data to files.

    Note that train_*.csv files in /train folder contain all the features in the training period
    and aux_*.csv files in /train folder contain all the features except 'logmove', 'constant',
    'profit' up until the forecast period end week. Both train_*.csv and auxi_*csv can be used for
    generating forecasts in each split. However, test_*.csv files in /test folder can only be used
    for model performance evaluation.

    Example:
        data_dir = "/home/ojdata"

        train, test, aux = split_train_test(data_dir=data_dir, n_splits=5, horizon=3, write_csv=True)

        print(len(train))
        print(len(test))
        print(len(aux))

    Args:
        data_dir (str): location of the download directory
        n_splits (int, optional): number of splits (folds) to generate (default: 1) 
        horizon (int, optional): forecasting horizon, number of weeks to forecast (default: 2) 
        gap (int, optional): gap between training and testing, number of weeks between last training 
            week and first test week (default: 2) 
        first_week (int, optional): first available week (default: 40) 
        last_week (int, optional): last available week (default: 156)
        write_csv (Boolean, optional): Whether to write out the data files or not (default: False)
    
    Returns:
        list[pandas.DataFrame]: a list containing train data frames for each split
        list[pandas.DataFrame]: a list containing test data frames for each split
        list[pandas.DataFrame]: a list containing aux data frames for each split
        
    """
    # Read sales data into dataframe
    sales = pd.read_csv(os.path.join(data_dir, "yx.csv"), index_col=0)

    if write_csv:
        TRAIN_DATA_DIR = os.path.join(data_dir, "train")
        TEST_DATA_DIR = os.path.join(data_dir, "test")
        if not os.path.isdir(TRAIN_DATA_DIR):
            os.mkdir(TRAIN_DATA_DIR)
        if not os.path.isdir(TEST_DATA_DIR):
            os.mkdir(TEST_DATA_DIR)

    train_df_list = list()
    test_df_list = list()
    aux_df_list = list()

    test_start_week_list, test_end_week_list, train_end_week_list = _gen_split_indices(
        n_splits, horizon, gap, first_week, last_week
    )

    for i in range(n_splits):
        data_mask = (sales.week >= first_week) & (sales.week <= train_end_week_list[i])
        train_df = sales[data_mask].copy()
        data_mask = (sales.week >= test_start_week_list[i]) & (sales.week <= test_end_week_list[i])
        test_df = sales[data_mask].copy()
        data_mask = (sales.week >= first_week) & (sales.week <= test_end_week_list[i])
        aux_df = sales[data_mask].copy()
        aux_df.drop(["logmove", "constant", "profit"], axis=1, inplace=True)

        if write_csv:
            roundstr = "_" + str(i + 1) if n_splits > 1 else ""
            train_df.to_csv(os.path.join(TRAIN_DATA_DIR, "train" + roundstr + ".csv"))
            test_df.to_csv(os.path.join(TEST_DATA_DIR, "test" + roundstr + ".csv"))
            aux_df.to_csv(os.path.join(TRAIN_DATA_DIR, "auxi" + roundstr + ".csv"))

        train_df_list.append(train_df)
        test_df_list.append(test_df)
        aux_df_list.append(aux_df)

    return train_df_list, test_df_list, aux_df_list