void NeuralNetworkShaper::shapeSliceLayer()

in mlmodel/src/Validation/NeuralNetwork/NeuralNetworkShapes.cpp [1093:1250]


void NeuralNetworkShaper::shapeSliceLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Slice layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Slice layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(inputShape.sequenceRange());
    outputShape.updateBatchRange(inputShape.batchRange());

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());

    Specification::SliceLayerParams slice = specLayer.slice();
    int start = static_cast<int>(slice.startindex());
    int end = static_cast<int>(slice.endindex());
    int stride = static_cast<int>(slice.stride());
    auto axis = slice.axis();

    int outsize = 0;
    int inLowerBound = 0;
    bool fixedSize = false;

    if (start >= 0 && end > 0) {
        fixedSize = true;
        outsize = (end - 1 - start) / stride + 1;
        inLowerBound = end;
    }
    else if (start < 0 && end <= 0) {
        fixedSize = true;
        inLowerBound = -1*start;
        outsize = (-start - 1 + end) / stride + 1;
    }

    // TODO: add check that this size is possible from the input size

    if (fixedSize) {
        switch (axis) {
            case Specification::SliceLayerParams::CHANNEL_AXIS:
                outputShape.setChannel(static_cast<size_t>(outsize));
                outputShape.updateHeightRange(inputShape.heightRange());
                outputShape.updateWidthRange(inputShape.widthRange());

                inputShape.lowerBoundChannel(static_cast<size_t>(inLowerBound));
                inputShape.updateHeightRange(outputShape.heightRange());
                inputShape.updateWidthRange(outputShape.widthRange());

                break;
            case Specification::SliceLayerParams::HEIGHT_AXIS:
                outputShape.updateChannelRange(inputShape.channelRange());
                outputShape.setHeight(static_cast<size_t>(outsize));
                outputShape.updateWidthRange(inputShape.widthRange());

                inputShape.updateChannelRange(outputShape.channelRange());
                inputShape.lowerBoundHeight(static_cast<size_t>(inLowerBound));
                inputShape.updateWidthRange(outputShape.widthRange());

                break;
            case Specification::SliceLayerParams::WIDTH_AXIS:
                outputShape.updateChannelRange(inputShape.channelRange());
                outputShape.updateHeightRange(inputShape.heightRange());
                outputShape.setWidth(static_cast<size_t>(outsize));

                inputShape.updateChannelRange(outputShape.channelRange());
                inputShape.updateHeightRange(outputShape.heightRange());
                inputShape.lowerBoundWidth(static_cast<size_t>(inLowerBound));

                break;
            default:
                throw std::runtime_error("Slice layer axis incorrect -- should be caught in validator.");
                break;
        }
    }
    else {
        ShapeRange size;
        switch (axis) {
            case Specification::SliceLayerParams::CHANNEL_AXIS:
                size = inputShape.channelRange();
                break;
            case Specification::SliceLayerParams::HEIGHT_AXIS:
                size = inputShape.heightRange();
                break;
            case Specification::SliceLayerParams::WIDTH_AXIS:
                size = inputShape.widthRange();
                break;
            default:
                throw std::runtime_error("Slice layer axis incorrect -- should be caught in validator.");
                break;
        }

        if (end <= 0) {
            end = (-1*end);
        }
        else { // start <=0, we already took the case where they have the same size
            start = (-1*start);
        }

        ShapeRange outrange = (size - (start + 1 + end)) / stride + 1;

        inLowerBound = start + 1 + end;

        switch (axis) {
            case Specification::SliceLayerParams::CHANNEL_AXIS: {
                outputShape.updateChannelRange(outrange);
                outputShape.updateHeightRange(inputShape.heightRange());
                outputShape.updateWidthRange(inputShape.widthRange());

                inputShape.lowerBoundChannel(static_cast<size_t>(inLowerBound));
                inputShape.updateHeightRange(outputShape.heightRange());
                inputShape.updateWidthRange(outputShape.widthRange());

                break;
            }
            case Specification::SliceLayerParams::HEIGHT_AXIS: {
                outputShape.updateChannelRange(inputShape.channelRange());
                outputShape.updateHeightRange(outrange);
                outputShape.updateWidthRange(inputShape.widthRange());

                inputShape.updateChannelRange(outputShape.channelRange());
                inputShape.lowerBoundHeight(static_cast<size_t>(inLowerBound));
                inputShape.updateWidthRange(outputShape.widthRange());

                break;
            }
            case Specification::SliceLayerParams::WIDTH_AXIS: {
                outputShape.updateChannelRange(inputShape.channelRange());
                outputShape.updateHeightRange(inputShape.heightRange());
                outputShape.updateWidthRange(outrange);

                inputShape.updateChannelRange(outputShape.channelRange());
                inputShape.updateHeightRange(outputShape.heightRange());
                inputShape.lowerBoundWidth(static_cast<size_t>(inLowerBound));

                break;
            }
            default: {
                throw std::runtime_error("Slice layer axis incorrect -- should be caught in validator.");
                break;
            }
        }

    } // signs don't match

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Slice layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Slice layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}