in nni/algorithms/compression/pytorch/pruning/amc/amc_pruner.py [0:0]
def __init__(
self,
model,
config_list,
evaluator,
val_loader,
suffix=None,
model_type='mobilenet',
dataset='cifar10',
flops_ratio=0.5,
lbound=0.2,
rbound=1.,
reward='acc_reward',
n_calibration_batches=60,
n_points_per_layer=10,
channel_round=8,
hidden1=300,
hidden2=300,
lr_c=1e-3,
lr_a=1e-4,
warmup=100,
discount=1.,
bsize=64,
rmsize=100,
window_length=1,
tau=0.01,
init_delta=0.5,
delta_decay=0.99,
max_episode_length=1e9,
output_dir='./logs',
debug=False,
train_episode=800,
epsilon=50000,
seed=None):
self.val_loader = val_loader
self.evaluator = evaluator
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
checkpoint = deepcopy(model.state_dict())
super().__init__(model, config_list, optimizer=None)
# build folder and logs
base_folder_name = '{}_{}_r{}_search'.format(model_type, dataset, flops_ratio)
if suffix is not None:
self.output_dir = os.path.join(output_dir, base_folder_name + '-' + suffix)
else:
self.output_dir = get_output_folder(output_dir, base_folder_name)
self.env_args = Namespace(
model_type=model_type,
preserve_ratio=flops_ratio,
lbound=lbound,
rbound=rbound,
reward=reward,
n_calibration_batches=n_calibration_batches,
n_points_per_layer=n_points_per_layer,
channel_round=channel_round,
output=self.output_dir
)
self.env = ChannelPruningEnv(
self, evaluator, val_loader, checkpoint, args=self.env_args)
_logger.info('=> Saving logs to %s', self.output_dir)
self.tfwriter = SummaryWriter(log_dir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
_logger.info('=> Output path: %s...', self.output_dir)
nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here
rmsize = rmsize * len(self.env.prunable_idx) # for each layer
_logger.info('** Actual replay buffer size: %d', rmsize)
self.ddpg_args = Namespace(
hidden1=hidden1,
hidden2=hidden2,
lr_c=lr_c,
lr_a=lr_a,
warmup=warmup,
discount=discount,
bsize=bsize,
rmsize=rmsize,
window_length=window_length,
tau=tau,
init_delta=init_delta,
delta_decay=delta_decay,
max_episode_length=max_episode_length,
debug=debug,
train_episode=train_episode,
epsilon=epsilon
)
self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)