private def getValueFromTensor()

in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/transformation/deserializers/ProtobufResponseRowDeserializer.scala [81:105]


  private def getValueFromTensor(valuesTensor: Float32Tensor) : Any = {
    val valuesCount = valuesTensor.getValuesCount
    require(valuesCount > 0, "Can't get value from deserialized tensor: values list is empty.")

    if (valuesCount == 1) {
      // Get value as a scalar
      valuesTensor.getValues(0).toDouble
    } else {
      // Get value as a Vector
      val keyCount = valuesTensor.getKeysCount
      val values = asScalaBufferConverter(valuesTensor.getValuesList)
        .asScala.toArray.map(_.toDouble)
        if (keyCount > 0) {
        // Tensor is sparsely encoded. We can only represent sparsely-encoded vectors
        // (not higher-order tensors).
        require(valuesTensor.getShapeCount == 1,
          "Cannot deserialize tensor to vector. Shape list has more than one value.")
        val indices = asScalaBufferConverter(valuesTensor.getKeysList).asScala.toArray.map(_.toInt)
        return new SparseVector(valuesTensor.getShape(0).toInt, indices, values)
      } else {
        // Vector is densely encoded.
        return new DenseVector(values)
      }
    }
  }