def JoinEmbed()

in source/lib/text_processing.py [0:0]


def JoinEmbed(if_embed, sid_fname, of_embed, dim=1024):
    if os.path.isfile(of_embed):
        print(' - JoinEmbed: {} already exists'.format(of_embed))
        return
    # read the input embeddings
    em_in = np.fromfile(if_embed, dtype=np.float32, count=-1).reshape(-1, dim)
    ninp = em_in.shape[0]
    print(' - Combine embeddings:')
    print('                input: {:s} {:d} sentences'.format(if_embed, ninp))

    # get all sentence IDs
    sid = np.empty(ninp, dtype=np.int32)
    i = 0
    with open(sid_fname, 'r') as fp_sid:
        for line in fp_sid:
            sid[i] = int(line)
            i += 1
    nout = sid.max() + 1
    print('                IDs: {:s}, {:d} sentences'.format(sid_fname, nout))

    # combining
    em_out = np.zeros((nout, dim), dtype=np.float32)
    cnt = np.zeros(nout, dtype=np.int32)
    for i in range(ninp):
        idx = sid[i]
        em_out[idx] += em_in[i]  # cumulate sentence vectors
        cnt[idx] += 1

    if (cnt == 0).astype(int).sum() > 0:
        print('ERROR: missing lines')
        sys.exit(1)

    # normalize
    for i in range(nout):
        em_out[i] /= cnt[i]

    print('                output: {:s}'.format(of_embed))
    em_out.tofile(of_embed)