def __init__()

in qlearn/atari/prior_bootstrapped_agent.py [0:0]


    def __init__(self, args, input_dim, num_actions):
        self.num_actions = num_actions
        self.batch_size = args.batch_size
        self.discount = args.discount
        self.double_q = args.double_q
        self.input_dim = input_dim
        self.nheads = args.nheads
        self.beta = args.beta
        self.online_net = AtariBootstrappedDQN(args, input_dim, num_actions)
        self.prior = AtariBootstrappedDQN(args, input_dim, num_actions)

        # if args.model and os.path.isfile(args.model):
        #     self.online_net.load_state_dict(torch.load(args.model))
        self.online_net.train()
        self.prior.eval()
        for param in self.prior.parameters():
            param.requires_grad = False

        self.target_net = AtariBootstrappedDQN(args, input_dim, num_actions)
        self.update_target_net()
        self.target_net.eval()
        for param in self.target_net.parameters():
            param.requires_grad = False
        self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr)
        # self.optimiser = optim.RMSprop(self.online_net.parameters(), lr=args.lr,
        #                                alpha=args.alpha, momentum=args.momentum,
        #                                eps=args.eps_rmsprop)
        if args.cuda:
            self.online_net.cuda()
            self.target_net.cuda()
            self.prior.cuda()
        self.FloatTensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if args.cuda else torch.LongTensor
        self.ByteTensor = torch.cuda.ByteTensor if args.cuda else torch.ByteTensor