def load_variables()

in ma_policy/load_policy.py [0:0]


def load_variables(policy, weights):
    weights = {os.path.normpath(key): value for key, value in weights.items()}
    weights = {replace_base_scope(key, policy.scope): value for key, value in weights.items()}
    assign_ops = []
    for var in policy.get_variables():
        var_name = os.path.normpath(var.name)
        if var_name not in weights:
            logging.warning(f"{var_name} was not found in weights dict. This will be reinitialized.")
            tf.get_default_session().run(var.initializer)
        else:
            try:
                assert np.all(np.array(shape_list(var)) == np.array(weights[var_name].shape))
                assign_ops.append(var.assign(weights[var_name]))
            except Exception:
                traceback.print_exc(file=sys.stdout)
                print(f"Error assigning weights of shape {weights[var_name].shape} to {var}")
                sys.exit()
    tf.get_default_session().run(assign_ops)