in qlearn/toys/bootstrapped_agent.py [0:0]
def __init__(self, args, env):
self.action_space = env.action_space.n
self.batch_size = args.batch_size
self.discount = args.discount
self.nheads = args.nheads
self.double_q = args.double_q
self.online_net = BoostrappedDQN(args, self.action_space)
if args.model and os.path.isfile(args.model):
self.online_net.load_state_dict(torch.load(args.model))
self.online_net.train()
self.target_net = BoostrappedDQN(args, self.action_space)
self.update_target_net()
self.target_net.eval()
for param in self.target_net.parameters():
param.requires_grad = False
self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps)
if args.cuda:
self.online_net.cuda()
self.target_net.cuda()
self.FloatTensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
self.LongTensor = torch.cuda.LongTensor if args.cuda else torch.LongTensor
self.ByteTensor = torch.cuda.ByteTensor if args.cuda else torch.ByteTensor