in torchbenchmark/models/dcgan/__init__.py [0:0]
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
self.debug_print = False
self.root = str(Path(__file__).parent)
self.dcgan = DCGAN(self)
dcgan = self.dcgan
device = dcgan.device
ngpu = dcgan.ngpu
nz = dcgan.nz
lr = dcgan.lr
beta1 = dcgan.beta1
num_epochs = dcgan.num_epochs
# Create the generator
self.netG = Generator(dcgan).to(device)
# Handle multi-gpu if desired
if (dcgan.device == 'cuda') and (ngpu > 1):
self.netG = nn.DataParallel(self.netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
self.netG.apply(weights_init)
if self.debug_print:
# Print the model
print(self.netG)
# Create the Discriminator
netD = Discriminator(dcgan).to(device)
# Handle multi-gpu if desired
if (dcgan.device == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(self.netD, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netD.apply(weights_init)
if self.debug_print:
# Print the model
print(netD)
# Initialize BCELoss function
self.criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
self.fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
self.real_label = 1.
self.fake_label = 0.
# Random values as surrogate for batch of photos
self.exmaple_inputs = torch.randn(self.batch_size, 3, 64, 64, device=self.device)
self.model = netD
if test == "train":
# Setup Adam optimizers for both G and D
self.optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
self.optimizerG = optim.Adam(self.netG.parameters(), lr=lr, betas=(beta1, 0.999))
elif test == "eval":
# inference would just run descriminator so thats what we'll do too.
self.inference_just_descriminator = True
if False == self.inference_just_descriminator:
self.eval_noise = torch.randn(self.batch_size, nz, 1, 1, device=self.device)