in apex/apex/amp/wrap.py [0:0]
def rnn_cast(backend, fn, handle, verbose=False):
orig_rnn = utils.get_func(backend, fn)
@functools.wraps(orig_rnn)
def rnn_wrapper(*args, **kwargs):
flat_weight = kwargs.get('flat_weight')
if flat_weight is not None:
# We replace `flat_weight` with an uninitialized fp16
# Tensor. The "actual" weight tensors (provided in `forward`),
# will then be set up as ptrs into the buffer and have the
# corresponding fp32 values copied in.
# We need to call `copy` on the "actual" weights so that the
# autograd graph correctly backprops from the wgrads computed
# inside cuDNN (on fp16 weights) into the fp32 weights.
assert utils.type_string(flat_weight) == 'FloatTensor'
if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
# Pre-0.4. A little slower, since it zeros out memory.
flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
else:
flat_weight_fp16 = torch.empty_like(flat_weight,
dtype=torch.float16)
kwargs['flat_weight'] = flat_weight_fp16
else:
flat_weight_fp16 = None
forward = orig_rnn(*args, **kwargs)
@functools.wraps(forward)
def fwd_wrapper(*fargs, **fkwargs):
assert len(fargs) == 3 or len(fargs) == 4
inputs, weights, hiddens = fargs[:3]
assert utils.is_fp_tensor(inputs)
assert isinstance(weights, list)
cast_fn = utils.verbosify(utils.maybe_half,
fn,
verbose)
new_args = []
# 0) Inputs
new_args.append(cast_fn(inputs))
# 1) Weights
if flat_weight_fp16 is not None:
fp16_weights = utils.synthesize_flattened_rnn_weights(
weights, flat_weight_fp16, fn, verbose)
else:
fp16_weights = [[cast_fn(w) for w in layer]
for layer in weights]
new_args.append(fp16_weights)
# 2) Inputs: either a tuple (for LSTM) or single tensor
if isinstance(hiddens, tuple):
new_args.append(tuple(cast_fn(x) for x in hiddens))
elif utils.is_fp_tensor(hiddens):
new_args.append(cast_fn(hiddens))
else:
# Hiddens can, in principle, be `None` -- pass through
new_args.append(hiddens)
# 3) Batch sizes (0.4 or later only)
if len(fargs) == 4:
new_args.append(fargs[3])
return forward(*new_args, **fkwargs)
return fwd_wrapper
utils.set_func_save(handle, backend, fn, rnn_wrapper)