in train.py [0:0]
def f32_storage_getter(getter, name, shape=None, dtype=tf.float32,
initializer=None, regularizer=None,
trainable=True, *args, **kwargs):
"""Custom variable getter that forces trainable variables to be stored in
float32 precision and then casts them to the training precision.
https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/
index.html#mptrain
"""
var = H.var_cache.get(name)
if var is None:
with tf.control_dependencies(None):
var = getter(name, shape, dtype=tf.float32,
initializer=initializer,
regularizer=regularizer,
trainable=trainable,
*args, **kwargs)
H.var_cache[name] = var
if H.ema is not None:
var = H.ema.average(var)
if dtype != var.dtype.base_dtype:
var = bs.float_cast(var, dtype=dtype, dx_dtype=dtype, name=f"{name}/cast")
return var