torchbenchmark/models/maml/learner.py (127 lines of code) (raw):
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from typing import List
class Learner(nn.Module):
"""
"""
def __init__(self, config, imgc, imgsz):
"""
:param config: network config file, type:list of (string, list)
:param imgc: 1 or 3
:param imgsz: 28 or 84
"""
super(Learner, self).__init__()
self.config = config
# this dict contains all tensors needed to be optimized
self.vars = nn.ParameterList()
# running_mean and running_var
self.vars_bn = nn.ParameterList()
for i, (name, param) in enumerate(self.config):
if name == 'conv2d':
# [ch_out, ch_in, kernelsz, kernelsz]
w = nn.Parameter(torch.ones(*param[:4]))
# gain=1 according to cbfin's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
elif name == 'convt2d':
# [ch_in, ch_out, kernelsz, kernelsz, stride, padding]
w = nn.Parameter(torch.ones(*param[:4]))
# gain=1 according to cbfin's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_in, ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[1])))
elif name == 'linear':
# [ch_out, ch_in]
w = nn.Parameter(torch.ones(*param))
# gain=1 according to cbfinn's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
elif name == 'bn':
# [ch_out]
w = nn.Parameter(torch.ones(param[0]))
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
# must set requires_grad=False
running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
self.vars_bn.extend([running_mean, running_var])
elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d',
'flatten', 'reshape', 'leakyrelu', 'sigmoid']:
continue
else:
raise NotImplementedError
def extra_repr(self):
info = ''
for name, param in self.config:
if name == 'conv2d':
tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
%(param[1], param[0], param[2], param[3], param[4], param[5],)
info += tmp + '\n'
elif name == 'convt2d':
tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
%(param[0], param[1], param[2], param[3], param[4], param[5],)
info += tmp + '\n'
elif name == 'linear':
tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0])
info += tmp + '\n'
elif name == 'leakyrelu':
tmp = 'leakyrelu:(slope:%f)'%(param[0])
info += tmp + '\n'
elif name == 'avg_pool2d':
tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
info += tmp + '\n'
elif name == 'max_pool2d':
tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
info += tmp + '\n'
elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']:
tmp = name + ':' + str(tuple(param))
info += tmp + '\n'
else:
raise NotImplementedError
return info
def forward(self, x, vars=None, bn_training=True):
"""
This function can be called by finetunning, however, in finetunning, we dont wish to update
running_mean/running_var. Thought weights/bias of bn == updated, it has been separated by fast_weights.
Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
:param x: [b, 1, 28, 28]
:param vars:
:param bn_training: set False to not update
:return: x, loss, likelihood, kld
"""
if vars == None:
vars = self.vars
idx = 0
bn_idx = 0
for name, param in self.config:
if name == 'conv2d':
w, b = vars[idx], vars[idx + 1]
# remember to keep synchrozied of forward_encoder and forward_decoder!
x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
idx += 2
# print(name, param, '\tout:', x.shape)
elif name == 'convt2d':
w, b = vars[idx], vars[idx + 1]
# remember to keep synchrozied of forward_encoder and forward_decoder!
x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
idx += 2
# print(name, param, '\tout:', x.shape)
elif name == 'linear':
w, b = vars[idx], vars[idx + 1]
x = F.linear(x, w, b)
idx += 2
# print('forward:', idx, x.norm().item())
elif name == 'bn':
w, b = vars[idx], vars[idx + 1]
running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
idx += 2
bn_idx += 2
elif name == 'flatten':
# print(x.shape)
x = x.view(x.size(0), -1)
elif name == 'reshape':
# [b, 8] => [b, 2, 2, 2]
x = x.view(x.size(0), *param)
elif name == 'relu':
x = F.relu(x, inplace=param[0])
elif name == 'leakyrelu':
x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
elif name == 'tanh':
x = F.tanh(x)
elif name == 'sigmoid':
x = torch.sigmoid(x)
elif name == 'upsample':
x = F.upsample_nearest(x, scale_factor=param[0])
elif name == 'max_pool2d':
x = F.max_pool2d(x, param[0], param[1], param[2])
elif name == 'avg_pool2d':
x = F.avg_pool2d(x, param[0], param[1], param[2])
else:
raise NotImplementedError
# make sure variable == used properly
assert idx == len(vars)
assert bn_idx == len(self.vars_bn)
return x
def zero_grad(self, vars=None):
"""
:param vars:
:return:
"""
with torch.no_grad():
if vars == None:
for p in self.vars:
if not p.grad == None:
p.grad.zero_()
else:
for p in vars:
if not p.grad == None:
p.grad.zero_()
def parameters(self):
"""
override this function since initial parameters will return with a generator.
:return:
"""
return self.vars