func updateOutputShapes()

in Sources/TensorFlow/Core/LazyTensorShapeInference.swift [46:142]


  func updateOutputShapes() {
    let status = TF_NewStatus()
    defer { TF_DeleteStatus(status) }

    /// Returns shape only if it has already been computed.
    func shape(for handle: LazyTensorHandle) -> TensorShape? {
      switch handle.handle {
      case .symbolic(let op, let index, _): return op.outputShapes[index]
      case .concrete(let tfeHandle, _): return tfeHandle.shape
      }
    }

    let inputShapes: [TensorShape?] = inputs.lazy.flatMap { (input) -> [TensorShape?] in
      switch input {
      case .single(let handle): return [shape(for: handle)]
      case .list(let values): return values.lazy.map { shape(for: $0) }
      }
    }
    let inputShapeList = TF_NewShapeAndTypeList( /*num_shapes*/Int32(inputShapes.count))
    defer { TF_DeleteShapeAndTypeList(inputShapeList) }
    for (i, shape) in inputShapes.enumerated() {
      guard let shape = shape else {
        TF_ShapeAndTypeListSetUnknownShape(inputShapeList, Int32(i))
        continue
      }
      let int64_dimensions = shape.dimensions.map { Int64($0) }
      int64_dimensions.withUnsafeBufferPointer { buffer in
        TF_ShapeAndTypeListSetShape(
          inputShapeList,
          /*index*/Int32(i),
          buffer.baseAddress,
          Int32(int64_dimensions.count))
      }
    }

    // Returns the `CTensor`, selectively materializing it if needed.
    func cTensor(handle: LazyTensorHandle) -> CTensor? {
      switch handle.handle {
      case .concrete(let h, _):
        let cTensor = TFE_TensorHandleResolve(h._cTensorHandle, status)
        checkOk(status)
        return cTensor
      case .symbolic(let op, _, _):
        // TODO(https://bugs.swift.org/browse/TF-765): "Pack" is used
        // for creating tensors from array literals. So, allow
        // materialization for 'Pack' so that we can get the shape for
        // array literals. We should revisit this heuristic.
        if op.name != "Pack" { return nil }
        let cTensor = TFE_TensorHandleResolve(handle._cTensorHandle, status)
        checkOk(status)
        return cTensor
      }
    }

    // Create `inputTensors` consisting of *only* materialized inputs.
    var inputTensors: [CTensor?] = []
    for input in inputs {
      switch input {
      case .single(let v):
        inputTensors.append(cTensor(handle: v))
      case .list(let values):
        inputTensors.append(contentsOf: values.lazy.map { cTensor(handle: $0) })
      }
    }

    // This will be filled in by `TFE_InferShapes` and should be freed later.
    var outputShapeListPtr = UnsafeMutablePointer<TF_ShapeAndTypeList>(nil)
    defer { TF_DeleteShapeAndTypeList(outputShapeListPtr) }

    let tfeOp = self.tfeOp
    defer {
      TFE_DeleteOp(tfeOp.op)
      TF_DeleteStatus(tfeOp.status)
    }

    inputTensors.withUnsafeMutableBufferPointer { buffer in
      TFE_InferShapes(
        tfeOp.op,
        /*input_shapes*/inputShapeList,
        /*input_tensors*/buffer.baseAddress!,
        /*input_tensors_as_shapes*/nil,
        /*input_resource_shapes_and_types*/nil,
        /*output_shapes*/&outputShapeListPtr,
        /*output_resource_shapes_and_types*/nil,
        status)
      checkOk(status)
    }

    precondition(outputShapeListPtr != nil, "TFE_InferShapes returned nil for output shapes")
    let outputShapeList = outputShapeListPtr!.pointee
    outputShapes = (0..<outputShapeList.num_items).lazy.map { index -> TensorShape? in
      let outputShape = outputShapeList.items![Int(index)]
      if outputShape.num_dims == -1 { return nil }
      let dims = (0..<outputShape.num_dims).lazy.map { Int(outputShape.dims![Int($0)]) }
      let hasUnknownDims = dims.contains { $0 == -1 }
      return hasUnknownDims ? nil : TensorShape(dims)
    }