in ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java [37:140]
public RelativeDimensionalSpace mapTo(Index[] indices) {
if (dimensions == null) {
throw new ArrayIndexOutOfBoundsException();
}
int dimIdx = 0;
int indexIdx = 0;
int newDimIdx = 0;
int segmentationIdx = -1;
long initialOffset = 0;
int newAxes = 0;
boolean seenEllipsis = false;
for (Index idx : indices) {
if (idx.isNewAxis()) {
newAxes += 1;
}
if (idx.isEllipsis()) {
if (seenEllipsis) {
throw new IllegalArgumentException("Only one ellipsis allowed");
} else {
seenEllipsis = true;
}
}
}
int newLength = dimensions.length + newAxes;
Dimension[] newDimensions = new Dimension[newLength];
while (indexIdx < indices.length) {
if (indices[indexIdx].isPoint()) {
// When an index targets a single point in a given dimension, calculate the offset of this
// point and cumulate the offset of any subsequent point as well
long offset = 0;
do {
offset += indices[indexIdx].mapCoordinate(0, dimensions[dimIdx]);
dimIdx++;
} while (++indexIdx < indices.length && indices[indexIdx].isPoint());
// If this is the first index, then the offset is the position of the whole dimension
// space within the original one. If not, then we apply the offset to the last vectorial
// dimension
if (newDimIdx == 0) {
initialOffset = offset;
} else {
long reducedSize = dimensions[dimIdx - 1].elementSize();
newDimensions[newDimIdx - 1] = new ReducedDimension(newDimensions[newDimIdx - 1], offset, reducedSize);
segmentationIdx = newDimIdx - 1;
}
} else if (indices[indexIdx].isNewAxis()) {
long newSize;
if (dimIdx == 0) {
// includes everything. Should really include future reduction (at()) but that doesn't seem to cause issues
// elsewhere
newSize = dimensions[0].numElements() * dimensions[0].elementSize();
} else {
newSize = dimensions[dimIdx - 1].elementSize();
}
newDimensions[newDimIdx] = new Axis(1, newSize);
segmentationIdx = newDimIdx; // is this correct?
++newDimIdx;
++indexIdx;
} else if (indices[indexIdx].isEllipsis()) {
int remainingDimensions = dimensions.length - dimIdx;
int requiredDimensions = 0;
for (int i = indexIdx + 1; i < indices.length; i++) {
if (!indices[i].isNewAxis()) {
requiredDimensions++;
}
}
// while the number of dimensions left < the number of indices that consume axes
while (remainingDimensions > requiredDimensions) {
Dimension dim = dimensions[dimIdx++];
if (dim.isSegmented()) {
segmentationIdx = newDimIdx;
}
newDimensions[newDimIdx++] = dim;
remainingDimensions--;
}
indexIdx++;
} else {
// Map any other index to the appropriate dimension of this space
Dimension newDimension = indices[indexIdx].apply(dimensions[dimIdx++]);
newDimensions[newDimIdx] = newDimension;
if (newDimension.isSegmented()) {
segmentationIdx = newDimIdx;
}
++newDimIdx;
++indexIdx;
}
}
// When the number of indices provided is smaller than the number of dimensions in this space,
// we copy the remaining dimensions directly to the new space as well.
for (; dimIdx < dimensions.length; ++dimIdx, ++newDimIdx) {
Dimension dim = dimensions[dimIdx];
newDimensions[newDimIdx] = dim;
if (dim.isSegmented()) {
segmentationIdx = newDimIdx;
}
}
return new RelativeDimensionalSpace(Arrays.copyOf(newDimensions, newDimIdx), segmentationIdx, initialOffset);
}