mico/utils/net_utils.py (62 lines of code) (raw):
import torch.nn as nn
import torch
class FeedForward(nn.Module):
def __init__(self, dim_input, dim_hidden, dim_output, num_layers,
activation='relu', dropout_rate=0, layer_norm=False,
residual_connection=False):
"""This model wraps a (residual) neural network in a easy-access way.
"""
super().__init__()
assert num_layers >= 0 # 0 = Linear
if num_layers > 0:
assert dim_hidden > 0
if residual_connection:
assert dim_hidden == dim_input
self.residual_connection = residual_connection
self.stack = nn.ModuleList()
for layer_idx in range(num_layers):
layer = []
if layer_norm:
layer.append(nn.LayerNorm(dim_input if layer_idx == 0 else dim_hidden))
layer.append(nn.Linear(dim_input if layer_idx == 0 else dim_hidden,
dim_hidden))
layer.append({'tanh': nn.Tanh(), 'relu': nn.ReLU()}[activation])
if dropout_rate > 0:
layer.append(nn.Dropout(dropout_rate))
self.stack.append(nn.Sequential(*layer))
self.out = nn.Linear(dim_input if num_layers < 1 else dim_hidden,
dim_output)
def forward(self, x):
for layer in self.stack:
x = x + layer(x) if self.residual_connection else layer(x)
return self.out(x)
class ConditionalDistributionZ(nn.Module):
def __init__(self, number_clusters, dim_input, num_layers, dim_hidden):
"""This model maps the input (BERT representation of a sentence) to the probability vector for
query routing or document assignment to the clusters.
Parameters
----------
number_clusters : int
The number of clusters to which we are going to assign the documents or route the queries.
dim_input : int
The dimension of the input (the representation from BERT base model is 768.)
num_layers : int
How many layers are used in this model.
dim_hidden : int
How many neurons are in the hidden layer of this model.
"""
super(ConditionalDistributionZ, self).__init__()
self.number_clusters = number_clusters
self.softmax = nn.Softmax(dim=-1)
number_clusters = number_clusters
self.ff = FeedForward(dim_input, dim_hidden, number_clusters, num_layers)
def forward(self, inputs):
logits = self.ff(inputs).view(inputs.size(0), 1, self.number_clusters)
probability_vector = self.softmax(logits)
return probability_vector
class MarginalDistributionZ(nn.Module):
def __init__(self, number_clusters):
"""This is the function approximating the distribution of cluster sizes for document assignment.
We use `softmax` to make sure the output is a distribution.
Parameters
----------
number_clusters : int
The number of clusters to which we are going to assign the documents.
"""
super(MarginalDistributionZ, self).__init__()
self.softmax = nn.Softmax(dim=-1)
self.theta = nn.Embedding(1, number_clusters)
def forward(self):
"""Return the approximated distribution.
"""
logits = self.theta.weight
probability_vector = self.softmax(logits)
return probability_vector
def get_init_function(init_value):
"""This is for the initialization of the Prior and Posterior model.
We can change the scale of the initialization by the argument `init_value`.
"""
def init_function(m):
if init_value > 0.:
if hasattr(m, 'weight'):
m.weight.data.uniform_(-init_value, init_value)
if hasattr(m, 'bias'):
m.bias.data.fill_(0.)
return init_function
def cross_entropy_p_q(p, q):
"""This is the function calculating the cross entropy between two distributions.
If there are multiple distributions in `p` or in `q`, we calculate the cross entropy correspondingly and take the average.
"""
if len(q.size()) == 2:
q = q.repeat(p.size(0), 1, 1)
return (- (p * torch.log(q)).sum(dim=(1, 2))).mean()