in mlmodel/src/Validation/NeuralNetwork/NeuralNetworkShapes.cpp [1093:1250]
void NeuralNetworkShaper::shapeSliceLayer(const Specification::NeuralNetworkLayer& specLayer) {
//get the input shape
ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
outputShape.setName(specLayer.output(0));
#if COREML_VALIDATOR_VERBOSE
std::cout << "Slice layer " << specLayer.name() << " input shapes (before): " << std::endl;
std::cout << inputShape;
std::cout << "Slice layer " << specLayer.name() << " output shapes (before): " << std::endl;
std::cout << outputShape;
#endif
outputShape.updateSequenceRange(inputShape.sequenceRange());
outputShape.updateBatchRange(inputShape.batchRange());
inputShape.updateSequenceRange(outputShape.sequenceRange());
inputShape.updateBatchRange(outputShape.batchRange());
Specification::SliceLayerParams slice = specLayer.slice();
int start = static_cast<int>(slice.startindex());
int end = static_cast<int>(slice.endindex());
int stride = static_cast<int>(slice.stride());
auto axis = slice.axis();
int outsize = 0;
int inLowerBound = 0;
bool fixedSize = false;
if (start >= 0 && end > 0) {
fixedSize = true;
outsize = (end - 1 - start) / stride + 1;
inLowerBound = end;
}
else if (start < 0 && end <= 0) {
fixedSize = true;
inLowerBound = -1*start;
outsize = (-start - 1 + end) / stride + 1;
}
// TODO: add check that this size is possible from the input size
if (fixedSize) {
switch (axis) {
case Specification::SliceLayerParams::CHANNEL_AXIS:
outputShape.setChannel(static_cast<size_t>(outsize));
outputShape.updateHeightRange(inputShape.heightRange());
outputShape.updateWidthRange(inputShape.widthRange());
inputShape.lowerBoundChannel(static_cast<size_t>(inLowerBound));
inputShape.updateHeightRange(outputShape.heightRange());
inputShape.updateWidthRange(outputShape.widthRange());
break;
case Specification::SliceLayerParams::HEIGHT_AXIS:
outputShape.updateChannelRange(inputShape.channelRange());
outputShape.setHeight(static_cast<size_t>(outsize));
outputShape.updateWidthRange(inputShape.widthRange());
inputShape.updateChannelRange(outputShape.channelRange());
inputShape.lowerBoundHeight(static_cast<size_t>(inLowerBound));
inputShape.updateWidthRange(outputShape.widthRange());
break;
case Specification::SliceLayerParams::WIDTH_AXIS:
outputShape.updateChannelRange(inputShape.channelRange());
outputShape.updateHeightRange(inputShape.heightRange());
outputShape.setWidth(static_cast<size_t>(outsize));
inputShape.updateChannelRange(outputShape.channelRange());
inputShape.updateHeightRange(outputShape.heightRange());
inputShape.lowerBoundWidth(static_cast<size_t>(inLowerBound));
break;
default:
throw std::runtime_error("Slice layer axis incorrect -- should be caught in validator.");
break;
}
}
else {
ShapeRange size;
switch (axis) {
case Specification::SliceLayerParams::CHANNEL_AXIS:
size = inputShape.channelRange();
break;
case Specification::SliceLayerParams::HEIGHT_AXIS:
size = inputShape.heightRange();
break;
case Specification::SliceLayerParams::WIDTH_AXIS:
size = inputShape.widthRange();
break;
default:
throw std::runtime_error("Slice layer axis incorrect -- should be caught in validator.");
break;
}
if (end <= 0) {
end = (-1*end);
}
else { // start <=0, we already took the case where they have the same size
start = (-1*start);
}
ShapeRange outrange = (size - (start + 1 + end)) / stride + 1;
inLowerBound = start + 1 + end;
switch (axis) {
case Specification::SliceLayerParams::CHANNEL_AXIS: {
outputShape.updateChannelRange(outrange);
outputShape.updateHeightRange(inputShape.heightRange());
outputShape.updateWidthRange(inputShape.widthRange());
inputShape.lowerBoundChannel(static_cast<size_t>(inLowerBound));
inputShape.updateHeightRange(outputShape.heightRange());
inputShape.updateWidthRange(outputShape.widthRange());
break;
}
case Specification::SliceLayerParams::HEIGHT_AXIS: {
outputShape.updateChannelRange(inputShape.channelRange());
outputShape.updateHeightRange(outrange);
outputShape.updateWidthRange(inputShape.widthRange());
inputShape.updateChannelRange(outputShape.channelRange());
inputShape.lowerBoundHeight(static_cast<size_t>(inLowerBound));
inputShape.updateWidthRange(outputShape.widthRange());
break;
}
case Specification::SliceLayerParams::WIDTH_AXIS: {
outputShape.updateChannelRange(inputShape.channelRange());
outputShape.updateHeightRange(inputShape.heightRange());
outputShape.updateWidthRange(outrange);
inputShape.updateChannelRange(outputShape.channelRange());
inputShape.updateHeightRange(outputShape.heightRange());
inputShape.lowerBoundWidth(static_cast<size_t>(inLowerBound));
break;
}
default: {
throw std::runtime_error("Slice layer axis incorrect -- should be caught in validator.");
break;
}
}
} // signs don't match
#if COREML_VALIDATOR_VERBOSE
std::cout << "Slice layer " << specLayer.name() << " input shapes (after): " << std::endl;
std::cout << inputShape;
std::cout << "Slice layer " << specLayer.name() << " output shapes (after): " << std::endl;
std::cout << outputShape;
#endif
}