in understanding_rl_vision/rl_clarity/loading.py [0:0]
def save_lucid_model(config, params, *, model_path, metadata_path):
config = config.copy()
config.pop("num_envs")
library = config.get("library", "baselines")
venv = create_env(1, **config)
arch = get_arch(**config)
with tf.Graph().as_default(), tf.Session() as sess:
observation_space = venv.observation_space
observations_placeholder = tf.placeholder(
shape=(None,) + observation_space.shape, dtype=tf.float32
)
if library == "baselines":
from baselines.common.policies import build_policy
with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE):
policy_fn = build_policy(venv, arch)
policy = policy_fn(
nbatch=None,
nsteps=1,
sess=sess,
observ_placeholder=(observations_placeholder * 255),
)
pd = policy.pd
vf = policy.vf
else:
raise ValueError(f"Unsupported library: {library}")
load_params(params, sess=sess)
Model.save(
model_path,
input_name=observations_placeholder.op.name,
output_names=[pd.logits.op.name, vf.op.name],
image_shape=observation_space.shape,
image_value_range=[0.0, 1.0],
)
metadata = {
"policy_logits_name": pd.logits.op.name,
"value_function_name": vf.op.name,
"env_name": config.get("env_name"),
"gae_gamma": config.get("gamma"),
"gae_lambda": config.get("lambda"),
}
env = venv
while hasattr(env, "env") and (not hasattr(env, "combos")):
env = env.env
if hasattr(env, "combos"):
metadata["action_combos"] = env.combos
else:
metadata["action_combos"] = None
save_joblib(metadata, metadata_path)
return {"model_bytes": read(model_path, cache=False, mode="rb"), **metadata}