in rllib_model_custom_torch.py [0:0]
def __init__(self, obs_space, action_space, num_outputs, model_config,
name, **model_kwargs):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
''' Load and check configuarations '''
assert num_outputs % 2 == 0, (
"num_outputs must be divisible by two", num_outputs)
num_outputs = num_outputs // 2
custom_model_config = MOEPolicyBase.DEFAULT_CONFIG.copy()
custom_model_config_by_user = model_config.get("custom_model_config")
if custom_model_config_by_user:
custom_model_config.update(custom_model_config_by_user)
log_std_type = custom_model_config.get("log_std_type")
assert log_std_type in ["constant", "state_independent"]
sample_std = custom_model_config.get("sample_std")
assert sample_std > 0.0, "The value shoulde be positive"
expert_hiddens = custom_model_config.get("expert_hiddens")
expert_activations = custom_model_config.get("expert_activations")
expert_init_weights = custom_model_config.get("expert_init_weights")
expert_log_std_types = custom_model_config.get("expert_log_std_types")
expert_sample_stds = custom_model_config.get("expert_sample_stds")
expert_checkpoints = custom_model_config.get("expert_checkpoints")
expert_learnable = custom_model_config.get("expert_learnable")
gate_fn_hiddens = custom_model_config.get("gate_fn_hiddens")
gate_fn_activations = custom_model_config.get("gate_fn_activations")
gate_fn_init_weights = custom_model_config.get("gate_fn_init_weights")
value_fn_hiddens = custom_model_config.get("value_fn_hiddens")
value_fn_activations = custom_model_config.get("value_fn_activations")
value_fn_init_weights = custom_model_config.get("value_fn_init_weights")
dim_state = int(np.product(obs_space.shape))
num_experts = len(expert_hiddens)
''' Construct the gate function '''
self._gate_fn = FC(
size_in=dim_state,
size_out=num_experts,
hiddens=gate_fn_hiddens,
activations=gate_fn_activations,
init_weights=gate_fn_init_weights,
append_log_std=False)
''' Construct experts '''
self._experts = []
for i in range(num_experts):
expert = FC(
size_in=dim_state,
size_out=num_outputs,
hiddens=expert_hiddens[i],
activations=expert_activations[i],
init_weights=expert_init_weights[i],
append_log_std=True,
log_std_type=expert_log_std_types[i],
sample_std=expert_sample_stds[i])
if expert_checkpoints[i]:
expert.load_state_dict(torch.load(expert_checkpoints[i]))
expert.eval()
for name, param in expert.named_parameters():
param.requires_grad = expert_learnable[i]
self._experts.append(expert)
''' Construct the value function '''
self._value_fn = FC(
size_in=dim_state,
size_out=1,
hiddens=value_fn_hiddens,
activations=value_fn_activations,
init_weights=value_fn_init_weights,
append_log_std=False)
self._num_experts = num_experts
self._cur_value = None
self._cur_gate_weight = None