def __init__()

in torchbenchmark/models/dcgan/__init__.py [0:0]


    def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
        super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
        self.debug_print = False

        self.root = str(Path(__file__).parent)
        self.dcgan = DCGAN(self)

        dcgan = self.dcgan

        device = dcgan.device
        ngpu = dcgan.ngpu
        nz = dcgan.nz
        lr = dcgan.lr
        beta1 = dcgan.beta1
        num_epochs = dcgan.num_epochs

        # Create the generator
        self.netG = Generator(dcgan).to(device)

        # Handle multi-gpu if desired
        if (dcgan.device == 'cuda') and (ngpu > 1):
            self.netG = nn.DataParallel(self.netG, list(range(ngpu)))

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.2.
        self.netG.apply(weights_init)

        if self.debug_print:
            # Print the model
            print(self.netG)

        # Create the Discriminator
        netD = Discriminator(dcgan).to(device)

        # Handle multi-gpu if desired
        if (dcgan.device == 'cuda') and (ngpu > 1):
            netD = nn.DataParallel(self.netD, list(range(ngpu)))

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.2.
        netD.apply(weights_init)
        
        if self.debug_print:
            # Print the model
            print(netD)

        # Initialize BCELoss function
        self.criterion = nn.BCELoss()

        # Create batch of latent vectors that we will use to visualize
        #  the progression of the generator
        self.fixed_noise = torch.randn(64, nz, 1, 1, device=device)

        # Establish convention for real and fake labels during training
        self.real_label = 1.
        self.fake_label = 0.

        # Random values as surrogate for batch of photos
        self.exmaple_inputs = torch.randn(self.batch_size, 3, 64, 64, device=self.device)
        self.model = netD
        if test == "train":
            # Setup Adam optimizers for both G and D
            self.optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
            self.optimizerG = optim.Adam(self.netG.parameters(), lr=lr, betas=(beta1, 0.999))
        elif test == "eval":
            # inference would just run descriminator so thats what we'll do too.
            self.inference_just_descriminator = True
            if False == self.inference_just_descriminator:
                self.eval_noise = torch.randn(self.batch_size, nz, 1, 1, device=self.device)