in apex/apex/amp/wrap.py [0:0]
def new_rnn_cast(fn, handle, verbose=False):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
# _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,
# which in turn has patched the ref named "_VF" in torch.nn.modules.rnn.
if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):
mod = torch.nn.modules.rnn._rnn_impls
else:
mod = torch.nn.modules.rnn._VF
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
fn = fn.lower()
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
# Exact call signature from modules/rnn.py
assert len(args) == 9
assert len(kwargs) == 0
if not _amp_state.handle.is_active():
return orig_fn(*args, **kwargs)
if isinstance(args[6], bool):
params_idx = 2 # Not PackedSequence case
else:
params_idx = 3 # PackedSequence case
new_args = []
for i, arg in enumerate(args):
if i == params_idx:
num_params = sum([x.numel() for x in arg])
fp16_weight_buf = args[0].new_empty((num_params,),
dtype=torch.half)
casted_weights = utils.new_synthesize_flattened_rnn_weights(
arg, fp16_weight_buf, fn, verbose)
new_args.append(casted_weights)
elif utils.is_fp_tensor(arg):
new_args.append(cast_fn(arg))
else:
new_args.append(arg)
return orig_fn(*new_args)
utils.set_func_save(handle, mod, fn, wrapper)