in janus/models/vq_model.py [0:0]
def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.entropy_loss_ratio = entropy_loss_ratio
self.l2_norm = l2_norm
self.show_usage = show_usage
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
if self.l2_norm:
self.embedding.weight.data = F.normalize(
self.embedding.weight.data, p=2, dim=-1
)
if self.show_usage:
self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))