recommended-item-search/softmax_main.py (130 lines of code) (raw):

#!/usr/bin/python # # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=g-bad-import-order from absl import app as absl_app from absl import flags import tensorflow as tf import input_pipeline import softmax_model # pylint: enable=g-bad-import-order FLAGS = flags.FLAGS flags.DEFINE_string( name='metadata_path', default='metadata.pickle', help='Set a path to metadata created by preprocess_movielens.py') flags.DEFINE_list( name='hidden_dims', default=['64','32'], help='The sizes of hidden layers for MLP. e.g. --layers=32,16,8,4') flags.DEFINE_enum( name='activation', default='relu', enum_values=['relu', 'None'], case_sensitive=False, help='Specify an activation function used in hidden layers.') flags.DEFINE_string( name='model_dir', default='./model', help='Set a model directory where model and checkpoint files are stored.') flags.DEFINE_string( name='export_dir', default='Servo', help='Set a sub directory where savedmodels are saved.') flags.DEFINE_boolean( name='resume_training', default=False, help='Resume training from a latest checkpoint.') flags.DEFINE_string( name='train_filename', default='train*.tfrecord', help='Set a file pattern of training inputs.') flags.DEFINE_integer( name='train_batch_size', default=200, help='Set a batch size for training process.') flags.DEFINE_integer( name='train_max_steps', default=1000000, help='Set a max training step per execution.') flags.DEFINE_string( name='eval_filename', default='eval*.tfrecord', help='Set a file pattern of evaluation inputs.') flags.DEFINE_integer( name='eval_batch_size', default=10000, help='Set a batch size for evaluation process.') flags.DEFINE_integer( name='eval_steps', default=10, help='Set the number of steps per evaluation.') flags.DEFINE_integer( name='eval_throttle_secs', default=10, help='Set throttle secs for each evaluation.') flags.DEFINE_float( name='learning_rate', default=0.01, help='Set a learning rate for optimizer.') flags.DEFINE_integer( name='lr_decay_steps', default=100000, help='Set a learning rate decay steps.') flags.DEFINE_float( name='lr_decay_rate', default=0.96, help='Set a learning rate decay rate.') flags.DEFINE_integer( name='save_checkpoints_steps', default=10000, help='Set frequency of saving checkpoints.') flags.DEFINE_integer( name='keep_checkpoint_max', default=3, help='Set maximum number of saved checkpoints.') flags.DEFINE_integer( name='log_step_count_steps', default=1000, help='Set frequency of loss logging.') flags.DEFINE_integer( name='tf_random_seed', default=20190501, help='Set random seed for TensorFlow.') tf.logging.set_verbosity(tf.logging.INFO) def get_run_config(): """Get running parameters for Estimator.""" return tf.estimator.RunConfig( model_dir=FLAGS.model_dir, tf_random_seed=FLAGS.tf_random_seed, log_step_count_steps=FLAGS.log_step_count_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max, save_checkpoints_steps=FLAGS.save_checkpoints_steps, train_distribute=None, session_config=tf.ConfigProto(allow_soft_placement=True) ) def get_hyperparams(): """Get hyper params which are used in model function.""" return tf.contrib.training.HParams( metadata_path=FLAGS.metadata_path, hidden_dims=[int(dim) for dim in FLAGS.hidden_dims], activation_name=FLAGS.activation, learning_rate=FLAGS.learning_rate, lr_decay_steps=FLAGS.lr_decay_steps, lr_decay_rate=FLAGS.lr_decay_rate, ) def get_train_spec(): """Get train spec for Estimator.""" profile_hook = tf.train.ProfilerHook( save_steps=FLAGS.save_checkpoints_steps, output_dir=FLAGS.model_dir, show_memory=True) train_input_fn = input_pipeline.generate_input_fn( file_pattern=FLAGS.train_filename, batch_size=FLAGS.train_batch_size, mode=tf.estimator.ModeKeys.TRAIN) train_spec = tf.estimator.TrainSpec( input_fn=train_input_fn, max_steps=FLAGS.train_max_steps, hooks=[profile_hook]) return train_spec def get_eval_spec(): """Get eval spec for Estimator.""" exporter = tf.estimator.LatestExporter( name=FLAGS.export_dir, exports_to_keep=FLAGS.keep_checkpoint_max, serving_input_receiver_fn=softmax_model.serving_input_fn) eval_input_fn = input_pipeline.generate_input_fn( file_pattern=FLAGS.eval_filename, batch_size=FLAGS.eval_batch_size, mode=tf.estimator.ModeKeys.EVAL) eval_spec = tf.estimator.EvalSpec( input_fn=eval_input_fn, steps=FLAGS.eval_steps, throttle_secs=FLAGS.eval_throttle_secs, exporters=exporter) return eval_spec def remove_artifacts(): """Remove previous artifacts if needed.""" if not FLAGS.resume_training: if tf.gfile.Exists(FLAGS.model_dir): tf.logging.info('Removing {} ...'.format(FLAGS.model_dir)) tf.gfile.DeleteRecursively(FLAGS.model_dir) tf.summary.FileWriterCache.clear() def main(_): remove_artifacts() estimator = tf.estimator.Estimator( model_fn=softmax_model.model_fn, params=get_hyperparams(), config=get_run_config()) tf.estimator.train_and_evaluate( estimator=estimator, train_spec=get_train_spec(), eval_spec=get_eval_spec()) if __name__ == '__main__': absl_app.run(main)