def sample_genotype()

in ppuda/deepnets1m/genotypes.py [0:0]


def sample_genotype(steps=1, only_pool=False, allow_none=True, drop_concat=True, allow_transformer=False):

    # Extended set of primitives based on https://github.com/quark0/darts/blob/master/cnn/genotypes.py
    PRIMITIVES_DARTS_EXT = [
        'none',
        'max_pool_3x3',
        'avg_pool_3x3',
        'skip_connect',
        'sep_conv_3x3',
        'sep_conv_5x5',
        'dil_conv_3x3',
        'dil_conv_5x5',
        'conv_1x1',
        'conv_7x1_1x7',
        'conv_3x3',
        'conv_5x5',
        'conv_7x7',
        'msa',
        'cse'
    ]

    multiplier = steps
    k = sum(1 for i in range(steps) for n in range(2 + i))
    num_ops = len(PRIMITIVES_DARTS_EXT)
    alphas_normal = Variable(1e-3 * torch.randn(k, num_ops))
    alphas_reduce = Variable(1e-3 * torch.randn(k, num_ops))

    if only_pool:
        assert PRIMITIVES_DARTS_EXT[3] == 'skip_connect', PRIMITIVES_DARTS_EXT
        assert PRIMITIVES_DARTS_EXT[4] == 'sep_conv_3x3', PRIMITIVES_DARTS_EXT
        alphas_reduce[:, 4:] = -1000  # prevent sampling operators with learnable params to sample the architectures similar to the best DARTS cell

    if not allow_transformer:
        ind = PRIMITIVES_DARTS_EXT.index('msa')
        assert ind == len(PRIMITIVES_DARTS_EXT) - 2, (ind, PRIMITIVES_DARTS_EXT)
        alphas_normal[:, ind] = -1000
        alphas_reduce[:, ind] = -1000

    def _parse(weights):
        # Based on https://github.com/quark0/darts/blob/master/cnn/model_search.py#L135
        gene = []
        n = 2
        start = 0
        for i in range(steps):
            end = start + n
            W = weights[start:end].copy()
            edges = sorted(range(i + 2),
                           key=lambda x: -max(W[x][k] for k in range(len(W[x])) if (k != PRIMITIVES_DARTS_EXT.index('none') or allow_none)))[:2]
            for j in edges:
                k_best = None
                for k in range(len(W[j])):
                    if k != PRIMITIVES_DARTS_EXT.index('none') or allow_none:
                        if k_best is None or W[j][k] > W[j][k_best]:
                            k_best = k
                gene.append((PRIMITIVES_DARTS_EXT[k_best], j))
            start = end
            n += 1
        return gene

    gene_normal = _parse(F.softmax(alphas_normal, dim=-1).data.numpy())
    gene_reduce = _parse(F.softmax(alphas_reduce, dim=-1).data.numpy())

    if drop_concat:
        concat = []
        for i in range(2 + steps - multiplier, steps + 2):
            if i == steps + 1 or torch.rand(1).item() > 0.5:  # always add the last otherwise the features from the previous sum nodes will be lost
                concat.append(i)
    else:
        concat = range(2 + steps - multiplier, steps + 2)

    genotype = Genotype(
        normal=gene_normal, normal_concat=concat,
        reduce=gene_reduce, reduce_concat=concat
    )

    return genotype