in code/scikit_learn_iris.py [0:0]
def train(train=None, test=None):
"""Trains a model using the specified algorithm with given parameters.
Args:
train : location on the filesystem for training dataset
test: location on the filesystem for test dataset
Returns:
trained model object
"""
# Take the set of files and read them all into a single pandas dataframe
train_files = [ os.path.join(train, file) for file in os.listdir(train) ]
if test:
test_files = [os.path.join(test, file) for file in os.listdir(test)]
if len(train_files) == 0 or (test and len(test_files)) == 0:
raise ValueError((f'There are no files in {train}.\n' +
'This usually indicates that the channel train was incorrectly specified,\n' +
'the data specification in S3 was incorrectly specified or the role specified\n' +
'does not have permission to access the data.'))
X_train = genfromtxt(f'{train}/train_x.csv', delimiter=',')
y_train = genfromtxt(f'{train}/train_y.csv', delimiter=',')
# Now use scikit-learn's decision tree classifier to train the model.
if "SM_CHANNEL_JOBINFO" in os.environ:
jobinfo_path = os.environ.get('SM_CHANNEL_JOBINFO')
with open(f"{jobinfo_path}/jobinfo.json", "r") as f:
jobinfo = json.load(f)
hyperparams = jobinfo['hyperparams']
clf = svm.SVC(kernel=hyperparams['kernel'],
C=float(hyperparams['c']),
gamma=float(hyperparams['gamma']),
verbose=1).fit(X_train, y_train)
else:
clf = svm.SVC(kernel=args.kernel,
C=args.c,
gamma=args.gamma,
verbose=1).fit(X_train, y_train)
return clf