in coinrun/main_utils.py [0:0]
def get_savable_params(loaded_params, scope, keep_heads=False):
params = tf.trainable_variables(scope)
filtered_params = []
filtered_loaded = []
if len(loaded_params) != len(params):
print('param mismatch', len(loaded_params), len(params))
assert(False)
for p, loaded_p in zip(params, loaded_params):
keep = True
if any((scope + '/' + x) in p.name for x in ['v','pi']):
keep = keep_heads
if keep:
filtered_params.append(p)
filtered_loaded.append(loaded_p)
else:
print('drop', p)
return filtered_loaded, filtered_params