def __init__()

in rllib_model_custom_torch.py [0:0]


    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name, **model_kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        ''' Load and check configuarations '''

        assert num_outputs % 2 == 0, (
            "num_outputs must be divisible by two", num_outputs)
        num_outputs = num_outputs // 2

        custom_model_config = MOEPolicyBase.DEFAULT_CONFIG.copy()
        custom_model_config_by_user = model_config.get("custom_model_config")
        if custom_model_config_by_user:
            custom_model_config.update(custom_model_config_by_user)

        log_std_type = custom_model_config.get("log_std_type")
        assert log_std_type in ["constant", "state_independent"]

        sample_std = custom_model_config.get("sample_std")
        assert sample_std > 0.0, "The value shoulde be positive"

        expert_hiddens = custom_model_config.get("expert_hiddens")
        expert_activations = custom_model_config.get("expert_activations")
        expert_init_weights = custom_model_config.get("expert_init_weights")
        expert_log_std_types = custom_model_config.get("expert_log_std_types")
        expert_sample_stds = custom_model_config.get("expert_sample_stds")
        expert_checkpoints = custom_model_config.get("expert_checkpoints")
        expert_learnable = custom_model_config.get("expert_learnable")

        gate_fn_hiddens = custom_model_config.get("gate_fn_hiddens")
        gate_fn_activations = custom_model_config.get("gate_fn_activations")
        gate_fn_init_weights = custom_model_config.get("gate_fn_init_weights")

        value_fn_hiddens = custom_model_config.get("value_fn_hiddens")
        value_fn_activations = custom_model_config.get("value_fn_activations")
        value_fn_init_weights = custom_model_config.get("value_fn_init_weights")

        dim_state = int(np.product(obs_space.shape))
        num_experts = len(expert_hiddens)

        ''' Construct the gate function '''
        
        self._gate_fn = FC(
            size_in=dim_state, 
            size_out=num_experts, 
            hiddens=gate_fn_hiddens, 
            activations=gate_fn_activations, 
            init_weights=gate_fn_init_weights, 
            append_log_std=False)

        ''' Construct experts '''

        self._experts = []
        for i in range(num_experts):
            expert = FC(
                size_in=dim_state, 
                size_out=num_outputs, 
                hiddens=expert_hiddens[i], 
                activations=expert_activations[i], 
                init_weights=expert_init_weights[i], 
                append_log_std=True,
                log_std_type=expert_log_std_types[i], 
                sample_std=expert_sample_stds[i])
            if expert_checkpoints[i]:
                expert.load_state_dict(torch.load(expert_checkpoints[i]))
                expert.eval()
            for name, param in expert.named_parameters():
                param.requires_grad = expert_learnable[i]
            self._experts.append(expert)

        ''' Construct the value function '''
        
        self._value_fn = FC(
            size_in=dim_state, 
            size_out=1, 
            hiddens=value_fn_hiddens, 
            activations=value_fn_activations, 
            init_weights=value_fn_init_weights, 
            append_log_std=False)

        self._num_experts = num_experts
        self._cur_value = None
        self._cur_gate_weight = None