in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/transformation/deserializers/ProtobufResponseRowDeserializer.scala [81:105]
private def getValueFromTensor(valuesTensor: Float32Tensor) : Any = {
val valuesCount = valuesTensor.getValuesCount
require(valuesCount > 0, "Can't get value from deserialized tensor: values list is empty.")
if (valuesCount == 1) {
// Get value as a scalar
valuesTensor.getValues(0).toDouble
} else {
// Get value as a Vector
val keyCount = valuesTensor.getKeysCount
val values = asScalaBufferConverter(valuesTensor.getValuesList)
.asScala.toArray.map(_.toDouble)
if (keyCount > 0) {
// Tensor is sparsely encoded. We can only represent sparsely-encoded vectors
// (not higher-order tensors).
require(valuesTensor.getShapeCount == 1,
"Cannot deserialize tensor to vector. Shape list has more than one value.")
val indices = asScalaBufferConverter(valuesTensor.getKeysList).asScala.toArray.map(_.toInt)
return new SparseVector(valuesTensor.getShape(0).toInt, indices, values)
} else {
// Vector is densely encoded.
return new DenseVector(values)
}
}
}