train.py (1,125 lines of code) (raw):

# the base model is the optimized version of the Sparse Transformer # presented at https://arxiv.org/abs/1904.10509 # if hacking on the model, be sure to use the mpi init functions # (random_or_zeros_init, constant_or_zeros_init, etc) or # else the models wont be synced across ranks from collections import namedtuple import itertools import os import pdb import sys import time import math import argparse import numpy as np import tensorflow as tf import blocksparse as bs from blocksparse.nccl import serialize_allreduce_ops import subprocess from utils import logger, save_params, load_variables_from_file from utils import maybe_download from utils import log_gradient_values, shape_list, go_over from hyperparams import Hyperparams, add_arguments from hyperparams import parse_args_and_update_hparams from mpi_utils import random_or_zeros_init, constant_or_zeros_init, zeros_init from mpi_utils import get_session, allreduce, group_allreduce, sync_variables from mpi_utils import mpi_size, mpi_rank, local_mpi_rank, mpi_allgather, mpi_barrier from optimizer import get_optimizer from datasets import get_dataset, iter_data_mpi, JankySubsampledDataset from autoaugment import distort_image_with_randaugment H = Hyperparams() AugmentationType = namedtuple("AugmentationType", ("sos_name", "description", "num_tokens", "is_used", "fn")) def f32_storage_getter(getter, name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, trainable=True, *args, **kwargs): """Custom variable getter that forces trainable variables to be stored in float32 precision and then casts them to the training precision. https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/ index.html#mptrain """ var = H.var_cache.get(name) if var is None: with tf.control_dependencies(None): var = getter(name, shape, dtype=tf.float32, initializer=initializer, regularizer=regularizer, trainable=trainable, *args, **kwargs) H.var_cache[name] = var if H.ema is not None: var = H.ema.average(var) if dtype != var.dtype.base_dtype: var = bs.float_cast(var, dtype=dtype, dx_dtype=dtype, name=f"{name}/cast") return var def split_states(x, heads): """ reshape (batch, pixel, state) -> (batch, pixel, head, head_state) """ x_shape = shape_list(x) m = x_shape[-1] new_x_shape = x_shape[:-1] + [heads, m // heads] return tf.reshape(x, new_x_shape) def merge_states(x): """ reshape (batch, pixel, head, head_state) -> (batch, pixel, state) """ x_shape = shape_list(x) new_x_shape = x_shape[:-2] + [np.prod(x_shape[-2:])] return tf.reshape(x, new_x_shape) def split_heads(scope, x, heads): """ (batch, pixel, state) -> (batch, head, pixel, head_state) """ with tf.name_scope(scope): return bs.transpose_0213(split_states(x, heads)) def merge_heads(scope, x): """ (batch, head, pixel, head_state) -> (batch, pixel, state) """ with tf.name_scope(scope): return merge_states(bs.transpose_0213(x)) def get_dense_attn_mask(n, attn_mode): '''a is dense attention, b is local attention (previous k), bT is strided (every kth element), implemented as a transpose''' key = f'{n}-{attn_mode}' dense_mask = H.dense_mask_cache.get(key) if dense_mask is not None: return dense_mask if attn_mode == 'a_all': b = tf.ones([n, n], dtype=tf.float32) elif attn_mode == 'a': b = tf.matrix_band_part(tf.ones([n, n]), -1, 0) elif attn_mode == 'b': bandwidth = H.local_attn_ctx ctx = tf.minimum(n - 1, bandwidth - 1) b = tf.matrix_band_part(tf.ones([n, n]), ctx, 0) elif attn_mode in ['c', 'bT']: stride = H.local_attn_ctx x = tf.reshape(tf.range(n, dtype=tf.int32), [n, 1]) y = tf.transpose(x) z = tf.zeros([n, n], dtype=tf.int32) q = z + x k = z + y c1 = q >= k c2 = tf.equal(tf.floormod(q - k, stride), 0) c3 = tf.logical_and(c1, c2) b = tf.cast(c3, tf.float32) else: raise ValueError('Not yet implemented') b = tf.reshape(b, [1, 1, n, n]) H.dense_mask_cache[key] = b return b def get_callback(attn_mode): def cb(blk_shape, head_idx, qry_idx, key_idx, blk_idx): mask = np.ones(blk_shape, dtype=np.bool) qdim, kdim = blk_shape assert qdim == kdim if attn_mode in ['a_all', 'b_all', 'bT_all']: return mask if qry_idx == key_idx: for q in range(qdim): mask[q, q + 1:] = 0 if attn_mode in ['a', 'bT', 'b0']: return mask if attn_mode == 'b': bandwidth = H.local_attn_ctx # convert group indices to absolute indices and mask # according to that q_pos = blk_shape[0] * qry_idx k_pos = blk_shape[1] * key_idx for q in range(qdim): q_ = q + q_pos maxw = max(-1, q_ - k_pos - bandwidth) mask[q, :maxw + 1] = 0 if qry_idx == key_idx: mask[q, q + 1:] = 0 if H.print_attn_layout: for i in range(qdim): print(' '.join([str(x) for x in mask[i, 0:kdim].astype(np.int32)])) print(qry_idx, key_idx) pdb.set_trace() return mask raise ValueError return cb def get_blocksparse_obj(n_ctx, n_heads, attn_mode): '''a is dense attention, b is local attention (previous k), bT is strided (every kth element), implemented as a transpose''' key = f'{n_ctx}-{n_heads}-{attn_mode}' bst = H.bst_cache.get(key) if bst is not None: return bst blocksize = H.blocksize n_bctx = n_ctx // blocksize if attn_mode in ['b', 'bT', 'b0']: if attn_mode in ['b']: assert H.local_attn_ctx % blocksize == 0 extra_diagonals = H.local_attn_ctx // blocksize elif attn_mode in ['bT', 'b0']: bT_ctx = H.attn_ctx // H.local_attn_ctx assert bT_ctx % blocksize == 0 block_chunks = bT_ctx // blocksize layout = np.ones([n_bctx, n_bctx], dtype=np.bool) for q_idx in range(n_bctx): # Causal queries cannot attend to keys above them layout[q_idx, q_idx + 1:] = 0 if attn_mode == 'b': start = max(0, q_idx - extra_diagonals) layout[q_idx, :start] = 0 elif attn_mode in ['bT', 'b0']: offset = q_idx % block_chunks layout[q_idx, :q_idx - offset] = 0 elif attn_mode == 'a': # standard causal attention layout = np.ones([n_bctx, n_bctx], dtype=np.bool) for q_idx in range(n_bctx): layout[q_idx, q_idx + 1:] = 0 elif attn_mode == 'a_all': layout = np.ones([n_bctx, n_bctx], dtype=np.bool) if H.mem_block and H.block_memory: # Block attention over the memory block layout[:-1, -1] = 0 elif attn_mode in ['b_all', 'bT_all']: assert H.blocksize == 32 assert H.local_attn_ctx == 32 assert n_bctx == 32 layout = np.zeros([n_bctx, n_bctx], dtype=np.bool) for q_idx in range(n_bctx): layout[q_idx, q_idx] = 1.0 else: raise NotImplementedError if H.print_attn_layout: width = H.attn_cols_to_print for i in range(min(width, n_bctx)): print(' '.join([str(x) for x in layout[i, 0:width].astype(np.int32)])) pdb.set_trace() bst = bs.BlocksparseTransformer( layout, block_size=blocksize, mask_callback=get_callback(attn_mode), heads=n_heads) H.bst_cache[key] = bst return bst def linear(scope, x, nf, std, relu=False, fast_gelu=False): with tf.variable_scope(scope): nx = x.shape[-1].value # delay w casting operation just prior to use # This can save a lot of memory for large param models. with tf.control_dependencies([x]): w = tf.get_variable("w", [nx, nf], dtype=H.dtype, initializer=random_or_zeros_init(stddev=std)) b = tf.get_variable("b", [nf], dtype=tf.float32, initializer=zeros_init()) ndims = x.shape.ndims if ndims > 2: h_shape = tf.concat([tf.shape(x)[:ndims - 1], [nf]], axis=0) x = tf.reshape(x, [-1, nx]) h = tf.matmul(x, w) h = bs.bias_relu(h, b, relu=relu, fast_gelu=fast_gelu) if ndims > 2: h = tf.reshape(h, h_shape) return h def norm(scope, x, epsilon=1e-5): with tf.variable_scope(scope): nx = x.shape[-1].value g = tf.get_variable("g", [nx], dtype=tf.float32, initializer=constant_or_zeros_init(1.0)) b = tf.get_variable("b", [nx], dtype=tf.float32, initializer=zeros_init()) return bs.layer_norm(x, g, b, axis=-1, epsilon=epsilon, relu=False) def embedding_dropout(x, train): if train and H.embd_pdrop > 0.0: x, _ = bs.dropout(x, keep_prob=1.0 - H.embd_pdrop) return x def residual_dropout(x, train, key, pdrop=None): resid_pdrop = pdrop if pdrop else H.resid_pdrop if train and resid_pdrop > 0.0: mask_shape = x.shape.as_list() key += str(mask_shape) mask_shape = None x, H.dropout_cache[key] = bs.dropout( x, keep_prob=1.0 - resid_pdrop, mask=H.dropout_cache.get(key), mask_shape=mask_shape) return x @bs.recomputable def dense_attention(x, n_heads, attn_mode, use_cache=False, train=False, pdrop=None): nx = x.shape[-1].value n_state = int(nx * H.qk_ratio) if n_state % n_heads != 0: raise ValueError('nx must be divisible by head state') h = norm("attn_input", x) qh = h[:, -1:, :] if use_cache else h q = linear('q_proj', qh, n_state, std=np.sqrt(H.qk_w / nx)) k = linear('k_proj', h, n_state, std=np.sqrt(H.qk_w / nx)) v = linear('v_proj', h, nx, std=np.sqrt(H.v_w / nx)) q = split_heads("q_split", q, n_heads) k = split_heads("k_split", k, n_heads) v = split_heads("v_split", v, n_heads) if use_cache: if attn_mode not in ['a', 'b', 'c', 'bT']: raise NotImplementedError mask = None if attn_mode == 'b': k = k[:, :, -H.local_attn_ctx:, :] v = v[:, :, -H.local_attn_ctx:, :] elif attn_mode in ['c', 'bT']: k = k[:, :, ::-H.local_attn_ctx, :][:, :, ::-1, :] v = v[:, :, ::-H.local_attn_ctx, :][:, :, ::-1, :] else: n_timesteps = k.shape[2].value mask = get_dense_attn_mask(n_timesteps, attn_mode) if H.float16: # These products can overflow, so we do it in float32. k = bs.float_cast(k, dtype=tf.float32) q = bs.float_cast(q, dtype=tf.float32) v = bs.float_cast(v, dtype=tf.float32) w = tf.matmul(q, k, transpose_b=True) w = bs.masked_softmax(w, mask=mask, scale=1.0 / np.sqrt(q.shape[-1].value)) a = tf.matmul(w, v) a = merge_heads("merge_attn", a) if H.float16: a = bs.float_cast(a, dtype=tf.float16) return post_attention(x, a, use_cache=use_cache, train=train, pdrop=pdrop) @bs.recomputable def sparse_attention(x, n_heads, attn_mode, use_cache=False, train=False, pdrop=None): if use_cache: raise NotImplementedError if not H.float16: raise ValueError("sparse_attention requires fp16") nx = x.shape[-1].value n_state = int(nx * H.qk_ratio) if n_state % n_heads != 0: raise ValueError('nx must be divisible by head state') h = norm("attn_input", x) if attn_mode in ['bT', 'bT_all']: ctx = H.local_attn_ctx bT_ctx = H.attn_ctx // ctx assert bT_ctx % H.blocksize == 0, f'{bT_ctx}, {H.blocksize}' n, t, embd = shape_list(h) h = tf.reshape(h, [n, bT_ctx, ctx, embd]) h = bs.transpose_0213(h) h = tf.reshape(h, [n, t, embd]) q = linear('q_proj', h, n_state, std=np.sqrt(H.qk_w / nx)) k = linear('k_proj', h, n_state, std=np.sqrt(H.qk_w / nx)) v = linear('v_proj', h, nx, std=np.sqrt(H.v_w / nx)) bst = get_blocksparse_obj(H.attn_ctx, n_heads, attn_mode) w = bst.query_key_op(q, k) w = bst.masked_softmax(w, scale=1.0 / np.sqrt(n_state // n_heads)) a = bst.weight_value_op(w, v) if attn_mode in ['bT', 'bT_all']: a = tf.reshape(a, [n, ctx, bT_ctx, embd]) a = bs.transpose_0213(a) a = tf.reshape(a, [n, t, embd]) return post_attention(x, a, train=train, pdrop=pdrop) def post_attention(x, a, use_cache=None, train=False, pdrop=None): nx = x.shape[-1].value a = linear('post_proj', a, nx, std=np.sqrt(H.post_w * 0.5 / nx / H.n_layer)) scopename = tf.get_variable_scope().name a = residual_dropout(a, train, key=f'{scopename}-a', pdrop=pdrop) x = x[:, -1:, :] if use_cache else x x = bs.add(x, a) inner_dim = int(nx * H.mlp_multiple) m = norm("mlp", x) m = linear('mlp_proj1', m, inner_dim, std=np.sqrt(H.mlp_w1 / nx), fast_gelu=True) m = linear('mlp_proj2', m, nx, std=np.sqrt(H.mlp_w2 / inner_dim / H.n_layer * 0.5)) m = residual_dropout(m, train, key=f'{scopename}-m', pdrop=pdrop) return bs.add(x, m) def add_position_embedding(x, x_emb, train, step): num_e = H.emb_number emb_std = H.pos_embd_std * np.sqrt(1.0 / num_e) for idx in range(H.emb_number): vsize = H.emb_vocabs[idx] name = f"pos_emb_{idx}" we = tf.get_variable( name, [vsize, H.n_embd], dtype=H.dtype, initializer=random_or_zeros_init(stddev=emb_std)) e = bs.embedding_lookup(we, x_emb[:, idx, :]) e = embedding_dropout(e, train) x += e return x def stack(X, X_emb, train, step=None, cache=None): with tf.name_scope('input_processing'): we = tf.get_variable( "we", [H.n_vocab, H.n_embd], dtype=H.dtype, initializer=random_or_zeros_init(stddev=H.w_embd_std)) h = bs.embedding_lookup(we, X) H.we = we H.we_x = h h = embedding_dropout(h, train) h = add_position_embedding(h, X_emb, train, step=step) if step is None: h = tf.reshape(h, [H.n_batch, H.attn_ctx, H.n_embd]) else: h = tf.reshape(h, [H.sample_batch, -1, H.n_embd]) with tf.variable_scope('sos_token'): if H.num_self_gen_in_use > 0 and not H.use_unconditional_augmentation: y_gen_idx = 0 sos_tok = 0 for typ in H.self_gen_types: if not typ.is_used: if mpi_rank() == 0: print(f" [self-gen] not using {typ.description}") continue if mpi_rank() == 0: print(f" [self-gen] using {typ.description}") this_sos_var = tf.get_variable( typ.sos_name, [typ.num_tokens, H.n_embd], dtype=H.dtype, initializer=random_or_zeros_init(stddev=H.w_embd_std)) this_sos_tok = bs.embedding_lookup(this_sos_var, H.Y_gen_ph[:, y_gen_idx:y_gen_idx + 1]) assert this_sos_tok.shape[1:] == (1, H.n_embd) sos_tok += this_sos_tok y_gen_idx += 1 assert y_gen_idx == H.num_self_gen_in_use else: sos = tf.get_variable( 'sos', [1, 1, H.n_embd], dtype=H.dtype, initializer=random_or_zeros_init(stddev=H.w_embd_std)) batch_size = H.n_batch if step is None else H.sample_batch sos_tok = tf.ones(shape=[batch_size, 1, H.n_embd], dtype=H.dtype) * sos if step is None: h = tf.concat([sos_tok, h[:, :-1, :]], axis=1) if H.randomly_determined_order_use_lookahead: print("lookahead_embd") with tf.variable_scope("lookahead_embedding"): h = add_position_embedding(h, X_emb, train, step=step) else: h = tf.concat([sos_tok, h], axis=1)[:, -1:, :] new_cache = [] modes = H.attention_layers.split(',') assert H.n_layer % len(modes) == 0 for layer_idx in range(H.n_layer): mode = modes[layer_idx % len(modes)] name = f'h{layer_idx}' if cache is not None: # We only cache the pre qkv tensor, as it takes up # too much memory otherwise on long sequences. h = tf.concat([cache[layer_idx], h], axis=1) new_cache.append(h) use_cache = True else: use_cache = False with tf.variable_scope(name): recompute = H.recompute and train if H.float16 and H.blocksparse_op and not use_cache: h = sparse_attention(h, H.n_head, mode, use_cache=use_cache, train=train, recompute=recompute) else: h = dense_attention(h, H.n_head, mode, use_cache=use_cache, train=train, recompute=recompute) if cache is not None: return h, new_cache return h def get_logits(name, h, n_out, train=False): n, t, nx = shape_list(h) w_std = np.sqrt(H.logits_w / nx) with tf.variable_scope(name): w = tf.get_variable( "logits_proj", [nx, n_out], dtype=H.dtype, initializer=random_or_zeros_init(stddev=w_std)) w = embedding_dropout(w, train) h = tf.reshape(h, [-1, nx]) logits = tf.matmul(h, w) return tf.reshape(logits, [n, t, n_out]) def get_losses(logits, labels, mask=None): with tf.name_scope('loss'): n, t, nx = shape_list(logits) ln, lt = shape_list(labels) assert lt == t labels = tf.reshape(labels, [-1]) logits = tf.reshape(logits, [-1, nx]) if H.float16 and logits.shape[-1].value <= 65536 and logits.dtype == tf.float16: # much faster fused fp16 implementation that also saves memory losses = bs.softmax_cross_entropy(logits=logits, labels=labels) else: logits = tf.cast(logits, tf.float32) losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels) losses = tf.reshape(losses, [n, t]) if mask is not None: # X_mask can be either boolean or scalar (weighted) mask return (tf.reduce_sum(losses * mask) / tf.reduce_sum(mask)), losses return tf.reduce_mean(losses), losses def model(train=False): with tf.variable_scope('model', custom_getter=f32_storage_getter): network_input = H.X_ph network_target = H.X_ph if H.rand_augment and train: assert network_input.shape[-1] == 3072, 'TODO: support other image sizes' network_input = tf.reshape(tf.cast(network_input, tf.uint8), [-1, 32, 32, 3]) if H.rand_augment_conditioning: if H.use_unconditional_augmentation: raise NotImplementedError rand_augment_idx = [t.sos_name for t in H.self_gen_types if t.is_used].index('sos_aa') batch = [] with tf.device('/cpu:0'): for i in range(H.n_batch): example = network_input[i] with_randaug = distort_image_with_randaugment(example, H.rand_augment_n, H.rand_augment_m) without_randaug = example should_autoaugment = tf.cast(H.Y_gen_ph[i, rand_augment_idx], tf.bool) example = tf.cond(should_autoaugment, lambda: with_randaug, lambda: without_randaug) batch.append(example) network_input = batch else: with tf.device('/cpu:0'): network_input = [distort_image_with_randaugment(network_input[i], H.rand_augment_n, H.rand_augment_m) for i in range(H.n_batch)] network_input = tf.cast(tf.reshape(tf.concat(network_input, axis=0), [-1, 3072]), H.X_ph.dtype) network_target = network_input h = stack(network_input, H.X_emb_ph, train=train) h = norm('final_norm', h, epsilon=1e-6) targets = network_target gen_logits = get_logits('gen_logits', h, H.n_vocab, train=train) gen_loss, gen_losses = get_losses(gen_logits, targets) return gen_loss, gen_losses def sample_model(): X = tf.zeros(shape=[H.sample_batch, 0], dtype=tf.int32) current_step = tf.constant(0, dtype=tf.int64) accumulated_output = X[:, :current_step] # Everything up til now. current_input = X[:, current_step - 1:current_step] cache_vars = [tf.zeros(shape=[H.sample_batch, 0, H.n_embd], dtype=H.dtype) for _ in range(H.n_layer)] cacheshapes = [tf.TensorShape([H.sample_batch, None, H.n_embd]) for _ in range(H.n_layer)] embd_index = tf.constant([0] * H.sample_batch, dtype=tf.int32) first_embd = tf.zeros(shape=[H.sample_batch, H.emb_number, 0], dtype=tf.int32) loop_vars = [current_step, accumulated_output, current_input, first_embd, embd_index, cache_vars] shape_invariants = [current_step.get_shape(), tf.TensorShape([H.sample_batch, None]), tf.TensorShape([H.sample_batch, None]), tf.TensorShape([H.sample_batch, H.emb_number, None]), embd_index.get_shape(), cacheshapes] embd_shapes = tf.constant(H.emb_vocabs, dtype=tf.int32) def cond(step, acc, curr, curr_embd, embd_index, cache): return step < H.attn_ctx def body(step, acc, curr, curr_embd, embd_index, cache): with tf.variable_scope('model', custom_getter=f32_storage_getter): h, cache = stack(curr, curr_embd, train=False, step=step, cache=cache) h = norm('final_norm', h, epsilon=1e-6) h = h[:, -1:, :] logits = tf.cast(get_logits('gen_logits', h, H.n_vocab), tf.float32) logits = tf.reshape(logits, [H.sample_batch, H.n_vocab]) temp = H.temperature symbol = tf.cast(tf.multinomial(logits / temp, 1), tf.int32) with tf.device('/cpu:0'): next_embd = tf.unravel_index(embd_index, embd_shapes) # unravel_index yields a embd_size, n_batch tensor next_embd = tf.transpose(next_embd, [1, 0]) next_embd = tf.reshape(next_embd, [ H.sample_batch, H.emb_number, 1]) next_index = embd_index + 1 return (step + 1, tf.concat([acc, symbol], axis=1), symbol, next_embd, next_index, cache) _, output_seq, _, _, _, _ = tf.while_loop( cond=cond, body=body, loop_vars=loop_vars, back_prop=False, shape_invariants=shape_invariants, parallel_iterations=1) # Now, we want to gather the images across all ranks which have generated # them. We will just allreduce a sparse tensor. all_samples = [tf.zeros_like(output_seq) for _ in range(mpi_size())] all_samples[mpi_rank()] = output_seq all_samples = tf.cast(allreduce(tf.cast( tf.concat(all_samples, axis=0), tf.float32)), tf.int32) return all_samples def warmup_cosine(current_iter): current_iter = tf.cast(current_iter, tf.float32) + 1.0 warmup_iters = tf.cast(H.warmup_iters, tf.float32) s = tf.cast(tf.less(current_iter, warmup_iters), tf.float32) current_fraction = ((current_iter - warmup_iters) / (H.n_updates_total - warmup_iters)) return (s * (current_iter / warmup_iters) + (1 - s) * (0.5 * (1 + tf.cos(math.pi * current_fraction)))) def warmup_linear_decay(current_iter): current_iter = tf.cast(current_iter, tf.float32) + 1.0 warmup_iters = tf.cast(H.warmup_iters, tf.float32) s = tf.cast(tf.less(current_iter, warmup_iters), tf.float32) current_fraction = tf.minimum( ((current_iter - warmup_iters) / (H.n_updates_total - warmup_iters)), tf.cast(1, tf.float32)) return (s * (current_iter / warmup_iters) + (1 - s) * (1.0 - current_fraction)) def mpi_train(): with tf.device('/cpu:0'), tf.name_scope('optimizer'): if H.decay_lr_linearly: lr_at_time = H.lr * warmup_linear_decay(H.global_step - H.lr_offset) else: lr_at_time = H.lr * warmup_cosine(H.global_step - H.lr_offset) rcp_mpi_size = tf.constant(1.0 / mpi_size()) grad_scale = tf.reciprocal(H.curr_loss_scale) with tf.device("/gpu:0"): avg_loss_gen, _ = model(train=True) H.train_gen_loss = avg_loss_gen # n_updates_per_epoch H.global_step loss_to_optimize = avg_loss_gen params = tf.trainable_variables() grads = bs.gradients(bs.scale_tensor(loss_to_optimize, H.curr_loss_scale), params) if H.merge_layer_allreduce > 0: search_strings = list() stride = H.merge_layer_allreduce for l in range(H.n_layer - 1, -1, -stride): search_strings.append([f"model/h{j}" for j in range(l, l - stride, -1)]) else: logprint('Not interleaving allreduce with backprop! Is slow.') search_strings = None if mpi_size() > 1: H.train_gen_loss = allreduce(bs.scale_tensor(avg_loss_gen, rcp_mpi_size)) # Pre-scale the gradients to give all-reduce some room. # After gradients are computed on this device scaling here can be rather aggressive. # But 1/mpi_size should be enough. grads = [bs.filter_tensor(x, rcp_mpi_size) for x in grads] cast_all = tf.float16 if H.fp16_allreduce else None grads = group_allreduce(grads, params, search_strings=search_strings, cast_all=cast_all) serialize_allreduce_ops([H.train_gen_loss] + grads) if H.log_grad_stats and mpi_rank() == 0: grads = log_gradient_values(grads, params, H.global_step, model_dir=H.model_dir) train_op, global_norm = get_optimizer(H.optimizer)( grads, params, learning_rate=lr_at_time, grad_scale=grad_scale, fp16_mean_var=H.fp16_mean_var, max_grad_norm=H.max_grad_norm, static_loss_scaling=H.float16 and not H.dynamic_loss_scaling, beta2=H.beta2) if H.l2_loss > 0: # AdamW logprint('enabling l2 loss with value', H.l2_loss) updates = [train_op] l2_updates = [] for p in params: if len(shape_list(p)) > 1: l2_updates.append(p.assign(p - lr_at_time * H.l2_loss * p)) updates.extend(l2_updates) train_op = tf.group(*updates) if not H.disable_ema_vars: # Polyak average of params. Stores an extra copy. # NOTE: this assignment is stateful -- graphs created after this will use the EMA var, see # the variable getter, so the order of mpi_train and eval model creation cannot be swapped. # TODO: remove this constraint H.ema = bs.Ema(decay=H.weights_beta) with tf.control_dependencies([train_op]): train_op = H.ema.apply(params) return train_op, lr_at_time, global_norm def eval(test=False, epoch=None): if test: tx = dataset.teX else: tx = dataset.vaX losses = [] for data in iter_data_mpi(tx, n_batch=H.n_batch, log=logprint, split_by_rank=dataset.full_dataset_valid): feeds = {H.X_ph: data[0], H.X_emb_ph: H.x_emb} if H.num_self_gen_in_use > 0 and not H.use_unconditional_augmentation: feeds[H.Y_gen_ph] = np.zeros((data[0].shape[0], H.num_self_gen_in_use), dtype=np.int32) losses.append(sess.run(H.avg_eval_loss_gen, feeds)) avg_loss = sum(losses) / len(losses) content = dict(epoch=epoch, series='eval_loss', loss=avg_loss, bits=avg_loss / np.log(2.)) logprint(**content) mpi_barrier() return avg_loss def get_data(partition): return { "train": (dataset.trX, dataset.trY), "valid": (dataset.vaX, dataset.vaY), "test": (dataset.teX, dataset.teY), }[partition] def aug_eval(partition, epoch): tx, ty = get_data(partition) if H.aug_eval_n_examples is not None: tx = tx[:H.aug_eval_n_examples] if ty is not None: ty = ty[:H.aug_eval_n_examples] gen_in_use = [gen for gen in H.self_gen_types if gen.is_used] if not gen_in_use: gen_in_use = [AugmentationType("sos", "identity", 1, True, identity)] aug_choices = [gen.num_tokens for gen in gen_in_use] for aug_types in go_over(aug_choices): fname = os.path.join( H.model_dir, f"{H.desc}_" + "_".join(map(str, aug_types)) + "_losses.npz") if os.path.exists(fname): if mpi_rank() == 0: print(f" Evaluated {fname}") continue if mpi_rank() == 0: print(f"Evaluating {fname}") losses = [] imgs = [] for data in iter_data_mpi(tx, n_batch=H.n_batch, log=logprint, split_by_rank=dataset.full_dataset_valid): feeds = {H.X_ph: data[0], H.X_emb_ph: H.x_emb} x_emb = np.concatenate([H.x_emb.copy() for _ in range(H.n_batch)], axis=0) d_in = data[0] if H.num_self_gen_in_use > 0: y_gen_list = [] for aug_type, gen in zip(aug_types, gen_in_use): if gen.sos_name == 'sos_data': raise NotImplementedError("sos_data is not supported in aug_eval") yy = np.full((H.n_batch, 1), aug_type, dtype=np.int32) d_in, x_emb, y_gen = gen.fn(d_in, x_emb, yy=yy) assert (y_gen == yy).all() y_gen_list.append(y_gen) feeds[H.X_ph] = d_in if H.permute_embeddings: feeds[H.X_emb_ph] = x_emb if not H.use_unconditional_augmentation: feeds[H.Y_gen_ph] = np.concatenate(y_gen_list, axis=1) assert (feeds[H.Y_gen_ph] == np.stack([aug_types] * H.n_batch)).all() imgs.append(d_in) cur_loss = sess.run(H.eval_gen_losses, feeds) assert cur_loss.shape[0] == H.n_batch losses.append(cur_loss) losses = np.concatenate(losses, axis=0).astype(np.float32) assert losses.shape[0] == tx.shape[0] // mpi_size() mpi_barrier() losses = mpi_allgather(losses) assert losses.shape[0] == tx.shape[0] loss = losses.mean() content = dict(epoch=epoch, aug_types=aug_types, loss=loss, bpd=loss / np.log(2.0)) logprint(**content) content["losses"] = losses if mpi_rank() == 0: np.savez(fname, **content) imgs = np.concatenate(imgs, axis=0) assert imgs.shape[0] == tx.shape[0] // mpi_size() mpi_barrier() imgs = mpi_allgather(imgs) assert imgs.shape == tx.shape if mpi_rank() == 0 and partition != "test": fname = os.path.join(H.model_dir, f"{H.desc}_" + "_".join(map(str, aug_types)) + "_imgs.npz") np.savez(fname, imgs=imgs.reshape(dataset.orig_shape)) mpi_barrier() def sample(name): sample_batches = [] assert H.samples_to_generate % (H.sample_batch * mpi_size()) == 0 for idx in range(H.samples_to_generate // (H.sample_batch * mpi_size())): feeds = {} if H.num_self_gen_in_use > 0 and not H.use_unconditional_augmentation: feeds[H.Y_gen_ph] = np.zeros((H.sample_batch, H.num_self_gen_in_use), dtype=np.int32) samples = sess.run(sample_output, feeds) sample_batches.append(samples) logprint(f'generated {sum([a.shape[0] for a in sample_batches])} / {H.samples_to_generate} samples') if idx == 0 and H.samples_memorycheck: mem = sess.run(tf.contrib.memory_stats.MaxBytesInUse()) logprint('Runtime memory usage so far (bytes):', f'{mem:,}') logprint(memory_usage=mem) if mpi_rank() == 0: samples = np.concatenate(sample_batches, axis=0) nppath = os.path.join(H.model_dir, f'{H.desc}-samples-{H.samples_to_generate}-t{H.temperature}.npy') np.save(nppath, samples) def sample_augmentation_type(n, size=None, nprng=None): """ Sample one of `n` augmentation types. Index 0 is reserved for not augmenting. """ if nprng is None: nprng = np.random if H.unaugmented_data_rate is None: y = nprng.randint(n, size=size) else: # We draw multiple augmentation types independently, so the probability # of not using augmentation has to be discounted accordingly. n_types = max(H.num_self_gen_in_use, 1) p = H.unaugmented_data_rate ** (1.0 / n_types) pmf = [p] + [(1.0 - p) / (n - 1)] * (n - 1) y = nprng.choice(n, size=size, p=pmf) return y.astype(np.int32) def data_aug(xx, nprng=None, yy=None): """just hflip""" if nprng is None: nprng = np.random xx = xx.reshape(dataset.orig_shape) if yy is None: yy = sample_augmentation_type(2, size=xx.shape[0], nprng=nprng) assert yy.shape[0] == xx.shape[0] # n = len(xx) # xx = np.pad(xx, [[0, 0], [4, 4], [4, 4], [0, 0]], mode='reflect') xx = [np.fliplr(x) if y else x for x, y in zip(xx, yy)] # ii = nprng.randint(low=0, high=4 * 2 + 1, size=n) # jj = nprng.randint(low=0, high=4 * 2 + 1, size=n) # xx = [x[i:i + 32, j:j + 32] for x, i, j in zip(xx, ii, jj)] xx = np.asarray(xx).reshape(dataset.shape) return xx def identity(xx, x_emb, nprng=None, yy=None): return xx, x_emb, yy def rotate(xx, x_emb, nprng=None, yy=None): b = xx.shape[0] b_emb, n_emb, n_ctx = x_emb.shape assert b == b_emb assert n_ctx == np.prod(dataset.orig_shape[1:]) if yy is None: yy = sample_augmentation_type(4, size=(b, 1), nprng=nprng) assert yy.shape[0] == xx.shape[0] xx = xx.reshape(dataset.orig_shape) xx = [np.rot90(x, k=yy[i, 0], axes=(1, 0)) for i, x in enumerate(xx)] xx = np.asarray(xx).reshape(dataset.shape) x_emb = x_emb.reshape((b_emb, n_emb, *dataset.orig_shape[1:])) x_emb = [np.rot90(x, k=yy[i, 0], axes=(2, 1)) for i, x in enumerate(x_emb)] x_emb = np.asarray(x_emb).reshape((b_emb, n_emb, n_ctx)) return xx, x_emb, yy def transpose(xx, x_emb, nprng=None, yy=None): b = xx.shape[0] b_emb, n_emb, n_ctx = x_emb.shape assert b == b_emb assert n_ctx == np.prod(dataset.orig_shape[1:]) if yy is None: yy = sample_augmentation_type(2, size=(b, 1), nprng=nprng) assert yy.shape[0] == xx.shape[0] xx = xx.reshape(dataset.orig_shape) xx = [np.transpose(x, [1, 0, 2]) if yy[i, 0] == 1 else x for i, x in enumerate(xx)] xx = np.asarray(xx).reshape(dataset.shape) x_emb = x_emb.reshape((b_emb, n_emb, *dataset.orig_shape[1:])) x_emb = [np.transpose(x, [0, 2, 1, 3]) if yy[i, 0] == 1 else x for i, x in enumerate(x_emb)] x_emb = np.asarray(x_emb).reshape((b_emb, n_emb, n_ctx)) return xx, x_emb, yy def reverse(xx, x_emb, nprng=None, yy=None): b = xx.shape[0] b_emb, n_emb, n_ctx = x_emb.shape assert b == b_emb assert n_ctx == np.prod(dataset.orig_shape[1:]) if yy is None: yy = sample_augmentation_type(2, size=(b, 1), nprng=nprng) assert yy.shape[0] == xx.shape[0] xx = xx.reshape(dataset.orig_shape) xx = [np.rot90(x, k=yy[i, 0] * 2, axes=(1, 0)) for i, x in enumerate(xx)] xx = np.asarray(xx).reshape(dataset.shape) x_emb = x_emb.reshape((b_emb, n_emb, *dataset.orig_shape[1:])) x_emb = [np.rot90(x, k=yy[i, 0] * 2, axes=(2, 1)) for i, x in enumerate(x_emb)] x_emb = np.asarray(x_emb).reshape((b_emb, n_emb, n_ctx)) return xx, x_emb, yy def autoaugment_conditioning(rate): def fn(xx, x_emb, nprng=None, yy=None): if nprng is None: nprng = np.random b = xx.shape[0] # 1 when augment is applied if yy is None: yy = (nprng.uniform(size=(b, 1)) < rate).astype(np.int32) assert yy.shape[0] == xx.shape[0] return xx, x_emb, yy return fn def permute_arbitrarily(random_perms): perms = [np.arange(dataset.ctx)] + random_perms n = len(perms) def fn(xx, x_emb, nprng=None, yy=None): b, n_ctx = xx.shape b_emb, n_emb, n_emb_ctx = x_emb.shape assert b == b_emb assert n_ctx == n_emb_ctx if yy is None: yy = sample_augmentation_type(n, size=(b, 1), nprng=nprng) assert yy.shape[0] == xx.shape[0] xx_new = [] x_emb_new = [] for i, y in enumerate(yy): xx_new.append(xx[i][perms[y[0]]]) x_emb_new.append(x_emb[i][:, perms[y[0]]]) xx = np.concatenate(xx_new, axis=0).reshape(dataset.shape) x_emb = np.concatenate(x_emb_new, axis=0) x_emb = x_emb.reshape(b_emb, n_emb, n_emb_ctx) return xx, x_emb, yy return fn def remap_c(xx, order): new = np.zeros_like(xx) a, b, c = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0) ][order] new[:, :, 0] = xx[:, :, a] new[:, :, 1] = xx[:, :, b] new[:, :, 2] = xx[:, :, c] return new def color_swap(xx, x_emb, nprng=None, yy=None): b = xx.shape[0] b_emb, n_emb, n_ctx = x_emb.shape assert b == b_emb assert n_ctx == np.prod(dataset.orig_shape[1:]) if yy is None: yy = sample_augmentation_type(6, size=(b, 1), nprng=nprng) assert yy.shape[0] == xx.shape[0] xx = xx.reshape(dataset.orig_shape) x_emb = x_emb.reshape((b, n_emb * dataset.orig_shape[1], *dataset.orig_shape[2:])) xx_new = [] x_emb_new = [] for i, order in enumerate(yy): xx_new.append(remap_c(xx[i], order[0])) x_emb_new.append(remap_c(x_emb[i], order[0])) xx = np.concatenate(xx_new, axis=0).reshape(dataset.shape) x_emb = np.concatenate(x_emb_new, axis=0).reshape((b_emb, n_emb, n_ctx)) return xx, x_emb, yy def remap_jigsaw(x, order): r, c, ch = x.shape g = H.jigsaw_grid_size gr, gc = r // g, c // g x = x.reshape((g, gr, g, gc, ch)) x = np.transpose(x, [0, 2, 1, 3, 4]) x = x.reshape([g * g, gr, gc, ch]) perm = H.jigsaw_perms[order] x = x[perm, :, :, :] x = x.reshape([g, g, gr, gc, ch]) x = np.transpose(x, [0, 2, 1, 3, 4]) x = x.reshape((r, c, ch)) return x def jigsaw(xx, x_emb, nprng=None, yy=None): b = xx.shape[0] b_emb, n_emb, n_ctx = x_emb.shape r, c, ch = dataset.orig_shape[1:] assert b == b_emb assert n_ctx == np.prod(dataset.orig_shape[1:]) xx = xx.reshape(dataset.orig_shape) if yy is None: yy = sample_augmentation_type(H.jigsaw_num_perms, size=(b, 1), nprng=nprng) assert yy.shape[0] == xx.shape[0] x_emb = x_emb.reshape(b, n_emb, r, c, ch) x_emb = np.transpose(x_emb, [0, 2, 1, 3, 4]) x_emb = x_emb.reshape((b, n_emb * r, c, ch)) xx_new = [] x_emb_new = [] for i, order in enumerate(yy): xx_new.append(remap_jigsaw(xx[i], order[0])) x_emb_new.append(remap_jigsaw(x_emb[i], order[0])) xx = np.concatenate(xx_new, axis=0).reshape(dataset.shape) x_emb = np.concatenate(x_emb_new, axis=0) x_emb = x_emb.reshape(b, r, n_emb, c, ch) x_emb = np.transpose(x_emb, [0, 2, 1, 3, 4]) x_emb = x_emb.reshape((b_emb, n_emb, n_ctx)) return xx, x_emb, yy if __name__ == '__main__': if 'CUDA_VISIBLE_DEVICES' not in os.environ: os.environ['CUDA_VISIBLE_DEVICES'] = str(local_mpi_rank()) parser = argparse.ArgumentParser() add_arguments(parser) parse_args_and_update_hparams(H, parser) H.model_dir = os.path.join(H.out_dir, H.desc) os.makedirs(H.model_dir, exist_ok=True) H.log_path = os.path.join(H.model_dir, 'log') logprint = logger(H.log_path) logprint(hyperparams=H, pprint=True) # Same numpy seed so we can shuffle the data across ranks similarly np.random.seed(H.seed) # Different seed for TF to randomize model sampling/dropout across ranks tf.set_random_seed(H.seed * mpi_rank()) # Augmentation nprng aug_nprng = np.random.RandomState(H.aug_seed + mpi_rank()) # Cache for objects/tensors that should persist through recompute, eval, and/or samples H.bst_cache = dict() H.dropout_cache = dict() H.dense_mask_cache = dict() H.var_cache = dict() H.ema = None H.reduced_targets = H.kmeans_targets or H.mse_targets H.attn_ctx = H.n_ctx H.dtype = tf.float16 if H.float16 else tf.float32 bs.set_entropy() # for bs.dropout if mpi_size() == 1: logprint("WARNING: Only one MPI rank, did you forget to run w/ MPI?") dataset = get_dataset(H.dataset)(H, logprint) if H.auxiliary_dataset is not None: if mpi_rank() == 0: logprint("") pmf = [1.0 - H.auxiliary_dataset_fraction, H.auxiliary_dataset_fraction] aux_dataset = get_dataset(H.auxiliary_dataset)(H, logprint) if H.auxiliary_dataset_subset_size is not None: n_train = H.auxiliary_dataset_subset_size aux_dataset.trX = aux_dataset.trX[:n_train] if mpi_rank() == 0: logprint(f"taking a subset of auxiliary dataset {len(aux_dataset.trX)}") aux_dataset.iters_per_epoch = n_train // (mpi_size() * aux_dataset.n_batch) datasets = (dataset, aux_dataset) dataset = JankySubsampledDataset(datasets, pmf, seed=H.auxiliary_dataset_seed) H.emb_number = dataset.num_embeddings H.emb_vocabs = dataset.embedding_sizes H.n_classes = dataset.n_classes H.X_emb_shape = [None] + [H.emb_number] + dataset.shape[1:] H.x_emb = dataset.x_emb # Making n_vocab the nearest multiple of 128 allows the usage of # tensor cores with fp16 on V100's, which speeds up large vocab problems if H.no_vocab_rounding: H.n_vocab = dataset.n_vocab else: H.n_vocab = (math.ceil(dataset.n_vocab / 128)) * 128 H.X_shape = [None] + dataset.shape[1:] with tf.device("/gpu:0"), tf.name_scope('placeholders'): H.X_ph = tf.placeholder(tf.int32, H.X_shape) H.X_emb_ph = tf.placeholder(tf.int32, H.X_emb_shape) H.jigsaw_perms = list(itertools.permutations(list(range(H.jigsaw_grid_size ** 2)))) H.jigsaw_num_perms = len(H.jigsaw_perms) nprng = np.random.RandomState(H.randomly_determined_order_seed) random_perms = [ nprng.permutation(dataset.ctx) for _ in range(H.randomly_determined_order_num_perms) ] H.self_gen_types = [ AugmentationType("sos_rot", "rotation", 4, H.use_rotation, rotate), AugmentationType("sos_c", "color swapping", 6, H.use_color, color_swap), AugmentationType("sos_tr", "transposition", 2, H.use_transposition, transpose), AugmentationType("sos_rev", "reverse", 2, H.use_reverse, reverse), AugmentationType("sos_js", f"jigsaw with grid size {H.jigsaw_grid_size}", H.jigsaw_num_perms, H.use_jigsaw, jigsaw), AugmentationType("sos_aa", "autoaugment", 2, H.rand_augment_conditioning, autoaugment_conditioning(H.rand_augment_rate)), AugmentationType("sos_rd", "randomly determined order", H.randomly_determined_order_num_perms + 1, H.use_randomly_determined_order, permute_arbitrarily(random_perms)), AugmentationType("sos_data", "dataset", 2, H.use_dataset_conditioning, None), ] H.num_self_gen_in_use = sum(typ.is_used for typ in H.self_gen_types) if mpi_rank() == 0: for typ in H.self_gen_types: if typ.is_used: logprint(f"Using [{typ.description}]") else: logprint(f"Not using [{typ.description}]") if H.use_unconditional_augmentation: logprint(f"Training without augmentation prompting") else: logprint(f"Training with augmentation prompting") if H.permute_embeddings: logprint("Permuting embeddings") else: logprint("Not permuting embeddings") if H.num_self_gen_in_use > 0 and not H.use_unconditional_augmentation: H.Y_gen_ph = tf.placeholder(tf.int32, [None, H.num_self_gen_in_use]) with tf.device("/cpu:0"): loss_scale_ph = tf.placeholder( tf.float32, shape=[], name="loss_scale") H.global_step = tf.get_variable( 'global_step', initializer=zeros_init(), shape=tuple(), trainable=False, dtype=tf.int64) num_epochs = tf.get_variable( 'num_epochs', initializer=zeros_init(), shape=tuple(), trainable=False, dtype=tf.int64) num_examples_processed = tf.get_variable( 'num_examples_processed', initializer=zeros_init(), shape=tuple(), trainable=False, dtype=tf.int64) curr_loss_scale = tf.get_variable( 'curr_loss_scale', initializer=constant_or_zeros_init(H.fp16_loss_scale), shape=tuple(), trainable=False, dtype=tf.float32) H.curr_loss_scale = curr_loss_scale best_val_loss = tf.get_variable( 'best_val_loss', initializer=constant_or_zeros_init(99999), shape=tuple(), trainable=False, dtype=tf.float32) val_loss = tf.placeholder(tf.float32, shape=[], name="val_loss") update_val_loss = tf.assign(best_val_loss, val_loss) update_loss_scale = tf.assign(curr_loss_scale, loss_scale_ph) increment_epochs = tf.assign_add(num_epochs, 1) increment_examples = tf.assign_add(num_examples_processed, H.n_batch * mpi_size()) increment_step = tf.assign_add(H.global_step, 1) n_updates_per_epoch = dataset.iters_per_epoch n_updates_total = H.total_epochs * n_updates_per_epoch H.n_updates_total = n_updates_total train_op, lr_at_step, global_norm = mpi_train() num_params = 0 for p in tf.trainable_variables(): num_params += np.prod(p.shape.as_list()) if H.print_params: logprint(f'{p.name}, {p.shape.as_list()}, {np.prod(p.shape.as_list()):,}') with tf.name_scope('eval_model'), tf.device('/gpu:0'): avg_eval_loss_gen, eval_gen_losses = model(train=False) H.eval_gen_losses = eval_gen_losses H.avg_eval_loss_gen = allreduce(avg_eval_loss_gen) * (1.0 / mpi_size()) if H.sample_and_exit or H.sample_during_eval: logprint('Creating sampling graph.') with tf.name_scope('sample_model'), tf.device('/gpu:0'): sample_output = sample_model() logprint('Done with sampling graph creation.') sess = get_session(mpi=True, disable_swapping=True, log=logprint) sess.run(tf.global_variables_initializer()) logprint(f'Total number trainable parameters: {num_params:,}') logprint(num_params=num_params, n_vocab=H.n_vocab, n_batch=H.n_batch, n_ctx=H.n_ctx, effective_minibatch=mpi_size() * H.n_batch, n_updates_total=n_updates_total, n_updates_per_epoch=n_updates_per_epoch, pprint=True) if H.restore_path: if mpi_rank() == 0: localpath = maybe_download(H.restore_path) logprint("loading from " + localpath) load_variables_from_file(sess, localpath, ema=False) logprint("Done loading from " + localpath) with tf.name_scope('sync_variables'): if mpi_size() > 1: logprint('Syncing initial variables across gpus') sync_variables(sess) logprint('Finishing syncing variables') ema_loss = None steps_since_starting = 0 save_dir = os.path.join(H.out_dir, H.desc) os.makedirs(save_dir, exist_ok=True) n_updates, n_epochs, curr_val_loss, loss_scale_t, examples_processed_t = sess.run([ H.global_step, num_epochs, best_val_loss, curr_loss_scale, num_examples_processed]) logprint(f"Starting at {n_updates} updates, {n_epochs} epochs, " + f"{curr_val_loss} best val loss, examples {examples_processed_t}") if H.sample_and_exit or H.sample_during_eval: sample('onload') if H.sample_and_exit: sys.exit(0) if H.eval_test or not H.skip_initial_evals or H.eval_and_exit: eval(test=H.eval_test, epoch=n_epochs) if H.eval_test or H.eval_and_exit: sys.exit(0) if H.aug_eval is not None: aug_eval(partition=H.aug_eval, epoch=n_epochs) sys.exit(0) # Free up some python memory H.bst_cache = None H.dropout_cache = None H.dense_mask_cache = None H.var_cache = None bs.clear_bst_constants() avg_t = 9999.0 loss_count = 0 if H.eval_after_n_examples: last_set_processed = examples_processed_t // H.eval_after_n_examples loss_scale_t = H.fp16_loss_scale times = [] losses = [] gns = [] for i in range(n_epochs, H.total_epochs): t0 = time.time() args = [dataset.trX] if H.use_dataset_conditioning: args.append(dataset.auxX) for data in iter_data_mpi(*args, n_batch=H.n_batch, log=logprint, iters=n_updates_per_epoch, shuffle=True, seed=i, split_by_rank=dataset.full_dataset_train): outputs = [train_op, H.train_gen_loss, lr_at_step, global_norm] d_in = data_aug(data[0], nprng=aug_nprng) if H.aug else data[0] feeds = {H.X_ph: d_in, H.X_emb_ph: H.x_emb} if H.num_self_gen_in_use > 0: y_gen_list = [] x_emb = np.concatenate([H.x_emb.copy() for _ in range(H.n_batch)], axis=0) d_gen = d_in.copy() for gen in H.self_gen_types: if not gen.is_used: continue if gen.fn is None and gen.sos_name == 'sos_data': y_gen = data[-1] else: d_gen, x_emb, y_gen = gen.fn(d_gen, x_emb, nprng=aug_nprng) assert d_gen.shape == d_in.shape assert y_gen.shape == (d_in.shape[0], 1) y_gen_list.append(y_gen) feeds[H.X_ph] = d_gen if H.permute_embeddings: feeds[H.X_emb_ph] = x_emb if not H.use_unconditional_augmentation: feeds[H.Y_gen_ph] = np.concatenate(y_gen_list, axis=1) is_rank0 = mpi_rank() == 0 if steps_since_starting == 2 or steps_since_starting == 65: mem = sess.run(tf.contrib.memory_stats.MaxBytesInUse()) logprint('Runtime memory usage so far (bytes):', f'{mem:,}') logprint(memory_usage=mem) t1 = time.time() _, loss_t, lr_t, gn_t = sess.run(outputs, feeds) t2 = time.time() if H.dynamic_loss_scaling and H.float16: # slowly increase loss scale but quickly drop it when inf or nan is detected in the gradients # global_norm will be nan/inf when this happens if np.isfinite(loss_t) and np.isfinite(gn_t): # Case: No infs or nans, roughly double the loss scale every 2k iters loss_scale_t = sess.run(update_loss_scale, {loss_scale_ph: loss_scale_t * 1.0003466337}) elif not np.isfinite(loss_t): # Incurred some nans on the forward pass, don't do anything. pass else: # gn_t is nan/inf and loss_t is non-nan, meaning the grad scaling was too high # Reduce by half and move to the next minibatch if loss_scale_t > H.min_loss_scale: loss_scale_t = sess.run(update_loss_scale, {loss_scale_ph: loss_scale_t * 0.5}) step_t = sess.run(increment_step) examples_processed = sess.run(increment_examples) n_updates += 1 gns.append(gn_t) times.append(t2 - t0) losses.append(loss_t) steps_since_starting += 1 if (steps_since_starting in [2**n for n in range(9)] or n_updates % H.iters_per_log == 0): loss_to_avg = [x for x in losses if np.isfinite(x)] if len(loss_to_avg) > 0: avg_loss = sum(loss_to_avg) / len(loss_to_avg) else: avg_loss = None avg_t = sum(times) / len(times) gns_so_far = [x for x in gns if np.isfinite(x)] if len(gns_so_far) > 0: max_gn_so_far = max([x for x in gns if np.isfinite(x)]) else: max_gn_so_far = -1 logprint(step=step_t, lr=lr_t, loss=loss_t, loss_avg=avg_loss, t_iter=t2 - t1, t_iter_avg=avg_t, t_data=t1 - t0, gn=gn_t, nans=len(losses) - len(loss_to_avg), loss_scale="2^%.0f" % np.log2(loss_scale_t), max_gn=max_gn_so_far, series='train_loss', examples=examples_processed) times = [] losses = [] gns = [] t0 = time.time() if H.eval_after_n_examples: sets_processed = examples_processed // H.eval_after_n_examples if sets_processed > last_set_processed: vl = eval(epoch=sets_processed) if H.sample_during_eval: sample(f'epoch-{sets_processed}') if vl < curr_val_loss: curr_val_loss = vl sess.run(update_val_loss, {val_loss: vl}) logprint(f'Saving model with val loss of {vl} at epoch {sets_processed}') save_params(sess, os.path.join(save_dir, 'model_best')) save_params(sess, os.path.join(save_dir, 'model_latest')) n = 12 if sets_processed in [2**i for i in range(n)] + [2**(n - 1) + 2 ** i for i in range(n)]: save_params(sess, os.path.join(save_dir, f'model_epoch{sets_processed}')) last_set_processed = sets_processed n_epochs = sess.run(increment_epochs) if not H.eval_after_n_examples: if n_epochs % H.epochs_per_eval == 0: vl = eval(epoch=n_epochs) if H.sample_during_eval: sample(f'epoch-{n_epochs}') if vl < curr_val_loss: curr_val_loss = vl sess.run(update_val_loss, {val_loss: vl}) logprint(f'Saving model with val loss of {vl} at epoch {n_epochs}') save_params(sess, os.path.join(save_dir, 'model_best')) if n_epochs % H.epochs_per_save == 0 and n_epochs > 0: save_params(sess, os.path.join(save_dir, 'model_latest')) if n_epochs in [2**i for i in range(12)]: save_params(sess, os.path.join(save_dir, f'model_epoch{n_epochs}')) if H.exit_after_n_epochs: if n_epochs >= H.exit_after_n_epochs: time.sleep(20) logprint(f'Exiting now, epoch={n_epochs}') sys.exit(0) save_params(sess, os.path.join(save_dir, 'model_latest')) logprint('Finished training.')