std::vector packInterleaved()

in faiss/gpu/impl/InterleavedCodes.cpp [464:556]


std::vector<uint8_t> packInterleaved(
        std::vector<uint8_t> data,
        int numVecs,
        int dims,
        int bitsPerCode) {
    int bytesPerDimBlock = 32 * bitsPerCode / 8;
    int bytesPerBlock = bytesPerDimBlock * dims;
    int numBlocks = utils::divUp(numVecs, 32);
    size_t totalSize = (size_t)bytesPerBlock * numBlocks;

    // bit codes padded to whole bytes
    FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));

    // packs based on blocks
    std::vector<uint8_t> out(totalSize, 0);

    if (bitsPerCode == 8) {
        packInterleavedWord<uint8_t>(
                data.data(), out.data(), numVecs, dims, bitsPerCode);
    } else if (bitsPerCode == 16) {
        packInterleavedWord<uint16_t>(
                (uint16_t*)data.data(),
                (uint16_t*)out.data(),
                numVecs,
                dims,
                bitsPerCode);
    } else if (bitsPerCode == 32) {
        packInterleavedWord<uint32_t>(
                (uint32_t*)data.data(),
                (uint32_t*)out.data(),
                numVecs,
                dims,
                bitsPerCode);
    } else if (bitsPerCode == 4) {
#pragma omp parallel for
        for (int i = 0; i < numBlocks; ++i) {
            for (int j = 0; j < dims; ++j) {
                for (int k = 0; k < bytesPerDimBlock; ++k) {
                    int loVec = i * 32 + k * 2;
                    int hiVec = loVec + 1;

                    uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
                    uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;

                    out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
                            (hi << 4) | (lo & 0xf);
                }
            }
        }
    } else if (bitsPerCode == 5) {
#pragma omp parallel for
        for (int i = 0; i < numBlocks; ++i) {
            for (int j = 0; j < dims; ++j) {
                for (int k = 0; k < bytesPerDimBlock; ++k) {
                    // What input vectors we are pulling from
                    int loVec = i * 32 + (k * 8) / 5;
                    int hiVec = loVec + 1;
                    int hiVec2 = hiVec + 1;

                    uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
                    uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
                    uint8_t hi2 =
                            hiVec2 < numVecs ? data[hiVec2 * dims + j] : 0;

                    out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
                            pack5(k, lo, hi, hi2);
                }
            }
        }
    } else if (bitsPerCode == 6) {
#pragma omp parallel for
        for (int i = 0; i < numBlocks; ++i) {
            for (int j = 0; j < dims; ++j) {
                for (int k = 0; k < bytesPerDimBlock; ++k) {
                    // What input vectors we are pulling from
                    int loVec = i * 32 + (k * 8) / 6;
                    int hiVec = loVec + 1;

                    uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
                    uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;

                    out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
                            pack6(k, lo, hi);
                }
            }
        }
    } else {
        // unimplemented
        FAISS_ASSERT(false);
    }

    return out;
}