def train_fn()

in archived/churn_prediction_multimodality_of_text_and_tabular/containers/huggingface_transformer_randomforest/entry_point.py [0:0]


def train_fn(args):
    # load data
    print('loading training data')
    train_data = load_jsonl(
        find_filepath(args.data_train)
    )
    print(f'length of training data: {len(train_data)}')

    print('loading test data')
    validation_data = load_jsonl(
        find_filepath(args.data_validation)
    )
    print(f'length of validation data: {len(validation_data)}')
    
    # parse feature names
    print('parsing feature names')
    numerical_feature_names = args.numerical_feature_names.split(',')
    categorical_feature_names = args.categorical_feature_names.split(',')
    textual_feature_names = args.textual_feature_names.split(',')
    
    print('saving feature names')
    save_feature_names_sentence_transformer(
        numerical_feature_names,
        categorical_feature_names,
        textual_feature_names,
        args.sentence_transformer,
        Path(args.model_dir, "feature_names.json")
    )
    
    # extract label
    print('extracting label')
    train_labels = extract_labels(train_data, args.label_name)
    print(f'length of training labels: {len(train_labels)}')
    validation_labels = extract_labels(validation_data, args.label_name)
    print(f'length of validation labels: {len(validation_labels)}') 
    
    
    if args.balanced_data:
        print('computing class weights')
        train_val_labels = np.concatenate((train_labels, validation_labels), axis=0)
        class_weights = class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(train_val_labels), y=train_val_labels)
        class_weight_dict = {}
        for idx, weight in enumerate(class_weights):
            class_weight_dict[idx] = weight
    else:
        class_weight_dict = None
    

    # extract features
    print('extracting features for training and validation data')
    train_numerical_features, train_categorical_features, train_textual_features = extract_features(
        train_data,
        numerical_feature_names,
        categorical_feature_names,
        textual_feature_names
    )
    
    validation_numerical_features, validation_categorical_features, validation_textual_features = extract_features(
        validation_data,
        numerical_feature_names,
        categorical_feature_names,
        textual_feature_names
    )

    # define preprocessors
    print('defining preprocessors')
    numerical_transformer = SimpleImputer(missing_values=np.nan, strategy='mean', add_indicator=True)
    categorical_transformer = OneHotEncoder(handle_unknown="ignore")
    textual_transformer = BertEncoder(model_name=args.sentence_transformer)

    # fit and save preprocessors
    print('fitting numerical_transformer')
    numerical_transformer.fit(train_numerical_features + validation_numerical_features)
    print('saving categorical_transformer')
    joblib.dump(numerical_transformer, Path(args.model_dir, "numerical_transformer.joblib"))
    print('fitting categorical_transformer')
    categorical_transformer.fit(train_categorical_features + validation_categorical_features)
    print('saving categorical_transformer')
    joblib.dump(categorical_transformer, Path(args.model_dir, "categorical_transformer.joblib"))

    # transform features
    print('transforming numerical_features for training and validataion data')
    train_numerical_features = numerical_transformer.transform(train_numerical_features)
    validation_numerical_features = numerical_transformer.transform(validation_numerical_features)
    print('transforming categorical_features for training and validataion data')
    train_categorical_features = categorical_transformer.transform(train_categorical_features)
    validation_categorical_features = categorical_transformer.transform(validation_categorical_features)
    print('transforming textual_features for training and validataion data')
    train_textual_features = textual_transformer.transform(train_textual_features)
    validation_textual_features = textual_transformer.transform(validation_textual_features)

    # concat features
    print('concatenating features')
    train_features = concatenate_features(train_numerical_features, train_categorical_features, train_textual_features)
    validation_features = concatenate_features(validation_numerical_features, validation_categorical_features, validation_textual_features)
    
    if args.max_depth == -1:
        max_depth = None
    else:
        max_depth = args.max_depth
        
    if args.bootstrap == "True":
        bootstrap = True
    else:
        bootstrap = False
    # define model
    print('instantiating model')
    classifier = RandomForestClassifier(
        n_estimators=args.n_estimators,
        criterion=args.criterion,
        max_depth=max_depth,
        min_impurity_decrease=args.min_impurity_decrease,
        ccp_alpha=args.ccp_alpha,
        bootstrap=bootstrap,
        min_samples_split=args.min_samples_split,
        min_samples_leaf=args.min_samples_leaf,
        class_weight=class_weight_dict,
    )

    # fit and save model
    print('fitting model')
    classifier = classifier.fit(train_features, train_labels)
    
    print('evaluating the model on the validation data')
    prediction_prob = classifier.predict_proba(validation_features).tolist()
    prediction_prob = np.array(prediction_prob)
    prediction_labels = np.argmax(prediction_prob, axis=1)
    
    f1_score_val = f1_score(validation_labels, prediction_labels)
    accuracy_val = accuracy_score(validation_labels, prediction_labels)
    roc_auc_val = roc_auc_score(validation_labels, prediction_prob[:, 1])
    print(f'f1 score on validation data: {f1_score_val}')
    print(f'accuracy score on validation data: {accuracy_val}')
    print(f'roc auc score on validation data: {roc_auc_val}')
    
    print('saving model')
    joblib.dump(classifier, Path(args.model_dir, "classifier.joblib"))