in Sources/TensorFlow/Core/LazyTensorShapeInference.swift [46:142]
func updateOutputShapes() {
let status = TF_NewStatus()
defer { TF_DeleteStatus(status) }
/// Returns shape only if it has already been computed.
func shape(for handle: LazyTensorHandle) -> TensorShape? {
switch handle.handle {
case .symbolic(let op, let index, _): return op.outputShapes[index]
case .concrete(let tfeHandle, _): return tfeHandle.shape
}
}
let inputShapes: [TensorShape?] = inputs.lazy.flatMap { (input) -> [TensorShape?] in
switch input {
case .single(let handle): return [shape(for: handle)]
case .list(let values): return values.lazy.map { shape(for: $0) }
}
}
let inputShapeList = TF_NewShapeAndTypeList( /*num_shapes*/Int32(inputShapes.count))
defer { TF_DeleteShapeAndTypeList(inputShapeList) }
for (i, shape) in inputShapes.enumerated() {
guard let shape = shape else {
TF_ShapeAndTypeListSetUnknownShape(inputShapeList, Int32(i))
continue
}
let int64_dimensions = shape.dimensions.map { Int64($0) }
int64_dimensions.withUnsafeBufferPointer { buffer in
TF_ShapeAndTypeListSetShape(
inputShapeList,
/*index*/Int32(i),
buffer.baseAddress,
Int32(int64_dimensions.count))
}
}
// Returns the `CTensor`, selectively materializing it if needed.
func cTensor(handle: LazyTensorHandle) -> CTensor? {
switch handle.handle {
case .concrete(let h, _):
let cTensor = TFE_TensorHandleResolve(h._cTensorHandle, status)
checkOk(status)
return cTensor
case .symbolic(let op, _, _):
// TODO(https://bugs.swift.org/browse/TF-765): "Pack" is used
// for creating tensors from array literals. So, allow
// materialization for 'Pack' so that we can get the shape for
// array literals. We should revisit this heuristic.
if op.name != "Pack" { return nil }
let cTensor = TFE_TensorHandleResolve(handle._cTensorHandle, status)
checkOk(status)
return cTensor
}
}
// Create `inputTensors` consisting of *only* materialized inputs.
var inputTensors: [CTensor?] = []
for input in inputs {
switch input {
case .single(let v):
inputTensors.append(cTensor(handle: v))
case .list(let values):
inputTensors.append(contentsOf: values.lazy.map { cTensor(handle: $0) })
}
}
// This will be filled in by `TFE_InferShapes` and should be freed later.
var outputShapeListPtr = UnsafeMutablePointer<TF_ShapeAndTypeList>(nil)
defer { TF_DeleteShapeAndTypeList(outputShapeListPtr) }
let tfeOp = self.tfeOp
defer {
TFE_DeleteOp(tfeOp.op)
TF_DeleteStatus(tfeOp.status)
}
inputTensors.withUnsafeMutableBufferPointer { buffer in
TFE_InferShapes(
tfeOp.op,
/*input_shapes*/inputShapeList,
/*input_tensors*/buffer.baseAddress!,
/*input_tensors_as_shapes*/nil,
/*input_resource_shapes_and_types*/nil,
/*output_shapes*/&outputShapeListPtr,
/*output_resource_shapes_and_types*/nil,
status)
checkOk(status)
}
precondition(outputShapeListPtr != nil, "TFE_InferShapes returned nil for output shapes")
let outputShapeList = outputShapeListPtr!.pointee
outputShapes = (0..<outputShapeList.num_items).lazy.map { index -> TensorShape? in
let outputShape = outputShapeList.items![Int(index)]
if outputShape.num_dims == -1 { return nil }
let dims = (0..<outputShape.num_dims).lazy.map { Int(outputShape.dims![Int($0)]) }
let hasUnknownDims = dims.contains { $0 == -1 }
return hasUnknownDims ? nil : TensorShape(dims)
}