in modules/SwissArmyTransformer/sat/tokenization/cogview/vqvae/api.py [0:0]
def test_decode_default(device=0):
# testing 3 decoders: original/l1+ms-ssim/l1+ms-ssim+perceptual
configs = [
{
"target": ".vqvae_zc.Decoder",
"params": {
"in_channel": 256,
"out_channel": 3,
"channel": 512,
"n_res_block": 0,
"n_res_channel": 32,
"stride": 4,
"simple": True
},
"ckpt": "/dataset/fd5061f6/cogview/zwd/pretrained/vqvae/vqvae_hard_biggerset_011.pt",
"ckpt_prefix": "module.dec",
"device": device },
{
"target": "vqvae.vqvae_diffusion.Decoder",
"params": {
"double_z": False,
"z_channels": 256,
"resolution": 256,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": [ 1,1,2,4], # num_down = len(ch_mult)-1
"num_res_blocks": 2,
"attn_resolutions": [16],
"dropout": 0.0
},
"ckpt": "/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim/checkpoints/last.ckpt",
"ckpt_prefix": "dec",
"device": device },
{
"target": "vqvae.vqvae_diffusion.Decoder",
"params": {
"double_z": False,
"z_channels": 256,
"resolution": 256,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": [ 1,1,2,4], # num_down = len(ch_mult)-1
"num_res_blocks": 2,
"attn_resolutions": [16],
"dropout": 0.0
},
"ckpt": "/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt",
"ckpt_prefix": "dec",
"device": device },
]
testcase_dir = "/dataset/fd5061f6/cogview/zwd/vqgan/testcase/"
for testcase in os.listdir(testcase_dir):
testcase = os.path.join(testcase_dir, testcase)
test_decode(configs, testcase, device)