def rnn_cast()

in apex/apex/amp/wrap.py [0:0]


def rnn_cast(backend, fn, handle, verbose=False):
    orig_rnn = utils.get_func(backend, fn)
    @functools.wraps(orig_rnn)
    def rnn_wrapper(*args, **kwargs):
        flat_weight = kwargs.get('flat_weight')
        if flat_weight is not None:
            # We replace `flat_weight` with an uninitialized fp16
            # Tensor. The "actual" weight tensors (provided in `forward`),
            # will then be set up as ptrs into the buffer and have the
            # corresponding fp32 values copied in.
            # We need to call `copy` on the "actual" weights so that the
            # autograd graph correctly backprops from the wgrads computed
            # inside cuDNN (on fp16 weights) into the fp32 weights.
            assert utils.type_string(flat_weight) == 'FloatTensor'
            if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
                # Pre-0.4. A little slower, since it zeros out memory.
                flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
            else:
                flat_weight_fp16 = torch.empty_like(flat_weight,
                                                    dtype=torch.float16)
            kwargs['flat_weight'] = flat_weight_fp16
        else:
            flat_weight_fp16 = None

        forward = orig_rnn(*args, **kwargs)
        @functools.wraps(forward)
        def fwd_wrapper(*fargs, **fkwargs):
            assert len(fargs) == 3 or len(fargs) == 4
            inputs, weights, hiddens = fargs[:3]
            assert utils.is_fp_tensor(inputs)
            assert isinstance(weights, list)
            cast_fn = utils.verbosify(utils.maybe_half,
                                      fn,
                                      verbose)
            new_args = []

            # 0) Inputs
            new_args.append(cast_fn(inputs))

            # 1) Weights
            if flat_weight_fp16 is not None:
                fp16_weights = utils.synthesize_flattened_rnn_weights(
                    weights, flat_weight_fp16, fn, verbose)
            else:
                fp16_weights = [[cast_fn(w) for w in layer]
                                for layer in weights]
            new_args.append(fp16_weights)

            # 2) Inputs: either a tuple (for LSTM) or single tensor
            if isinstance(hiddens, tuple):
                new_args.append(tuple(cast_fn(x) for x in hiddens))
            elif utils.is_fp_tensor(hiddens):
                new_args.append(cast_fn(hiddens))
            else:
                # Hiddens can, in principle, be `None` -- pass through
                new_args.append(hiddens)

            # 3) Batch sizes (0.4 or later only)
            if len(fargs) == 4:
                new_args.append(fargs[3])

            return forward(*new_args, **fkwargs)
        return fwd_wrapper
    utils.set_func_save(handle, backend, fn, rnn_wrapper)