courses/machine_learning/cloudmle/taxifare/trainer/task.py (74 lines of code) (raw):
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example implementation of code to run on the Cloud ML service.
"""
import argparse
import json
import os
from . import model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Input Arguments
parser.add_argument(
'--train_data_paths',
help = 'GCS or local path to training data',
required = True
)
parser.add_argument(
'--train_batch_size',
help = 'Batch size for training steps',
type = int,
default = 512
)
parser.add_argument(
'--train_steps',
help = 'Steps to run the training job for',
type = int
)
parser.add_argument(
'--eval_steps',
help = 'Number of steps to run evalution for at each checkpoint',
default = 10,
type = int
)
parser.add_argument(
'--eval_data_paths',
help = 'GCS or local path to evaluation data',
required = True
)
# Training arguments
parser.add_argument(
'--hidden_units',
help = 'List of hidden layer sizes to use for DNN feature columns',
nargs = '+',
type = int,
default = [128, 32, 4]
)
parser.add_argument(
'--output_dir',
help = 'GCS location to write checkpoints and export models',
required = True
)
parser.add_argument(
'--job-dir',
help = 'this model ignores this field, but it is required by gcloud',
default = 'junk'
)
# Eval arguments
parser.add_argument(
'--eval_delay_secs',
help = 'How long to wait before running first evaluation',
default = 10,
type = int
)
parser.add_argument(
'--min_eval_frequency',
help = 'Seconds between evaluations',
default = 300,
type = int
)
args = parser.parse_args()
arguments = args.__dict__
# Unused args provided by service
arguments.pop('job_dir', None)
arguments.pop('job-dir', None)
output_dir = arguments['output_dir']
# Append trial_id to path if we are doing hptuning
# This code can be removed if you are not using hyperparameter tuning
output_dir = os.path.join(
output_dir,
json.loads(
os.environ.get('TF_CONFIG', '{}')
).get('task', {}).get('trail', '')
)
# Run the training job
model.train_and_evaluate(arguments)