in rllib_model_custom_torch.py [0:0]
def __init__(self, obs_space, action_space, num_outputs, model_config,
name, **model_kwargs):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
''' Load and check configuarations '''
assert num_outputs % 2 == 0, (
"num_outputs must be divisible by two", num_outputs)
num_outputs = num_outputs // 2
custom_model_config = FullyConnectedPolicy.DEFAULT_CONFIG.copy()
custom_model_config_by_user = model_config.get("custom_model_config")
if custom_model_config_by_user:
custom_model_config.update(custom_model_config_by_user)
log_std_type = custom_model_config.get("log_std_type")
assert log_std_type in ["constant", "state_independent"]
sample_std = custom_model_config.get("sample_std")
assert sample_std > 0.0, "The value shoulde be positive"
policy_fn_hiddens = custom_model_config.get("policy_fn_hiddens")
policy_fn_activations = custom_model_config.get("policy_fn_activations")
policy_fn_init_weights = custom_model_config.get("policy_fn_init_weights")
assert len(policy_fn_hiddens) > 0
assert len(policy_fn_hiddens)+1 == len(policy_fn_activations)
assert len(policy_fn_hiddens)+1 == len(policy_fn_init_weights)
value_fn_hiddens = custom_model_config.get("value_fn_hiddens")
value_fn_activations = custom_model_config.get("value_fn_activations")
value_fn_init_weights = custom_model_config.get("value_fn_init_weights")
assert len(value_fn_hiddens) > 0
assert len(value_fn_hiddens)+1 == len(value_fn_activations)
assert len(value_fn_hiddens)+1 == len(value_fn_init_weights)
dim_state = int(np.product(obs_space.shape))
''' Construct the policy function '''
self._policy_fn = FC(
size_in=dim_state,
size_out=num_outputs,
hiddens=policy_fn_hiddens,
activations=policy_fn_activations,
init_weights=policy_fn_init_weights,
append_log_std=True,
log_std_type=log_std_type,
sample_std=sample_std)
''' Construct the value function '''
self._value_fn = FC(
size_in=dim_state,
size_out=1,
hiddens=value_fn_hiddens,
activations=value_fn_activations,
init_weights=value_fn_init_weights,
append_log_std=False)
self._cur_value = None