in utils.py [0:0]
def load_variables(sess, weights, ignore=None, trainable=False, ema=True):
'''ema refers to whether the exponential moving averaged weights are used to
initialize the true weights or not.'''
weights = {os.path.normpath(key): value for key, value in weights.items()}
ops = []
feed_dict = {}
if ema:
gvs_map = {v.name: v for v in tf.global_variables()}
for i, var in enumerate(get_variables(trainable=trainable)):
var_name = os.path.normpath(var.name)
if ignore:
do_not_load = False
for ignore_substr in ignore:
if ignore_substr in var_name:
do_not_load = True
if do_not_load:
continue
ph = tf.placeholder(dtype=var.dtype, shape=var.shape)
ops.append(var.assign(ph))
if ema:
ema_name = f'{var_name[:-2]}/Ema/ema:0'
# We assign the EMA value to the current value
try:
feed_dict[ph] = weights[ema_name]
except KeyError:
print(f'warning: ema var not found for {var_name}')
feed_dict[ph] = weights[var_name]
# We also assign the EMA value to the current EMA, which will otherwise
# use the initialized value of the variable (random)
ema_var = gvs_map[ema_name]
ph = tf.placeholder(dtype=ema_var.dtype, shape=ema_var.shape)
ops.append(ema_var.assign(ph))
feed_dict[ph] = weights[ema_name]
else:
feed_dict[ph] = weights[var_name]
sess.run(ops, feed_dict)