in apex/apex/amp/_initialize.py [0:0]
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init
optimizers_was_list = False
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
optimizers = [optimizers]
elif optimizers is None:
optimizers = []
elif isinstance(optimizers, list):
optimizers_was_list = True
check_optimizers(optimizers)
else:
check_optimizers([optimizers])
raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
if isinstance(models, torch.nn.Module):
models_was_list = False
models = [models]
elif isinstance(models, list):
models_was_list = True
else:
raise TypeError("models must be either a single model or a list of models.")
check_models(models)
if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models)
# In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model.
if properties.cast_model_type:
if properties.keep_batchnorm_fp32:
for model in models:
convert_network(model, properties.cast_model_type)
else:
for model in models:
model.to(properties.cast_model_type)
input_caster = functools.partial(to_type, properties.cast_model_type)
if cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
else:
output_caster = functools.partial(to_type, torch.float32)
for model in models:
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()."
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
output = old_fwd(*applier(args, input_caster),
**applier(kwargs, input_caster))
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
elif cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
for model in models:
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
output = old_fwd(*args, **kwargs)
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
for i, optimizer in enumerate(optimizers):
# Still need to special case this for the first pass
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
else:
optimizers[i] = _process_optimizer(optimizer, properties)
_amp_state.loss_scalers = []
for _ in range(num_losses):
_amp_state.loss_scalers.append(LossScaler(properties.loss_scale,
min_loss_scale=_amp_state.min_loss_scale,
max_loss_scale=_amp_state.max_loss_scale))
if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway.
handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
for optimizer in optimizers:
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
def patch_step(old_step):
def new_step(*args, **kwargs):
with disable_casts():
output = old_step(*args, **kwargs)
return output
return new_step
optimizer.step = patch_step(optimizer.step)
if optimizers_was_list:
if models_was_list:
return models, optimizers
else:
return models[0], optimizers
else:
if models_was_list:
if len(optimizers) == 0:
return models
else:
return models, optimizers[0]
else:
if len(optimizers) == 0:
return models[0]
else:
return models[0], optimizers[0]