in ma_policy/load_policy.py [0:0]
def load_variables(policy, weights):
weights = {os.path.normpath(key): value for key, value in weights.items()}
weights = {replace_base_scope(key, policy.scope): value for key, value in weights.items()}
assign_ops = []
for var in policy.get_variables():
var_name = os.path.normpath(var.name)
if var_name not in weights:
logging.warning(f"{var_name} was not found in weights dict. This will be reinitialized.")
tf.get_default_session().run(var.initializer)
else:
try:
assert np.all(np.array(shape_list(var)) == np.array(weights[var_name].shape))
assign_ops.append(var.assign(weights[var_name]))
except Exception:
traceback.print_exc(file=sys.stdout)
print(f"Error assigning weights of shape {weights[var_name].shape} to {var}")
sys.exit()
tf.get_default_session().run(assign_ops)