def __init__()

in pyro/nn/auto_reg_nn.py [0:0]


    def __init__(
            self,
            input_dim,
            context_dim,
            hidden_dims,
            param_dims=[1, 1],
            permutation=None,
            skip_connections=False,
            nonlinearity=nn.ReLU()):
        super().__init__()
        if input_dim == 1:
            warnings.warn('ConditionalAutoRegressiveNN input_dim = 1. Consider using an affine transformation instead.')
        self.input_dim = input_dim
        self.context_dim = context_dim
        self.hidden_dims = hidden_dims
        self.param_dims = param_dims
        self.count_params = len(param_dims)
        self.output_multiplier = sum(param_dims)
        self.all_ones = (torch.tensor(param_dims) == 1).all().item()

        # Calculate the indices on the output corresponding to each parameter
        ends = torch.cumsum(torch.tensor(param_dims), dim=0)
        starts = torch.cat((torch.zeros(1).type_as(ends), ends[:-1]))
        self.param_slices = [slice(s.item(), e.item()) for s, e in zip(starts, ends)]

        # Hidden dimension must be not less than the input otherwise it isn't
        # possible to connect to the outputs correctly
        for h in hidden_dims:
            if h < input_dim:
                raise ValueError('Hidden dimension must not be less than input dimension.')

        if permutation is None:
            # By default set a random permutation of variables, which is important for performance with multiple steps
            P = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device)
        else:
            # The permutation is chosen by the user
            P = permutation.type(dtype=torch.int64)
        self.register_buffer('permutation', P)

        # Create masks
        self.masks, self.mask_skip = create_mask(
            input_dim=input_dim, context_dim=context_dim, hidden_dims=hidden_dims, permutation=self.permutation,
            output_dim_multiplier=self.output_multiplier)

        # Create masked layers
        layers = [MaskedLinear(input_dim + context_dim, hidden_dims[0], self.masks[0])]
        for i in range(1, len(hidden_dims)):
            layers.append(MaskedLinear(hidden_dims[i - 1], hidden_dims[i], self.masks[i]))
        layers.append(MaskedLinear(hidden_dims[-1], input_dim * self.output_multiplier, self.masks[-1]))
        self.layers = nn.ModuleList(layers)

        if skip_connections:
            self.skip_layer = MaskedLinear(
                input_dim +
                context_dim,
                input_dim *
                self.output_multiplier,
                self.mask_skip,
                bias=False)
        else:
            self.skip_layer = None

        # Save the nonlinearity
        self.f = nonlinearity