Sources/Concurrency/Executor/ConcurrentSequenceExecutor.swift (90 lines of code) (raw):
//
// Copyright (c) 2018. Uber Technologies
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
import Foundation
/// An executor that executes sequences of tasks concurrently.
///
/// - seeAlso: `SequenceExecutor`.
/// - seeAlso: `Task`.
public class ConcurrentSequenceExecutor: SequenceExecutor {
/// Initializer.
///
/// - parameter name: The name of the executor.
/// - parameter qos: The quality of service of this executor. This
/// defaults to `userInitiated`.
/// - parameter shouldTrackTaskId: `true` if task IDs should be tracked
/// as tasks are executed. `false` otherwise. By tracking the task IDs,
/// if waiting on the completion of a task sequence times out, the
/// reported error contains the ID of the task that was being executed
/// when the timeout occurred. The tracking does incur a minor
/// performance cost. This value defaults to `false`.
/// - parameter maxConcurrentTasks: The optional maximum number of tasks
/// the executor can execute concurrently. `nil` if the executor should
/// not limit the maximum number of concurrent tasks. Defaults to `nil`.
public init(name: String, qos: DispatchQoS = .userInitiated, shouldTrackTaskId: Bool = false, maxConcurrentTasks: Int? = nil) {
taskQueue = DispatchQueue(label: "Executor.taskQueue-\(name)", qos: qos, attributes: .concurrent)
if let maxConcurrentTasks = maxConcurrentTasks {
taskSemaphore = AutoReleasingSemaphore(value: maxConcurrentTasks)
} else {
taskSemaphore = nil
}
self.shouldTrackTaskId = shouldTrackTaskId
}
/// Execute a sequence of tasks concurrently from the given initial task.
///
/// - parameter initialTask: The root task of the sequence of tasks
/// to be executed.
/// - parameter execution: The execution defining the sequence of tasks.
/// When a task completes its execution, this closure is invoked with
/// the task and its produced result. This closure is invoked from
/// multiple threads concurrently, therefore it must be thread-safe.
/// The tasks provided by this closure are executed concurrently.
/// - returns: The execution handle that allows control and monitoring
/// of the sequence of tasks being executed.
public func executeSequence<SequenceResultType>(from initialTask: Task, with execution: @escaping (Task, Any) -> SequenceExecution<SequenceResultType>) -> SequenceExecutionHandle<SequenceResultType> {
let handle: SynchronizedSequenceExecutionHandle<SequenceResultType> = SynchronizedSequenceExecutionHandle()
execute(initialTask, with: handle, execution)
return handle
}
// MARK: - Private
private let taskQueue: DispatchQueue
private let taskSemaphore: AutoReleasingSemaphore?
private let shouldTrackTaskId: Bool
private func execute<SequenceResultType>(_ task: Task, with sequenceHandle: SynchronizedSequenceExecutionHandle<SequenceResultType>, _ execution: @escaping (Task, Any) -> SequenceExecution<SequenceResultType>) {
taskSemaphore?.wait()
taskQueue.async {
guard !sequenceHandle.isCancelled else {
self.taskSemaphore?.signal()
return
}
if self.shouldTrackTaskId {
sequenceHandle.willBeginExecuting(taskId: task.id)
}
do {
let result = try task.typeErasedExecute()
let nextExecution = execution(task, result)
self.taskSemaphore?.signal()
switch nextExecution {
case .continueSequence(let nextTask):
self.execute(nextTask, with: sequenceHandle, execution)
case .endOfSequence(let result):
sequenceHandle.sequenceDidComplete(with: result)
}
} catch {
self.taskSemaphore?.signal()
sequenceHandle.sequenceDidError(with: error)
}
}
}
}
private class SynchronizedSequenceExecutionHandle<SequenceResultType>: SequenceExecutionHandle<SequenceResultType> {
private let latch = CountDownLatch(count: 1)
private let didCancel = AtomicBool(initialValue: false)
private let currentTaskId = AtomicInt(initialValue: nonTrackingDefaultTaskId)
// Use a lock to ensure result/error is properly accessed, since the read
// `await` method may be invoked on a different thread than the write
// `sequenceDidComplete`/`sequenceDidError` method.
private let resultLock = NSRecursiveLock()
private var result: SequenceResultType?
private var error: Error?
fileprivate var isCancelled: Bool {
return didCancel.value
}
fileprivate func willBeginExecuting(taskId: Int) {
currentTaskId.value = taskId
}
fileprivate override func await(withTimeout timeout: TimeInterval?) throws -> SequenceResultType {
let didComplete = latch.await(timeout: timeout)
if !didComplete {
throw SequenceExecutionError.awaitTimeout(currentTaskId.value)
}
resultLock.lock()
defer {
resultLock.unlock()
}
if let error = self.error {
throw error
} else {
// If latch was counted down and there is no error, the result must have been
// set. Therefore, this forced-unwrap is safe.
return result!
}
}
fileprivate func sequenceDidComplete(with result: SequenceResultType) {
resultLock.lock()
self.result = result
resultLock.unlock()
latch.countDown()
}
fileprivate func sequenceDidError(with error: Error) {
resultLock.lock()
self.error = error
resultLock.unlock()
latch.countDown()
}
fileprivate override func cancel() {
didCancel.compareAndSet(expect: false, newValue: true)
}
}