in janus/models/vq_model.py [0:0]
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.encoder = Encoder(
ch_mult=config.encoder_ch_mult,
z_channels=config.z_channels,
dropout=config.dropout_p,
)
self.decoder = Decoder(
ch_mult=config.decoder_ch_mult,
z_channels=config.z_channels,
dropout=config.dropout_p,
)
self.quantize = VectorQuantizer(
config.codebook_size,
config.codebook_embed_dim,
config.commit_loss_beta,
config.entropy_loss_ratio,
config.codebook_l2_norm,
config.codebook_show_usage,
)
self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
self.post_quant_conv = nn.Conv2d(
config.codebook_embed_dim, config.z_channels, 1
)