examples/vae/utils/custom_mlp.py (81 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 from inspect import isclass import torch import torch.nn as nn from pyro.distributions.util import broadcast_shape class Exp(nn.Module): """ a custom module for exponentiation of tensors """ def __init__(self): super().__init__() def forward(self, val): return torch.exp(val) class ConcatModule(nn.Module): """ a custom module for concatenation of tensors """ def __init__(self, allow_broadcast=False): self.allow_broadcast = allow_broadcast super().__init__() def forward(self, *input_args): # we have a single object if len(input_args) == 1: # regardless of type, # we don't care about single objects # we just index into the object input_args = input_args[0] # don't concat things that are just single objects if torch.is_tensor(input_args): return input_args else: if self.allow_broadcast: shape = broadcast_shape(*[s.shape[:-1] for s in input_args]) + (-1,) input_args = [s.expand(shape) for s in input_args] return torch.cat(input_args, dim=-1) class ListOutModule(nn.ModuleList): """ a custom module for outputting a list of tensors from a list of nn modules """ def __init__(self, modules): super().__init__(modules) def forward(self, *args, **kwargs): # loop over modules in self, apply same args return [mm.forward(*args, **kwargs) for mm in self] def call_nn_op(op): """ a helper function that adds appropriate parameters when calling an nn module representing an operation like Softmax :param op: the nn.Module operation to instantiate :return: instantiation of the op module with appropriate parameters """ if op in [nn.Softmax, nn.LogSoftmax]: return op(dim=1) else: return op() class MLP(nn.Module): def __init__(self, mlp_sizes, activation=nn.ReLU, output_activation=None, post_layer_fct=lambda layer_ix, total_layers, layer: None, post_act_fct=lambda layer_ix, total_layers, layer: None, allow_broadcast=False, use_cuda=False): # init the module object super().__init__() assert len(mlp_sizes) >= 2, "Must have input and output layer sizes defined" # get our inputs, outputs, and hidden input_size, hidden_sizes, output_size = mlp_sizes[0], mlp_sizes[1:-1], mlp_sizes[-1] # assume int or list assert isinstance(input_size, (int, list, tuple)), "input_size must be int, list, tuple" # everything in MLP will be concatted if it's multiple arguments last_layer_size = input_size if type(input_size) == int else sum(input_size) # everything sent in will be concatted together by default all_modules = [ConcatModule(allow_broadcast)] # loop over l for layer_ix, layer_size in enumerate(hidden_sizes): assert type(layer_size) == int, "Hidden layer sizes must be ints" # get our nn layer module (in this case nn.Linear by default) cur_linear_layer = nn.Linear(last_layer_size, layer_size) # for numerical stability -- initialize the layer properly cur_linear_layer.weight.data.normal_(0, 0.001) cur_linear_layer.bias.data.normal_(0, 0.001) # use GPUs to share data during training (if available) if use_cuda: cur_linear_layer = nn.DataParallel(cur_linear_layer) # add our linear layer all_modules.append(cur_linear_layer) # handle post_linear post_linear = post_layer_fct(layer_ix + 1, len(hidden_sizes), all_modules[-1]) # if we send something back, add it to sequential # here we could return a batch norm for example if post_linear is not None: all_modules.append(post_linear) # handle activation (assumed no params -- deal with that later) all_modules.append(activation()) # now handle after activation post_activation = post_act_fct(layer_ix + 1, len(hidden_sizes), all_modules[-1]) # handle post_activation if not null # could add batch norm for example if post_activation is not None: all_modules.append(post_activation) # save the layer size we just created last_layer_size = layer_size # now we have all of our hidden layers # we handle outputs assert isinstance(output_size, (int, list, tuple)), "output_size must be int, list, tuple" if type(output_size) == int: all_modules.append(nn.Linear(last_layer_size, output_size)) if output_activation is not None: all_modules.append(call_nn_op(output_activation) if isclass(output_activation) else output_activation) else: # we're going to have a bunch of separate layers we can spit out (a tuple of outputs) out_layers = [] # multiple outputs? handle separately for out_ix, out_size in enumerate(output_size): # for a single output object, we create a linear layer and some weights split_layer = [] # we have an activation function split_layer.append(nn.Linear(last_layer_size, out_size)) # then we get our output activation (either we repeat all or we index into a same sized array) act_out_fct = output_activation if not isinstance(output_activation, (list, tuple)) \ else output_activation[out_ix] if(act_out_fct): # we check if it's a class. if so, instantiate the object # otherwise, use the object directly (e.g. pre-instaniated) split_layer.append(call_nn_op(act_out_fct) if isclass(act_out_fct) else act_out_fct) # our outputs is just a sequential of the two out_layers.append(nn.Sequential(*split_layer)) all_modules.append(ListOutModule(out_layers)) # now we have all of our modules, we're ready to build our sequential! # process mlps in order, pretty standard here self.sequential_mlp = nn.Sequential(*all_modules) # pass through our sequential for the output! def forward(self, *args, **kwargs): return self.sequential_mlp.forward(*args, **kwargs)