in models/future_prediction.py [0:0]
def __init__(
self,
in_features: int,
output_len: int = -1,
output_len_eval: int = -1, # Same as output_len, used during eval
avg_last_n: int = -1,
inter_dim: int = 768,
future_pred_loss: hydra.types.TargetConf = None,
return_past_too: bool = False,
drop_last_n: int = 0,
quantize_before_rollout: bool = False,
# This is only relevant when in_features=1 and input is
# clustered, or if on the fly cluster assgn is requested
assign_to_centroids: str = None,
num_cluster_centers: int = 50000,
freeze_encoder_decoder: bool = False,
**kwargs):
super().__init__()
self.assign_to_centroids = assign_to_centroids
if self.assign_to_centroids:
# Since we will be assign the features
assert in_features != 1
self.assigner = KmeansAssigner(assign_to_centroids)
assert self.assigner.num_clusters == num_cluster_centers
if in_features == 1 or assign_to_centroids:
self.encoder = nn.Embedding(num_cluster_centers, inter_dim)
else:
self.encoder = nn.Linear(in_features, inter_dim, bias=False)
self.decoder = nn.Linear(inter_dim, in_features, bias=False)
# If encoder is an embedding, then tie up the weights
if isinstance(self.encoder, nn.Embedding):
self.decoder.weight = self.encoder.weight
if freeze_encoder_decoder:
self.encoder.weight.requires_grad = False
self.decoder.weight.requires_grad = False
# This already has the LayerNorm inside residual, as Naman suggested.
self.gpt_model = transformers.GPT2Model(
transformers.GPT2Config(n_embd=inter_dim,
vocab_size=in_features,
use_cache=True,
**kwargs))
# Not needed, encoder will take care of it.
del self.gpt_model.wte
self.output_len = output_len
self.output_len_eval = output_len_eval
self.avg_last_n = avg_last_n
self.inter_dim = inter_dim
self.in_features = in_features
if future_pred_loss is not None:
self.future_pred_loss = hydra.utils.instantiate(future_pred_loss,
reduction='none')
else:
self.future_pred_loss = None
self.return_past_too = return_past_too
self.drop_last_n = drop_last_n
# Set this, if want to quantize the prediction (using top-1) and
# re-encode, as opposed to using the soft predicted feature
self.quantize_before_rollout = quantize_before_rollout