in pyro/nn/auto_reg_nn.py [0:0]
def __init__(
self,
input_dim,
context_dim,
hidden_dims,
param_dims=[1, 1],
permutation=None,
skip_connections=False,
nonlinearity=nn.ReLU()):
super().__init__()
if input_dim == 1:
warnings.warn('ConditionalAutoRegressiveNN input_dim = 1. Consider using an affine transformation instead.')
self.input_dim = input_dim
self.context_dim = context_dim
self.hidden_dims = hidden_dims
self.param_dims = param_dims
self.count_params = len(param_dims)
self.output_multiplier = sum(param_dims)
self.all_ones = (torch.tensor(param_dims) == 1).all().item()
# Calculate the indices on the output corresponding to each parameter
ends = torch.cumsum(torch.tensor(param_dims), dim=0)
starts = torch.cat((torch.zeros(1).type_as(ends), ends[:-1]))
self.param_slices = [slice(s.item(), e.item()) for s, e in zip(starts, ends)]
# Hidden dimension must be not less than the input otherwise it isn't
# possible to connect to the outputs correctly
for h in hidden_dims:
if h < input_dim:
raise ValueError('Hidden dimension must not be less than input dimension.')
if permutation is None:
# By default set a random permutation of variables, which is important for performance with multiple steps
P = torch.randperm(input_dim, device='cpu').to(torch.Tensor().device)
else:
# The permutation is chosen by the user
P = permutation.type(dtype=torch.int64)
self.register_buffer('permutation', P)
# Create masks
self.masks, self.mask_skip = create_mask(
input_dim=input_dim, context_dim=context_dim, hidden_dims=hidden_dims, permutation=self.permutation,
output_dim_multiplier=self.output_multiplier)
# Create masked layers
layers = [MaskedLinear(input_dim + context_dim, hidden_dims[0], self.masks[0])]
for i in range(1, len(hidden_dims)):
layers.append(MaskedLinear(hidden_dims[i - 1], hidden_dims[i], self.masks[i]))
layers.append(MaskedLinear(hidden_dims[-1], input_dim * self.output_multiplier, self.masks[-1]))
self.layers = nn.ModuleList(layers)
if skip_connections:
self.skip_layer = MaskedLinear(
input_dim +
context_dim,
input_dim *
self.output_multiplier,
self.mask_skip,
bias=False)
else:
self.skip_layer = None
# Save the nonlinearity
self.f = nonlinearity