def __init__()

in source/MXNetEnv/training/training_src/networks/qnetworks.py [0:0]


    def __init__(self, state_shape, action_size,
                 starting_channels,
                 dS,
                 d,
                 number_of_hidden_states,
                 kernel_size,
                 repeat_size,
                 activation_type,
                 sequence_length,
                 seed):
        """Initialize parameters and build model.
        Params
        ======
            state_shape (int, int, int): Dimension of each state
            action_size (int): Dimension of each action
            starting_channels (int):
            dS (int): depth of snake embedding
            d (int): depth of health and turn embedding
            number_of_hidden_states (int)
            repeat_size (int)
            activation_type (str)
            sequence_length (int)
            seed (int): Random seed
        """
        super(QNetworkAttention, self).__init__()

        self.take_additional_forward_arguments = True
        
        self.dS = dS
        self.dH = d
        self.dT = d
        
        self.sequence_length = sequence_length
        self.repeat_size = repeat_size
        mx.random.seed(seed)

        self.conv = gluon.nn.Conv2D(starting_channels,
                                    kernel_size=kernel_size,
                                    strides=2,
                                    activation=activation_type)
        self.conv.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.conv2 = gluon.nn.Conv2D(starting_channels,
                                    kernel_size=kernel_size,
                                    strides=2,
                                    activation=activation_type)
        self.conv2.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.key_norm = gluon.nn.LayerNorm()
        self.key_norm.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.query_norm = gluon.nn.LayerNorm()
        self.query_norm.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
        
        self.conv_snake = gluon.nn.Conv2D(starting_channels, kernel_size=kernel_size,
                                          strides=2)
        self.conv_snake.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.embedding_snake = gluon.nn.Embedding(5, self.dS*starting_channels)
        self.embedding_snake.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.conv_health = gluon.nn.Conv2D(starting_channels, kernel_size=kernel_size,
                                           strides=2)
        self.conv_health.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.embedding_health = gluon.nn.Embedding(100, self.dH*starting_channels)
        self.embedding_health.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.conv_turn = gluon.nn.Conv2D(starting_channels, kernel_size=kernel_size,
                                         strides=2)
        self.conv_turn.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.embedding_turn = gluon.nn.Embedding(10, self.dT*starting_channels)
        self.embedding_turn.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.conv_predict = gluon.nn.Conv2D(starting_channels, kernel_size=kernel_size,
                                             strides=2)
        self.conv_predict.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

        self.predict = gluon.nn.Dense(action_size)
        self.predict.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
   
        if self.sequence_length > 1:
            self.gru = gluon.rnn.GRU(number_of_hidden_states, num_layers=1,
                                     layout='NTC')
            self.gru.collect_params().initialize(mx.init.Xavier(), ctx=ctx)