train.py (290 lines of code) (raw):

#!/usr/bin/env python # Modified Horovod MNIST example import os import sys import time import horovod.tensorflow as hvd import numpy as np import tensorflow as tf import graphics from utils import ResultLogger learn = tf.contrib.learn # Surpress verbose warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' def _print(*args, **kwargs): if hvd.rank() == 0: print(*args, **kwargs) def init_visualizations(hps, model, logdir): def sample_batch(y, eps): n_batch = hps.local_batch_train xs = [] for i in range(int(np.ceil(len(eps) / n_batch))): xs.append(model.sample( y[i*n_batch:i*n_batch + n_batch], eps[i*n_batch:i*n_batch + n_batch])) return np.concatenate(xs) def draw_samples(epoch): if hvd.rank() != 0: return rows = 10 if hps.image_size <= 64 else 4 cols = rows n_batch = rows*cols y = np.asarray([_y % hps.n_y for _y in ( list(range(cols)) * rows)], dtype='int32') # temperatures = [0., .25, .5, .626, .75, .875, 1.] #previously temperatures = [0., .25, .5, .6, .7, .8, .9, 1.] x_samples = [] x_samples.append(sample_batch(y, [.0]*n_batch)) x_samples.append(sample_batch(y, [.25]*n_batch)) x_samples.append(sample_batch(y, [.5]*n_batch)) x_samples.append(sample_batch(y, [.6]*n_batch)) x_samples.append(sample_batch(y, [.7]*n_batch)) x_samples.append(sample_batch(y, [.8]*n_batch)) x_samples.append(sample_batch(y, [.9] * n_batch)) x_samples.append(sample_batch(y, [1.]*n_batch)) # previously: 0, .25, .5, .625, .75, .875, 1. for i in range(len(x_samples)): x_sample = np.reshape( x_samples[i], (n_batch, hps.image_size, hps.image_size, 3)) graphics.save_raster(x_sample, logdir + 'epoch_{}_sample_{}.png'.format(epoch, i)) return draw_samples # === # Code for getting data # === def get_data(hps, sess): if hps.image_size == -1: hps.image_size = {'mnist': 32, 'cifar10': 32, 'imagenet-oord': 64, 'imagenet': 256, 'celeba': 256, 'lsun_realnvp': 64, 'lsun': 256}[hps.problem] if hps.n_test == -1: hps.n_test = {'mnist': 10000, 'cifar10': 10000, 'imagenet-oord': 50000, 'imagenet': 50000, 'celeba': 3000, 'lsun_realnvp': 300*hvd.size(), 'lsun': 300*hvd.size()}[hps.problem] hps.n_y = {'mnist': 10, 'cifar10': 10, 'imagenet-oord': 1000, 'imagenet': 1000, 'celeba': 1, 'lsun_realnvp': 1, 'lsun': 1}[hps.problem] if hps.data_dir == "": hps.data_dir = {'mnist': None, 'cifar10': None, 'imagenet-oord': '/mnt/host/imagenet-oord-tfr', 'imagenet': '/mnt/host/imagenet-tfr', 'celeba': '/mnt/host/celeba-reshard-tfr', 'lsun_realnvp': '/mnt/host/lsun_realnvp', 'lsun': '/mnt/host/lsun'}[hps.problem] if hps.problem == 'lsun_realnvp': hps.rnd_crop = True else: hps.rnd_crop = False if hps.category: hps.data_dir += ('/%s' % hps.category) # Use anchor_size to rescale batch size based on image_size s = hps.anchor_size hps.local_batch_train = hps.n_batch_train * \ s * s // (hps.image_size * hps.image_size) hps.local_batch_test = {64: 50, 32: 25, 16: 10, 8: 5, 4: 2, 2: 2, 1: 1}[ hps.local_batch_train] # round down to closest divisor of 50 hps.local_batch_init = hps.n_batch_init * \ s * s // (hps.image_size * hps.image_size) print("Rank {} Batch sizes Train {} Test {} Init {}".format( hvd.rank(), hps.local_batch_train, hps.local_batch_test, hps.local_batch_init)) if hps.problem in ['imagenet-oord', 'imagenet', 'celeba', 'lsun_realnvp', 'lsun']: hps.direct_iterator = True import data_loaders.get_data as v train_iterator, test_iterator, data_init = \ v.get_data(sess, hps.data_dir, hvd.size(), hvd.rank(), hps.pmap, hps.fmap, hps.local_batch_train, hps.local_batch_test, hps.local_batch_init, hps.image_size, hps.rnd_crop) elif hps.problem in ['mnist', 'cifar10']: hps.direct_iterator = False import data_loaders.get_mnist_cifar as v train_iterator, test_iterator, data_init = \ v.get_data(hps.problem, hvd.size(), hvd.rank(), hps.dal, hps.local_batch_train, hps.local_batch_test, hps.local_batch_init, hps.image_size) else: raise Exception() return train_iterator, test_iterator, data_init def process_results(results): stats = ['loss', 'bits_x', 'bits_y', 'pred_loss'] assert len(stats) == results.shape[0] res_dict = {} for i in range(len(stats)): res_dict[stats[i]] = "{:.4f}".format(results[i]) return res_dict def main(hps): # Initialize Horovod. hvd.init() # Create tensorflow session sess = tensorflow_session() # Download and load dataset. tf.set_random_seed(hvd.rank() + hvd.size() * hps.seed) np.random.seed(hvd.rank() + hvd.size() * hps.seed) # Get data and set train_its and valid_its train_iterator, test_iterator, data_init = get_data(hps, sess) hps.train_its, hps.test_its, hps.full_test_its = get_its(hps) # Create log dir logdir = os.path.abspath(hps.logdir) + "/" if not os.path.exists(logdir): os.mkdir(logdir) # Create model import model model = model.model(sess, hps, train_iterator, test_iterator, data_init) # Initialize visualization functions visualise = init_visualizations(hps, model, logdir) if not hps.inference: # Perform training train(sess, model, hps, logdir, visualise) else: infer(sess, model, hps, test_iterator) def infer(sess, model, hps, iterator): # Example of using model in inference mode. Load saved model using hps.restore_path # Can provide x, y from files instead of dataset iterator # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32) if hps.direct_iterator: iterator = iterator.get_next() xs = [] zs = [] for it in range(hps.full_test_its): if hps.direct_iterator: # replace with x, y, attr if you're getting CelebA attributes, also modify get_data x, y = sess.run(iterator) else: x, y = iterator() z = model.encode(x, y) x = model.decode(y, z) xs.append(x) zs.append(z) x = np.concatenate(xs, axis=0) z = np.concatenate(zs, axis=0) np.save('logs/x.npy', x) np.save('logs/z.npy', z) return zs def train(sess, model, hps, logdir, visualise): _print(hps) _print('Starting training. Logging to', logdir) _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg') # Train sess.graph.finalize() n_processed = 0 n_images = 0 train_time = 0.0 test_loss_best = 999999 if hvd.rank() == 0: train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__) test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__) tcurr = time.time() for epoch in range(1, hps.epochs): t = time.time() train_results = [] for it in range(hps.train_its): # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs. lr = hps.lr * min(1., n_processed / (hps.n_train * hps.epochs_warmup)) # Run a training step synchronously. _t = time.time() train_results += [model.train(lr)] if hps.verbose and hvd.rank() == 0: _print(n_processed, time.time()-_t, train_results[-1]) sys.stdout.flush() # Images seen wrt anchor resolution n_processed += hvd.size() * hps.n_batch_train # Actual images seen at current resolution n_images += hvd.size() * hps.local_batch_train train_results = np.mean(np.asarray(train_results), axis=0) dtrain = time.time() - t ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain train_time += dtrain if hvd.rank() == 0: train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int( train_time), **process_results(train_results)) if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0: test_results = [] msg = '' t = time.time() # model.polyak_swap() if epoch % hps.epochs_full_valid == 0: # Full validation run for it in range(hps.full_test_its): test_results += [model.test()] test_results = np.mean(np.asarray(test_results), axis=0) if hvd.rank() == 0: test_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, **process_results(test_results)) # Save checkpoint if test_results[0] < test_loss_best: test_loss_best = test_results[0] model.save(logdir+"model_best_loss.ckpt") msg += ' *' dtest = time.time() - t # Sample t = time.time() if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0: visualise(epoch) dsample = time.time() - t if hvd.rank() == 0: dcurr = time.time() - tcurr tcurr = time.time() _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format( ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg) # model.polyak_swap() if hvd.rank() == 0: _print("Finished!") # Get number of training and validation iterations def get_its(hps): # These run for a fixed amount of time. As anchored batch is smaller, we've actually seen fewer examples train_its = int(np.ceil(hps.n_train / (hps.n_batch_train * hvd.size()))) test_its = int(np.ceil(hps.n_test / (hps.n_batch_train * hvd.size()))) train_epoch = train_its * hps.n_batch_train * hvd.size() # Do a full validation run if hvd.rank() == 0: print(hps.n_test, hps.local_batch_test, hvd.size()) assert hps.n_test % (hps.local_batch_test * hvd.size()) == 0 full_test_its = hps.n_test // (hps.local_batch_test * hvd.size()) if hvd.rank() == 0: print("Train epoch size: " + str(train_epoch)) return train_its, test_its, full_test_its ''' Create tensorflow session with horovod ''' def tensorflow_session(): # Init session and params config = tf.ConfigProto() config.gpu_options.allow_growth = True # Pin GPU to local rank (one GPU per process) config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) return sess if __name__ == "__main__": # This enables a ctr-C without triggering errors import signal signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) import argparse parser = argparse.ArgumentParser() parser.add_argument("--verbose", action='store_true', help="Verbose mode") parser.add_argument("--restore_path", type=str, default='', help="Location of checkpoint to restore") parser.add_argument("--inference", action="store_true", help="Use in inference mode") parser.add_argument("--logdir", type=str, default='./logs', help="Location to save logs") # Dataset hyperparams: parser.add_argument("--problem", type=str, default='cifar10', help="Problem (mnist/cifar10/imagenet") parser.add_argument("--category", type=str, default='', help="LSUN category") parser.add_argument("--data_dir", type=str, default='', help="Location of data") parser.add_argument("--dal", type=int, default=1, help="Data augmentation level: 0=None, 1=Standard, 2=Extra") # New dataloader params parser.add_argument("--fmap", type=int, default=1, help="# Threads for parallel file reading") parser.add_argument("--pmap", type=int, default=16, help="# Threads for parallel map") # Optimization hyperparams: parser.add_argument("--n_train", type=int, default=50000, help="Train epoch size") parser.add_argument("--n_test", type=int, default=- 1, help="Valid epoch size") parser.add_argument("--n_batch_train", type=int, default=64, help="Minibatch size") parser.add_argument("--n_batch_test", type=int, default=50, help="Minibatch size") parser.add_argument("--n_batch_init", type=int, default=256, help="Minibatch size for data-dependent init") parser.add_argument("--optimizer", type=str, default="adamax", help="adam or adamax") parser.add_argument("--lr", type=float, default=0.001, help="Base learning rate") parser.add_argument("--beta1", type=float, default=.9, help="Adam beta1") parser.add_argument("--polyak_epochs", type=float, default=1, help="Nr of averaging epochs for Polyak and beta2") parser.add_argument("--weight_decay", type=float, default=1., help="Weight decay. Switched off by default.") parser.add_argument("--epochs", type=int, default=1000000, help="Total number of training epochs") parser.add_argument("--epochs_warmup", type=int, default=10, help="Warmup epochs") parser.add_argument("--epochs_full_valid", type=int, default=50, help="Epochs between valid") parser.add_argument("--gradient_checkpointing", type=int, default=1, help="Use memory saving gradients") # Model hyperparams: parser.add_argument("--image_size", type=int, default=-1, help="Image size") parser.add_argument("--anchor_size", type=int, default=32, help="Anchor size for deciding batch size") parser.add_argument("--width", type=int, default=512, help="Width of hidden layers") parser.add_argument("--depth", type=int, default=32, help="Depth of network") parser.add_argument("--weight_y", type=float, default=0.00, help="Weight of log p(y|x) in weighted loss") parser.add_argument("--n_bits_x", type=int, default=8, help="Number of bits of x") parser.add_argument("--n_levels", type=int, default=3, help="Number of levels") # Synthesis/Sampling hyperparameters: parser.add_argument("--n_sample", type=int, default=1, help="minibatch size for sample") parser.add_argument("--epochs_full_sample", type=int, default=50, help="Epochs between full scale sample") # Ablation parser.add_argument("--learntop", action="store_true", help="Learn spatial prior") parser.add_argument("--ycond", action="store_true", help="Use y conditioning") parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument("--flow_permutation", type=int, default=2, help="Type of flow. 0=reverse (realnvp), 1=shuffle, 2=invconv (ours)") parser.add_argument("--flow_coupling", type=int, default=0, help="Coupling type: 0=additive, 1=affine") hps = parser.parse_args() # So error if typo main(hps)