Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Concurrency/ConcurrentTaskQueue.swift
1 views
//
// Copyright 2024 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation

/// A Task-based replacement for `OperationQueue.maxConcurrentOperationCount`.
///
/// A `ConcurrentTaskQueue` with `concurrentLimit` set to `1` ("CTQ1") is
/// similar to, but not identical to, a `SerialTaskQueue` ("STQ"). The
/// former will reuse the current Task, whereas the latter creates a new
/// Task for each block that's enqueued. A CTQ1 can be replaced anywhere it
/// appears with an STQ without impacting execution ordering guarantees.
/// However, SQTs can't (always) be replaced by CTQ1s because the former
/// supports bridging synchronous and asynchronous contexts.
///
/// A ConcurrentTaskQueue is intended to be (somewhat) "invisible to the
/// caller", meaning that it doesn't violate the caller's expectations. For
/// example, ConcurrentTaskQueue supports cooperative cancellation, so it's
/// safe to use in methods intended to be cancellable. (This is different
/// than STQs -- those break the "chain of cancellation", and that might be
/// surprising to async callers of an async method.)
///
/// A ConcurrentTaskQueue serves as a suspension point/cancellation check.
/// This is different than invoking `block` directly (it's not "invisible"),
/// but it's roughly equivalent to calling `Task.checkCancellation()` just
/// prior to invoking `block`. Various other APIs -- `Task.sleep`, network
/// requests, Monitors -- similarly serve as cancellation points.
///
/// A ConcurrentTaskQueue doesn't "slow down" the reaction speed for
/// cancellations. If a block is executing when it's canceled, the reaction
/// speed will match that of the block that's executing. If a block hasn't
/// started executing, cancellation is immediate.
public final actor ConcurrentTaskQueue {
    private var remainingCount: Int
    private var pendingContinuations = [Int: CheckedContinuation<Bool, Never>]()

    private var pendingContinuationIndex = 0
    private var resumedContinuationIndex = 0

    public init(concurrentLimit: Int) {
        self.remainingCount = concurrentLimit
    }

    /// Executes `block` when fewer than `concurrentLimit` blocks are running.
    ///
    /// This method doesn't check for cancellation while waiting.
    public func runWithoutTaskCancellationHandler<T, E>(_ block: () async throws(E) -> T) async throws(E) -> T {
        let result: Result<T, E>
        do throws(CancellationError) {
            result = try await _run(isCancellable: false, block)
        } catch {
            owsFail("Can't throw an error when isCancellable is false.")
        }
        return try result.get()
    }

    /// Executes `block` when fewer than `concurrentLimit` blocks are running.
    ///
    /// This method throws a `CancellationError` immediately (i.e., out of
    /// "order") if canceled while waiting.
    ///
    /// From a cancellation perspective, its behavior is an efficient
    /// implementation (i.e., without polling) of the following code:
    ///
    ///     while runningBlocks >= concurrentLimit {
    ///         try Task.checkCancellation()
    ///         await Task.yield()
    ///     }
    ///     return try await block()
    ///
    /// - Throws: An error thrown from `block` or a `CancellationError`.
    public func run<T>(_ block: () async throws -> T) async throws -> T {
        return try await _run(isCancellable: true, block).get()
    }

    /// See documentation for method above; variant for non-throwing blocks that just throws CancellationError.
    public func run<T>(_ block: () async -> T) async throws(CancellationError) -> T {
        return try await _run(isCancellable: true, block).get()
    }

    /// This throws a CancellationError and returns a Result to aid in compiler
    /// enforcement for the implementation of the two public-facing methods.
    private func _run<T, E>(
        isCancellable: Bool,
        _ block: () async throws(E) -> T,
    ) async throws(CancellationError) -> Result<T, E> {
        if self.remainingCount > 0 {
            self.remainingCount -= 1
        } else {
            let continuationIndex = self.pendingContinuationIndex
            self.pendingContinuationIndex += 1
            let isCanceled = await withTaskCancellationHandler(
                operation: {
                    return await withCheckedContinuation { continuation in
                        self.pendingContinuations[continuationIndex] = continuation
                    }
                },
                onCancel: {
                    if isCancellable {
                        Task { await self.cancelContinuation(continuationIndex: continuationIndex) }
                    }
                },
            )
            if isCanceled {
                throw CancellationError()
            }
        }
        defer {
            self.remainingCount += 1
            while resumedContinuationIndex < pendingContinuationIndex {
                let continuationToResume = self.pendingContinuations.removeValue(forKey: resumedContinuationIndex)
                resumedContinuationIndex += 1
                if let continuationToResume {
                    self.remainingCount -= 1
                    continuationToResume.resume(returning: false)
                    break
                }
            }
        }
        return await Result(catching: { () async throws(E) in try await block() })
    }

    private func cancelContinuation(continuationIndex: Int) {
        self.pendingContinuations.removeValue(forKey: continuationIndex)?.resume(returning: true)
    }
}