def init_3d_from_2d_dict()

in network/initializer.py [0:0]


def init_3d_from_2d_dict(net, state_dict, method='inflation'):
    logging.debug("Initializer:: loading from 2D neural network, filling method: `{}' ...".format(method))

    # filling method
    def filling_kernel(src, dshape, method):
        assert method in ['inflation', 'random'], \
            "filling method: {} is unknown!".format(method)
        src_np = src.numpy()

        if method == 'inflation':
            dst = torch.FloatTensor(dshape)
            # normalize
            src = src/float(dshape[2])
            src = src.view(dshape[0],dshape[1], 1, dshape[3],dshape[4])
            dst.copy_(src)
        elif method == 'random':
            dst = torch.FloatTensor(dshape)
            tmp = torch.FloatTensor(src.shape)
            # normalize
            src = src/float(dshape[2])
            # random range
            scale = src.abs().mean()
            # filling
            dst[:,:,0,:,:].copy_(src)
            i = 1
            while i < dshape[2]:
                if i+2 < dshape[2]:
                    torch.nn.init.uniform(tmp, a=-scale, b=scale)
                    dst[:,:,i,:,:].copy_(tmp)
                    dst[:,:,i+1,:,:].copy_(src)
                    dst[:,:,i+2,:,:].copy_(-tmp)
                    i += 3
                elif i+1 < dshape[2]:
                    torch.nn.init.uniform(tmp, a=-scale, b=scale)
                    dst[:,:,i,:,:].copy_(tmp)
                    dst[:,:,i+1,:,:].copy_(-tmp)
                    i += 2
                else:
                    dst[:,:,i,:,:].copy_(src)
                    i += 1
            # shuffle
            tmp = dst.numpy().swapaxes(2, -1)
            shp = tmp.shape[:-1]
            for ndx in np.ndindex(shp):
                np.random.shuffle(tmp[ndx])
            dst = torch.from_numpy(tmp)
        else:
            raise NotImplementedError

        return dst


    # customized partialy load function
    src_state_keys = list(state_dict.keys())
    dst_state_keys = list(net.state_dict().keys())
    for name, param in state_dict.items():
        if name in dst_state_keys:
            src_param_shape = param.shape
            dst_param_shape = net.state_dict()[name].shape
            if src_param_shape != dst_param_shape:
                if name.startswith('classifier'):
                    continue
                assert len(src_param_shape) == 4 and len(dst_param_shape) == 5, "{} mismatch".format(name)
                if list(src_param_shape) == [dst_param_shape[i] for i in [0, 1, 3, 4]]:
                    if dst_param_shape[2] != 1:
                        param = filling_kernel(src=param, dshape=dst_param_shape, method=method)
                    else:
                        param = param.view(dst_param_shape)
                assert dst_param_shape == param.shape, \
                    "Initilizer:: error({}): {} != {}".format(name, dst_param_shape, param.shape)
            net.state_dict()[name].copy_(param)
            src_state_keys.remove(name)
            dst_state_keys.remove(name)

    # indicating missed / ignored keys
    if src_state_keys:
        out = "[\'" + '\', \''.join(src_state_keys) + "\']"
        logging.info("Initializer:: >> {} params are unused: {}".format(len(src_state_keys),
                     out if len(out) < 300 else out[0:150] + " ... " + out[-150:]))
    if dst_state_keys:
        logging.info("Initializer:: >> failed to load: \n{}".format(
                     json.dumps(dst_state_keys, indent=4, sort_keys=True)))