def train_fn()

in sagemaker_notebook_instance/containers/model/entry_point.py [0:0]


def train_fn(args):
    # load data
    print('loading data')
    data = load_data(args.data_train, 'train.json')

    # parse feature names
    print('parsing feature names')
    numerical_feature_names = args.numerical_feature_names.split(',')
    categorical_feature_names = args.categorical_feature_names.split(',')
    textual_feature_names = args.textual_feature_names.split(',')
    print('saving feature names')
    save_feature_names(
        numerical_feature_names,
        categorical_feature_names,
        textual_feature_names,
        Path(args.model_dir, "feature_names.json")
    )
    
    # extract label
    print('extracting label')
    labels = extract_labels(
        data,
        args.label_name
    )

    # extract features
    print('extracting features')
    numerical_features, categorical_features, textual_features = extract_features(
        data,
        numerical_feature_names,
        categorical_feature_names,
        textual_feature_names
    )

    # define preprocessors
    print('defining preprocessors')
    numerical_transformer = SimpleImputer(missing_values=np.nan, strategy='mean', add_indicator=True)
    categorical_transformer = OneHotEncoder(handle_unknown="ignore")
    textual_transformer = BertEncoder()

    # fit and save preprocessors
    print('fitting numerical_transformer')
    numerical_transformer.fit(numerical_features)
    print('saving categorical_transformer')
    joblib.dump(numerical_transformer, Path(args.model_dir, "numerical_transformer.joblib"))
    print('fitting categorical_transformer')
    categorical_transformer.fit(categorical_features)
    print('saving categorical_transformer')
    joblib.dump(categorical_transformer, Path(args.model_dir, "categorical_transformer.joblib"))

    # transform features
    print('transforming numerical_features')
    numerical_features = numerical_transformer.transform(numerical_features)
    print('transforming categorical_features')
    categorical_features = categorical_transformer.transform(categorical_features)
    print('transforming textual_features')
    textual_features = textual_transformer.transform(textual_features)

    # concat features
    print('concatenating features')
    categorical_features = categorical_features.toarray()
    textual_features = np.array(textual_features)
    textual_features = textual_features.reshape(textual_features.shape[0], -1)
    features = np.concatenate([
        numerical_features,
        categorical_features,
        textual_features
    ], axis=1)

    # define model
    print('instantiating model')
    classifier = RandomForestClassifier(
        n_estimators=args.n_estimators
    )

    # fit and save model
    print('fitting model')
    classifier = classifier.fit(features, labels)
    print('saving model')
    joblib.dump(classifier, Path(args.model_dir, "classifier.joblib"))