ErrorCode CPUTensorConverter::convert()

in source/backend/cpu/CPUTensorConvert.cpp [53:252]


ErrorCode CPUTensorConverter::convert(const void* inputRaw, void* outputRaw, MNN_DATA_FORMAT source, MNN_DATA_FORMAT dest, int batch, int area, int channel, int bitLength, const CoreFunctions* core, int tId, int numberThread) {
    // the case when source and dest data layout are the same
    // This case occurs in BackendTest of BF16 data.
    if(source == dest) {
        if (tId == 0) {
            ::memcpy(outputRaw, inputRaw, batch * area * channel * bitLength);
        }
        return NO_ERROR;
    }
    if (MNN_DATA_FORMAT_NHWC == source && MNN_DATA_FORMAT_NCHW == dest) {
        if (tId == 0) {
            switch (bitLength) {
                case 1:
                    NHWC2NCHW((int8_t*)inputRaw, (int8_t*)outputRaw, batch, channel, area);
                    break;
                case 2:
                    NHWC2NCHW((int16_t*)inputRaw, (int16_t*)outputRaw, batch, channel, area);
                    break;
                case 4:
                    NHWC2NCHW((float*)inputRaw, (float*)outputRaw, batch, channel, area);
                    break;
                default:
                    break;
            }
        }
        return NO_ERROR;
    }
    if (MNN_DATA_FORMAT_NCHW == source && MNN_DATA_FORMAT_NHWC == dest) {
        if (tId == 0) {
            switch (bitLength) {
                case 1:
                    NCHW2NHWC((int8_t*)inputRaw, (int8_t*)outputRaw, batch, channel, area);
                    break;
                case 2:
                    NCHW2NHWC((int16_t*)inputRaw, (int16_t*)outputRaw, batch, channel, area);
                    break;
                case 4:
                    NCHW2NHWC((float*)inputRaw, (float*)outputRaw, batch, channel, area);
                    break;
                default:
                    break;
            }
        }
        return NO_ERROR;
    }
    // Need Pack
    PackProc proc = nullptr;
    int inside = area;
    int outside = batch;
    if (MNN_DATA_FORMAT_NHWC == source || MNN_DATA_FORMAT_NHWC == dest) {
        inside = 1;
        outside = batch * area;
    }
    //MNN_PRINT("bytes = %d, from %d -> %d, %d - %d - %d\n", bitLength, source, dest, inside, outside, channel);
    if (MNN_DATA_FORMAT_NC4HW4 == source) {
        if (1 == inside) {
            int offset[2] = {
                outside,
                outside
            };
            int step = UP_DIV(outside, numberThread);
            int start = tId * step;
            int end = std::min(start + step, outside);
            if (end <= start) {
                return NO_ERROR;
            }
            auto inputStart = (int8_t*)inputRaw + (start * core->pack * bitLength);
            auto outputStart = (int8_t*)outputRaw + (start * channel * bitLength);
            if (core->bytes == bitLength) {
                proc = decltype(proc)(core->MNNUnpackCUnitTranspose);
            } else if (bitLength == 1) {
                proc = decltype(proc)(core->MNNUnpackCUnitTransposeInt8);
            } else if (bitLength == 2) {
                proc = decltype(proc)(core->MNNUnpackCUnitTransposeInt16);
            }
            if (nullptr == proc) {
                return NOT_SUPPORT;
            }
            proc((float*)outputStart, (const float*)inputStart, end - start, channel, offset);
        } else {
            if (core->bytes == bitLength) {
                proc = decltype(proc)(core->MNNUnpackCUnit);
            } else if (bitLength == 1) {
                proc = decltype(proc)(core->MNNUnpackCUnitInt8);
            } else if (bitLength == 2) {
                proc = decltype(proc)(core->MNNUnpackCUnitInt16);
            }
            if (nullptr == proc) {
                return NOT_SUPPORT;
            }
            if (batch != 1) {
                // Divide in batch
                int offset[2] = {
                    outside * inside,
                    area
                };
                int step = UP_DIV(batch, numberThread);
                int start = tId * step;
                int end = std::min(start + step, batch);
                if (end <= start) {
                    return NO_ERROR;
                }
                for (int v=start; v<end; ++v) {
                    auto inputStart = (int8_t*)inputRaw + (v * core->pack * bitLength * area);
                    auto outputStart = (int8_t*)outputRaw + (v * channel * bitLength * area);
                    proc((float*)outputStart, (const float*)inputStart, area, channel, offset);
                }
            } else {
                // Divide in area
                int offset[2] = {
                    area,
                    area
                };
                int step = UP_DIV(area, numberThread);
                int start = tId * step;
                int end = std::min(start + step, area);
                if (end <= start) {
                    return NO_ERROR;
                }
                auto inputStart = (int8_t*)inputRaw + (start * core->pack * bitLength);
                auto outputStart = (int8_t*)outputRaw + (start * bitLength);
                proc((float*)outputStart, (const float*)inputStart, end - start, channel, offset);
            }
        }
        return NO_ERROR;
    }
    if (MNN_DATA_FORMAT_NC4HW4 == dest) {
        if (1 == inside) {
            int offset[2] = {
                outside,
                outside
            };
            int step = UP_DIV(outside, numberThread);
            int start = tId * step;
            int end = std::min(start + step, outside);
            if (end <= start) {
                return NO_ERROR;
            }
            if (core->bytes == bitLength) {
                proc = decltype(proc)(core->MNNPackCUnitTranspose);
            } else if (bitLength == 1) {
                proc = decltype(proc)(core->MNNPackCUnitTransposeInt8);
            } else if (bitLength == 2) {
                proc = decltype(proc)(core->MNNPackCUnitTransposeInt16);
            }
            if (nullptr == proc) {
                return NOT_SUPPORT;
            }
            auto outputStart = (int8_t*)outputRaw + (start * core->pack * bitLength);
            auto inputStart = (int8_t*)inputRaw + (start * channel * bitLength);
            proc(outputStart, inputStart, end - start, channel, offset);
        } else {
            if (core->bytes == bitLength) {
                proc = decltype(proc)(core->MNNPackCUnit);
            } else if (bitLength == 1) {
                proc = decltype(proc)(core->MNNPackCUnitInt8);
            } else if (bitLength == 2) {
                proc = decltype(proc)(core->MNNPackCUnitInt16);
            }
            if (nullptr == proc) {
                return NOT_SUPPORT;
            }
            if (batch != 1) {
                // Divide in batch
                int offset[2] = {
                    area,
                    outside * inside
                };
                int step = UP_DIV(batch, numberThread);
                int start = tId * step;
                int end = std::min(start + step, batch);
                if (end <= start) {
                    return NO_ERROR;
                }
                for (int v=start; v<end; ++v) {
                    auto outputStart = (int8_t*)outputRaw + (v * core->pack * bitLength * area);
                    auto inputStart = (int8_t*)inputRaw + (v * channel * bitLength * area);
                    proc((float*)outputStart, (const float*)inputStart, area, channel, offset);
                }
            } else {
                // Divide in area
                int offset[2] = {
                    area,
                    area
                };
                int step = UP_DIV(area, numberThread);
                int start = tId * step;
                int end = std::min(start + step, area);
                if (end <= start) {
                    return NO_ERROR;
                }
                auto outputStart = (int8_t*)outputRaw + (start * core->pack * bitLength);
                auto inputStart = (int8_t*)inputRaw + (start * bitLength);
                proc((float*)outputStart, (const float*)inputStart, end - start, channel, offset);
            }
        }
        return NO_ERROR;
    }
    return NO_ERROR;
}