in ppuda/deepnets1m/genotypes.py [0:0]
def sample_genotype(steps=1, only_pool=False, allow_none=True, drop_concat=True, allow_transformer=False):
# Extended set of primitives based on https://github.com/quark0/darts/blob/master/cnn/genotypes.py
PRIMITIVES_DARTS_EXT = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'conv_1x1',
'conv_7x1_1x7',
'conv_3x3',
'conv_5x5',
'conv_7x7',
'msa',
'cse'
]
multiplier = steps
k = sum(1 for i in range(steps) for n in range(2 + i))
num_ops = len(PRIMITIVES_DARTS_EXT)
alphas_normal = Variable(1e-3 * torch.randn(k, num_ops))
alphas_reduce = Variable(1e-3 * torch.randn(k, num_ops))
if only_pool:
assert PRIMITIVES_DARTS_EXT[3] == 'skip_connect', PRIMITIVES_DARTS_EXT
assert PRIMITIVES_DARTS_EXT[4] == 'sep_conv_3x3', PRIMITIVES_DARTS_EXT
alphas_reduce[:, 4:] = -1000 # prevent sampling operators with learnable params to sample the architectures similar to the best DARTS cell
if not allow_transformer:
ind = PRIMITIVES_DARTS_EXT.index('msa')
assert ind == len(PRIMITIVES_DARTS_EXT) - 2, (ind, PRIMITIVES_DARTS_EXT)
alphas_normal[:, ind] = -1000
alphas_reduce[:, ind] = -1000
def _parse(weights):
# Based on https://github.com/quark0/darts/blob/master/cnn/model_search.py#L135
gene = []
n = 2
start = 0
for i in range(steps):
end = start + n
W = weights[start:end].copy()
edges = sorted(range(i + 2),
key=lambda x: -max(W[x][k] for k in range(len(W[x])) if (k != PRIMITIVES_DARTS_EXT.index('none') or allow_none)))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if k != PRIMITIVES_DARTS_EXT.index('none') or allow_none:
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES_DARTS_EXT[k_best], j))
start = end
n += 1
return gene
gene_normal = _parse(F.softmax(alphas_normal, dim=-1).data.numpy())
gene_reduce = _parse(F.softmax(alphas_reduce, dim=-1).data.numpy())
if drop_concat:
concat = []
for i in range(2 + steps - multiplier, steps + 2):
if i == steps + 1 or torch.rand(1).item() > 0.5: # always add the last otherwise the features from the previous sum nodes will be lost
concat.append(i)
else:
concat = range(2 + steps - multiplier, steps + 2)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype