def __init__()

in nni/algorithms/compression/pytorch/pruning/amc/amc_pruner.py [0:0]


    def __init__(
            self,
            model,
            config_list,
            evaluator,
            val_loader,
            suffix=None,
            model_type='mobilenet',
            dataset='cifar10',
            flops_ratio=0.5,
            lbound=0.2,
            rbound=1.,
            reward='acc_reward',
            n_calibration_batches=60,
            n_points_per_layer=10,
            channel_round=8,
            hidden1=300,
            hidden2=300,
            lr_c=1e-3,
            lr_a=1e-4,
            warmup=100,
            discount=1.,
            bsize=64,
            rmsize=100,
            window_length=1,
            tau=0.01,
            init_delta=0.5,
            delta_decay=0.99,
            max_episode_length=1e9,
            output_dir='./logs',
            debug=False,
            train_episode=800,
            epsilon=50000,
            seed=None):

        self.val_loader = val_loader
        self.evaluator = evaluator

        if seed is not None:
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)

        checkpoint = deepcopy(model.state_dict())

        super().__init__(model, config_list, optimizer=None)

        # build folder and logs
        base_folder_name = '{}_{}_r{}_search'.format(model_type, dataset, flops_ratio)
        if suffix is not None:
            self.output_dir = os.path.join(output_dir, base_folder_name + '-' + suffix)
        else:
            self.output_dir = get_output_folder(output_dir, base_folder_name)

        self.env_args = Namespace(
            model_type=model_type,
            preserve_ratio=flops_ratio,
            lbound=lbound,
            rbound=rbound,
            reward=reward,
            n_calibration_batches=n_calibration_batches,
            n_points_per_layer=n_points_per_layer,
            channel_round=channel_round,
            output=self.output_dir
        )
        self.env = ChannelPruningEnv(
            self, evaluator, val_loader, checkpoint, args=self.env_args)
        _logger.info('=> Saving logs to %s', self.output_dir)
        self.tfwriter = SummaryWriter(log_dir=self.output_dir)
        self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
        _logger.info('=> Output path: %s...', self.output_dir)

        nb_states = self.env.layer_embedding.shape[1]
        nb_actions = 1  # just 1 action here

        rmsize = rmsize * len(self.env.prunable_idx)  # for each layer
        _logger.info('** Actual replay buffer size: %d', rmsize)

        self.ddpg_args = Namespace(
            hidden1=hidden1,
            hidden2=hidden2,
            lr_c=lr_c,
            lr_a=lr_a,
            warmup=warmup,
            discount=discount,
            bsize=bsize,
            rmsize=rmsize,
            window_length=window_length,
            tau=tau,
            init_delta=init_delta,
            delta_decay=delta_decay,
            max_episode_length=max_episode_length,
            debug=debug,
            train_episode=train_episode,
            epsilon=epsilon
        )
        self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)