def __init__()

in lerobot/common/policies/tdmpc/modeling_tdmpc.py [0:0]


    def __init__(self, config: TDMPCConfig):
        """
        Creates encoders for pixel and/or state modalities.
        TODO(alexander-soare): The original work allows for multiple images by concatenating them along the
            channel dimension. Re-implement this capability.
        """
        super().__init__()
        self.config = config

        if config.image_features:
            self.image_enc_layers = nn.Sequential(
                nn.Conv2d(
                    next(iter(config.image_features.values())).shape[0],
                    config.image_encoder_hidden_dim,
                    7,
                    stride=2,
                ),
                nn.ReLU(),
                nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
                nn.ReLU(),
                nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
                nn.ReLU(),
                nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
                nn.ReLU(),
            )
            dummy_shape = (1, *next(iter(config.image_features.values())).shape)
            out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:]
            self.image_enc_layers.extend(
                nn.Sequential(
                    nn.Flatten(),
                    nn.Linear(np.prod(out_shape), config.latent_dim),
                    nn.LayerNorm(config.latent_dim),
                    nn.Sigmoid(),
                )
            )

        if config.robot_state_feature:
            self.state_enc_layers = nn.Sequential(
                nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
                nn.ELU(),
                nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
                nn.LayerNorm(config.latent_dim),
                nn.Sigmoid(),
            )

        if config.env_state_feature:
            self.env_state_enc_layers = nn.Sequential(
                nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
                nn.ELU(),
                nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
                nn.LayerNorm(config.latent_dim),
                nn.Sigmoid(),
            )