def test_decode_default()

in modules/SwissArmyTransformer/sat/tokenization/cogview/vqvae/api.py [0:0]


def test_decode_default(device=0):
    # testing 3 decoders: original/l1+ms-ssim/l1+ms-ssim+perceptual
    configs = [
        {
            "target": ".vqvae_zc.Decoder",
            "params": {
                "in_channel": 256, 
                "out_channel": 3,
                "channel": 512,
                "n_res_block": 0,
                "n_res_channel": 32,
                "stride": 4,
                "simple": True
            },
            "ckpt": "/dataset/fd5061f6/cogview/zwd/pretrained/vqvae/vqvae_hard_biggerset_011.pt",
            "ckpt_prefix": "module.dec",
            "device": device },
        {
            "target": "vqvae.vqvae_diffusion.Decoder",
            "params": {
                "double_z": False,
                "z_channels": 256,
                "resolution": 256,
                "in_channels": 3,
                "out_ch": 3,
                "ch": 128,
                "ch_mult": [ 1,1,2,4],  # num_down = len(ch_mult)-1
                "num_res_blocks": 2,
                "attn_resolutions": [16],
                "dropout": 0.0
            },
            "ckpt": "/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim/checkpoints/last.ckpt",
            "ckpt_prefix": "dec",
            "device": device },
        {
            "target": "vqvae.vqvae_diffusion.Decoder",
            "params": {
                "double_z": False,
                "z_channels": 256,
                "resolution": 256,
                "in_channels": 3,
                "out_ch": 3,
                "ch": 128,
                "ch_mult": [ 1,1,2,4],  # num_down = len(ch_mult)-1
                "num_res_blocks": 2,
                "attn_resolutions": [16],
                "dropout": 0.0
            },
            "ckpt": "/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt",
            "ckpt_prefix": "dec",
            "device": device },
    ]
    testcase_dir = "/dataset/fd5061f6/cogview/zwd/vqgan/testcase/"
    for testcase in os.listdir(testcase_dir):
        testcase = os.path.join(testcase_dir, testcase)
        test_decode(configs, testcase, device)