def train()

in code/scikit_learn_iris.py [0:0]


def train(train=None, test=None):
    """Trains a model using the specified algorithm with given parameters.
    
       Args:
          train : location on the filesystem for training dataset
          test: location on the filesystem for test dataset 
          
       Returns:
          trained model object
    """
    # Take the set of files and read them all into a single pandas dataframe
    train_files = [ os.path.join(train, file) for file in os.listdir(train) ]
    if test:
        test_files = [os.path.join(test, file) for file in os.listdir(test)]
    if len(train_files) == 0 or (test and len(test_files)) == 0:
        raise ValueError((f'There are no files in {train}.\n' +
                          'This usually indicates that the channel train was incorrectly specified,\n' +
                          'the data specification in S3 was incorrectly specified or the role specified\n' +
                          'does not have permission to access the data.'))

    X_train = genfromtxt(f'{train}/train_x.csv', delimiter=',')
    y_train = genfromtxt(f'{train}/train_y.csv', delimiter=',')

    # Now use scikit-learn's decision tree classifier to train the model.
    if "SM_CHANNEL_JOBINFO" in os.environ: 
      jobinfo_path = os.environ.get('SM_CHANNEL_JOBINFO')
      with open(f"{jobinfo_path}/jobinfo.json", "r") as f:
        jobinfo = json.load(f)
        hyperparams = jobinfo['hyperparams']
        clf = svm.SVC(kernel=hyperparams['kernel'], 
                      C=float(hyperparams['c']), 
                      gamma=float(hyperparams['gamma']), 
                      verbose=1).fit(X_train, y_train)
    else:     
      clf = svm.SVC(kernel=args.kernel, 
                    C=args.c, 
                    gamma=args.gamma, 
                    verbose=1).fit(X_train, y_train)
    return clf