Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Network/OWSChatConnection.swift
1 views
//
// Copyright 2021 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation
import LibSignalClient

public enum OWSChatConnectionType: Int, CaseIterable, CustomDebugStringConvertible {
    case identified = 0
    case unidentified = 1

    public var debugDescription: String {
        switch self {
        case .identified:
            return "[type: identified]"
        case .unidentified:
            return "[type: unidentified]"
        }
    }
}

// MARK: -

public enum OWSChatConnectionState: Int, CustomDebugStringConvertible {
    case closed = 0
    case connecting = 1
    case open = 2

    public var debugDescription: String {
        switch self {
        case .closed:
            return "closed"
        case .connecting:
            return "connecting"
        case .open:
            return "open"
        }
    }
}

// MARK: -

public class OWSChatConnection {
    enum AlertType: String {
        case idlePrimaryDevice = "idle-primary-device"
    }

    public static let chatConnectionStateDidChange = Notification.Name("chatConnectionStateDidChange")
    public static let chatConnectionStateKey: String = "chatConnectionState"

    fileprivate let serialQueue: DispatchQueue

    // MARK: -

    fileprivate let type: OWSChatConnectionType
    fileprivate let appExpiry: AppExpiry
    fileprivate let appReadiness: AppReadiness
    fileprivate let db: any DB

    public var hasEmptiedInitialQueue: Bool {
        get async {
            return false
        }
    }

    fileprivate var logPrefix: String {
        "[\(type)]"
    }

    // MARK: -

    public init(
        type: OWSChatConnectionType,
        appExpiry: AppExpiry,
        appReadiness: AppReadiness,
        db: any DB,
    ) {
        AssertIsOnMainThread()

        self.serialQueue = DispatchQueue(label: "org.signal.chat-connection-\(type)")
        self.type = type
        self.appExpiry = appExpiry
        self.appReadiness = appReadiness
        self.db = db

        appReadiness.runNowOrWhenAppDidBecomeReadySync { [weak self] in
            self?.appDidBecomeReady()
        }
    }

    // MARK: - Notifications

    // We want to observe these notifications lazily to avoid accessing
    // the data store in [application: didFinishLaunchingWithOptions:].
    fileprivate func appDidBecomeReady() {
        AssertIsOnMainThread()

        NotificationCenter.default.addObserver(
            self,
            selector: #selector(isCensorshipCircumventionActiveDidChange),
            name: .isCensorshipCircumventionActiveDidChange,
            object: nil,
        )
        NotificationCenter.default.addObserver(
            self,
            selector: #selector(appExpiryDidChange),
            name: AppExpiry.AppExpiryDidChange,
            object: nil,
        )
    }

    // MARK: -

    fileprivate var _currentState: OWSChatConnectionState = .closed {
        didSet {
            DispatchQueue.main.async { [_currentState] in
                self.currentState = _currentState
            }
        }
    }

    // We update currentState based on lifecycle events,
    // so this should be accurate (with the usual caveats about races).
    @MainActor
    public private(set) var currentState: OWSChatConnectionState = .closed {
        didSet { AssertIsOnMainThread() }
    }

    private var onOpen = [NSObject: Monitor.Continuation]()

    private let openCondition = Monitor.Condition<OWSChatConnection>(
        isSatisfied: { $0._currentState == .open },
        waiters: \.onOpen,
    )

    fileprivate func notifyStatusChange(newState: OWSChatConnectionState) {
        // Technically this would be safe to call from anywhere,
        // but requiring it to be on the serial queue means it's less likely
        // for a caller to check a condition that's immediately out of date (a race).
        assertOnQueue(serialQueue)

        let oldState = self._currentState
        self._currentState = newState

        if newState != oldState {
            Logger.info("\(logPrefix): \(oldState) -> \(newState)")
        }
        Monitor.notifyOnQueue(
            serialQueue,
            state: self,
            conditions: openCondition,
        )
        NotificationCenter.default.postOnMainThread(
            name: Self.chatConnectionStateDidChange,
            object: nil,
            userInfo: [Self.chatConnectionStateKey: newState],
        )
    }

    func waitForOpen() async throws(CancellationError) {
        try await Monitor.waitForCondition(openCondition, in: self, on: serialQueue)
    }

    fileprivate func waitUntilReadyAndPerformRequest<Output>(
        operation: () async throws -> Output,
    ) async throws -> Output {
        let timeout: TimeInterval = 30
        do {
            try await withCooperativeRace(
                { try await self.waitForOpen() },
                { try await self.waitUntilSocketShouldBeClosed() },
                { try await Task.sleep(nanoseconds: timeout.clampedNanoseconds); throw CooperativeTimeoutError() },
            )

            try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
                self.serialQueue.async {
                    if let canOpenWebSocketError = self.canOpenWebSocketError {
                        continuation.resume(throwing: canOpenWebSocketError)
                    } else {
                        continuation.resume()
                    }
                }
            }
        } catch is CooperativeTimeoutError {
            throw OWSHTTPError.networkFailure(.genericFailure)
        }

        let output = try await operation()
        OutageDetection.shared.reportConnectionSuccess()
        return output
    }

    func waitForDisconnectIfClosed() async {
        owsFail("Subclasses must provide an implementation.")
    }

    // Access on serialQueue.
    private var onSocketShouldBeClosed = [NSObject: Monitor.Continuation]()

    func waitUntilSocketShouldBeClosed() async throws(CancellationError) {
        return try await Monitor.waitForCondition(shouldBeClosedCondition, in: self, on: serialQueue)
    }

    // MARK: - Socket LifeCycle

    /// Tracks app-wide, "fatal" errors that block web sockets.
    ///
    /// If this property is nonnil, the app shouldn't attempt to open a
    /// connection to the server. If `makeRequest` is called while this property
    /// is nonnil, the request will fail with this error.
    ///
    /// This property is used for "fatal" errors: "the user isn't registered",
    /// "the app has expired", "this extension doesn't ever use web sockets",
    /// etc. Transient errors ("no network", "the server returned a 5xx", etc.)
    /// don't use this property.
    ///
    /// Must be accessed on `serialQueue`.
    private var canOpenWebSocketError: (any Error)? = OWSHTTPError.networkFailure(.genericFailure)

    func updateCanOpenWebSocket() {
        serialQueue.async(_updateCanOpenWebSocket)
    }

    fileprivate func _updateCanOpenWebSocket() {
        assertOnQueue(serialQueue)

        let oldValue = (canOpenWebSocketError == nil)
        canOpenWebSocketError = _canOpenWebSocketError()
        let newValue = (canOpenWebSocketError == nil)
        if newValue != oldValue {
            _applyDesiredSocketState()
        }
    }

    /// May be overridden.
    fileprivate func _canOpenWebSocketError() -> (any Error)? {
        guard !appExpiry.isExpired(now: Date()) else {
            return AppExpiredError()
        }
        return nil
    }

    public final class ConnectionToken {
        private let tokenId: Int
        private weak var chatConnection: OWSChatConnection?

        fileprivate init(tokenId: Int, chatConnection: OWSChatConnection) {
            self.tokenId = tokenId
            self.chatConnection = chatConnection
        }

        deinit {
            guard let chatConnection else {
                return
            }
            let didRelease = chatConnection.releaseConnection(self.tokenId)
            owsAssertDebug(!didRelease, "You must explicitly call releaseConnection().")
        }

        public func releaseConnection() {
            guard let chatConnection else {
                return
            }
            let didRelease = chatConnection.releaseConnection(self.tokenId)
            owsAssertDebug(didRelease, "You can't call releaseConnection() multiple times.")
        }
    }

    struct ConnectionTokenState {
        var tokenId = 0

        var activeTokenIds = Set<Int>()

        func shouldSocketBeOpen() -> Bool {
            return !self.activeTokenIds.isEmpty
        }
    }

    private let connectionTokenState = AtomicValue(ConnectionTokenState(), lock: .init())

    public func requestConnection() -> ConnectionToken {
        let (connectionToken, shouldConnect) = connectionTokenState.update {
            $0.tokenId += 1
            let shouldConnect = !$0.shouldSocketBeOpen()
            // If we want to reconnect, set the number of retries to "Int.max" (aka
            // "infinity"). If we shouldn't reconnect, set the number of retries to 1.
            $0.activeTokenIds.insert($0.tokenId)
            let connectionToken = ConnectionToken(tokenId: $0.tokenId, chatConnection: self)
            return (connectionToken, shouldConnect)
        }
        if shouldConnect {
            applyDesiredSocketState()
        }
        return connectionToken
    }

    private func releaseConnection(_ tokenId: Int) -> Bool {
        let (didRelease, shouldDisconnect) = connectionTokenState.update {
            let didRelease = $0.activeTokenIds.remove(tokenId) != nil
            return (didRelease, !$0.shouldSocketBeOpen())
        }
        if shouldDisconnect {
            applyDesiredSocketState()
        }
        return didRelease
    }

    // This method aligns the socket state with the "desired" socket state.
    //
    // This method is thread-safe.
    fileprivate final func applyDesiredSocketState() {
        serialQueue.async(self._applyDesiredSocketState)
    }

    private func shouldSocketBeOpen() -> Bool {
        assertOnQueue(serialQueue)

        return (
            (canOpenWebSocketError == nil)
                && connectionTokenState.update { $0.shouldSocketBeOpen() },
        )
    }

    fileprivate final func _applyDesiredSocketState() {
        assertOnQueue(serialQueue)

        if shouldSocketBeOpen() {
            owsPrecondition(appReadiness.isAppReady)
            ensureWebsocketExists()
        } else {
            disconnectIfNeeded()
            notifySocketShouldBeClosedIfNeeded()
        }
    }

    private let shouldBeClosedCondition = Monitor.Condition<OWSChatConnection>(
        isSatisfied: { !$0.shouldSocketBeOpen() },
        waiters: \.onSocketShouldBeClosed,
    )

    private func notifySocketShouldBeClosedIfNeeded() {
        assertOnQueue(serialQueue)
        Monitor.notifyOnQueue(serialQueue, state: self, conditions: shouldBeClosedCondition)
    }

    // This method must be thread-safe.
    fileprivate func cycleSocket() {
        serialQueue.async(self._cycleSocket)
    }

    fileprivate func _cycleSocket() {
        assertOnQueue(serialQueue)

        disconnectIfNeeded()
        _applyDesiredSocketState()
    }

    fileprivate func ensureWebsocketExists() {
        assertOnQueue(serialQueue)
        owsFailDebug("should be using a concrete subclass")
    }

    fileprivate func disconnectIfNeeded() {
        assertOnQueue(serialQueue)
        owsFailDebug("should be using a concrete subclass")
    }

    // MARK: - Notifications

    @objc
    private func isCensorshipCircumventionActiveDidChange(_ notification: NSNotification) {
        AssertIsOnMainThread()

        cycleSocket()
    }

    @objc
    private func appExpiryDidChange(_ notification: NSNotification) {
        AssertIsOnMainThread()

        updateCanOpenWebSocket()
    }

    // MARK: - Message Sending

    func makeRequest(_ request: TSRequest) async throws -> HTTPResponse {
        let requestId = UInt64.random(in: .min ... .max)
        let requestDescription = "\(request) [\(requestId)]"
        do {
            request.logger.info("Sending… -> \(requestDescription)")

            return try await waitUntilReadyAndPerformRequest {
                let response = try await self.makeRequestInternal(request, requestId: requestId)
                request.logger.info("HTTP \(response.responseStatusCode) <- \(requestDescription)")
                return response
            }
        } catch {
            if let statusCode = error.httpStatusCode {
                request.logger.warn("HTTP \(statusCode) <- \(requestDescription)")
            } else {
                request.logger.warn("Failure. <- \(requestDescription): \(error)")
            }
            throw error
        }
    }

    fileprivate func makeRequestInternal(_ request: TSRequest, requestId: UInt64) async throws -> HTTPResponse {
        owsFail("must be using a concrete subclass")
    }

    fileprivate final func handleRequestResponse(
        requestUrl: URL,
        responseStatus: Int,
        responseHeaders: HttpHeaders,
        responseData: Data?,
    ) async throws(OWSHTTPError) -> HTTPResponse {
        if (200...299).contains(responseStatus) {
            let response = HTTPResponse(
                requestUrl: requestUrl,
                status: responseStatus,
                headers: responseHeaders,
                bodyData: responseData,
            )
            return response
        } else {
            let error = await HTTPUtils.preprocessMainServiceHTTPError(
                requestUrl: requestUrl,
                responseStatus: responseStatus,
                responseHeaders: responseHeaders,
                responseData: responseData,
            )
            throw error
        }
    }

    // MARK: - Reconnect

    static let socketReconnectDelay: TimeInterval = 5
}

// MARK: -

class OWSChatConnectionUsingLibSignal<Connection: ChatConnection & Sendable>: OWSChatConnection, ConnectionEventsListener {
    fileprivate let libsignalNet: Net

    fileprivate enum ConnectionState {
        case closed(task: Task<Void, Never>?)
        case connecting(token: NSObject, task: Task<Connection?, Never>)
        case open(Connection)

        var asExternalState: OWSChatConnectionState {
            switch self {
            case .closed: .closed
            case .connecting(token: _, task: _): .connecting
            case .open: .open
            }
        }

        func isCurrentlyConnecting(_ token: NSObject) -> Bool {
            guard case .connecting(token: let activeToken, task: _) = self else {
                return false
            }
            return activeToken === token
        }

        func isActive(_ connection: Connection) -> Bool {
            guard case .open(let activeConnection) = self else {
                return false
            }
            return activeConnection === connection
        }

        func waitToFinishConnecting(cancel: Bool = false) async -> Connection? {
            switch self {
            case .closed:
                return nil
            case .connecting(token: _, task: let task):
                if cancel {
                    task.cancel()
                }
                return await task.value
            case .open(let connection):
                return connection
            }
        }
    }

    private var _connection: ConnectionState = .closed(task: nil)
    fileprivate var connection: ConnectionState {
        get {
            assertOnQueue(serialQueue)
            return _connection
        }
        set {
            assertOnQueue(serialQueue)
            _connection = newValue
            notifyStatusChange(newState: newValue.asExternalState)
        }
    }

    fileprivate func getOpenConnectionAfterHavingWaited() async -> Connection? {
        // To improve: some callers might have already done a hop to serialQueue,
        // and now we're making another one (without priority donation, even).
        let connection = await withCheckedContinuation { continuation in
            self.serialQueue.async {
                continuation.resume(returning: self.connection)
            }
        }

        // There is a race condition where we cycle the socket between
        // `waitForOpen` succeeding (see callers) and the code that runs here. If we
        // win the race, the request we send will be almost immediately canceled.
        // If we lose the race, we won't send the request at all. These outcomes
        // are essentially equivalent, and it's not necessary to support this race
        // condition where the socket cycles immediately after it opens.
        switch connection {
        case .closed(task: _), .connecting(token: _, task: _):
            return nil
        case .open(let service):
            return service
        }
    }

    init(
        libsignalNet: Net,
        type: OWSChatConnectionType,
        appExpiry: AppExpiry,
        appReadiness: AppReadiness,
        db: any DB,
    ) {
        self.libsignalNet = libsignalNet
        super.init(type: type, appExpiry: appExpiry, appReadiness: appReadiness, db: db)
    }

    override fileprivate func appDidBecomeReady() {
        super.appDidBecomeReady()

        NotificationCenter.default.addObserver(
            self,
            selector: #selector(signalProxyConfigDidChange),
            name: .signalProxyConfigDidChange,
            object: nil,
        )
    }

    fileprivate func connectChatService(token: NSObject) async throws -> Connection {
        fatalError("must be overridden by subclass")
    }

    @objc
    private func signalProxyConfigDidChange(_ notification: NSNotification) {
        // The libsignal connection needs to be recreated whether the proxy is going up,
        // changing, or going down.
        Logger.info("\(logPrefix) signal proxy config changed; cycling socket")
        cycleSocket()
    }

    override fileprivate func ensureWebsocketExists() {
        assertOnQueue(serialQueue)

        let disconnectTask: Task<Void, Never>?
        switch connection {
        case .open:
            return
        case .connecting:
            // The most recent transition was attempting to connect, and we have not yet observed a failure.
            // That's as good as we're going to get.
            return
        case .closed(let task):
            disconnectTask = task
        }

        // Unique while live.
        let token = NSObject()
        connection = .connecting(token: token, task: Task { [token] () -> Connection? in
            // We need to wait until the prior connection releases the connection lock
            // before we try to acquire it again. This happens as part of this Task.
            await disconnectTask?.value

            func connectionAttemptCompleted(_ state: ConnectionState) async -> Connection? {
                // We're not done until self.connection has been updated.
                // (Otherwise, we might try to send requests before calling start(listener:).)
                return await withCheckedContinuation { continuation in
                    self.serialQueue.async {
                        guard self.connection.isCurrentlyConnecting(token) else {
                            // We finished connecting, but we've since been asked to disconnect
                            // (either because we should be offline, or because config has changed).
                            continuation.resume(returning: nil)
                            return
                        }

                        self.connection = state

                        if case .open(let connection) = state {
                            continuation.resume(returning: connection)
                        } else {
                            continuation.resume(returning: nil)
                        }
                    }
                }
            }

            do {
                Logger.info("\(self.logPrefix): starting chat connect (signalProxyEnabled: \(SignalProxy.isEnabled), signalProxyReady: \(SignalProxy.isEnabledAndReady))")
                let chatService = try await self.connectChatService(token: token)
                if type == .identified {
                    self.didConnectIdentified()
                }
                OutageDetection.shared.reportConnectionSuccess()
                return await connectionAttemptCompleted(.open(chatService))

            } catch is CancellationError {
                // We've been asked to disconnect, no other action necessary.
                // (We could even skip updating state, since the disconnect action should have already set it to "closed",
                // but just in case it's still on "connecting" we'll continue on to execute that cleanup.)
                return await connectionAttemptCompleted(.closed(task: nil))
            } catch SignalError.appExpired(_) {
                await appExpiry.setHasAppExpiredAtCurrentVersion(db: db)
            } catch SignalError.deviceDeregistered(_) {
                // Handled by the subclass; this isn't a connection failure.
            } catch {
                Logger.warn("\(self.logPrefix): failed to connect: \(error)")
                OutageDetection.shared.reportConnectionFailure()
            }
            let result = await connectionAttemptCompleted(.closed(task: nil))
            serialQueue.async {
                self.reconnectAfterFailure()
            }
            return result
        })
    }

    fileprivate func didConnectIdentified() {
        // Overridden by subclass.
    }

    override fileprivate func disconnectIfNeeded() {
        assertOnQueue(serialQueue)

        let previousConnection = connection
        if case .closed = previousConnection {
            // Either we are already disconnecting,
            // or we finished disconnecting,
            // or we were never connected to begin with.
            return
        }
        // Spin off a background task to disconnect the previous connection.
        connection = .closed(task: Task {
            do {
                try await previousConnection.waitToFinishConnecting(cancel: true)?.disconnect()
            } catch {
                Logger.warn("\(self.logPrefix): error while disconnecting: \(error)")
            }
        })
    }

    override func waitForDisconnectIfClosed() async {
        let connection = await withCheckedContinuation { continuation in
            serialQueue.async { continuation.resume(returning: self.connection) }
        }
        switch connection {
        case .open(_), .connecting:
            break
        case .closed(let disconnectTask):
            await disconnectTask?.value
        }
    }

    override fileprivate var logPrefix: String {
        "[\(type): libsignal]"
    }

    fileprivate let authOverride = AtomicValue<ChatServiceAuth>(.implicit(), lock: .init())

    override fileprivate func makeRequestInternal(_ request: TSRequest, requestId: UInt64) async throws -> HTTPResponse {
        var httpHeaders = request.headers
        try request.applyAuth(to: &httpHeaders, socketAuth: authOverride.get())

        let body: Data
        switch request.body {
        case .data(let bodyData):
            body = bodyData
        case .parameters(let bodyParameters):
            // TODO: Do we need body & headers for requests with no parameters?
            do {
                body = try TSRequest.Body.encodedParameters(bodyParameters)
            } catch {
                owsFailDebug("[\(requestId)]: \(error).", logger: request.logger)
                throw OWSHTTPError.invalidRequest
            }

            // If we're going to use the json serialized parameters as our body, we should overwrite
            // the Content-Type on the request.
            httpHeaders["Content-Type"] = "application/json"
        }

        let requestUrl = request.url
        owsAssertDebug(requestUrl.scheme == nil)
        owsAssertDebug(requestUrl.host == nil)
        owsAssertDebug(!requestUrl.path.hasPrefix("/"))

        guard let httpMethod = request.method.nilIfEmpty else {
            throw OWSHTTPError.invalidRequest
        }

        let libsignalRequest = ChatConnection.Request(method: httpMethod, pathAndQuery: "/\(requestUrl.relativeString)", headers: httpHeaders.headers, body: body, timeout: request.timeoutInterval)

        let chatService = await getOpenConnectionAfterHavingWaited()
        guard let chatService else {
            throw OWSHTTPError.networkFailure(.genericFailure)
        }

        let connectionInfo: ConnectionInfo
        let response: ChatConnection.Response
        do {
            connectionInfo = chatService.info()
            response = try await chatService.send(libsignalRequest)
        } catch {
            switch error {
            case SignalError.connectionTimeoutError(_), SignalError.requestTimeoutError:
                throw handleRequestTimeout(usingChatService: chatService)
            case SignalError.webSocketError(_), SignalError.possibleCaptiveNetwork(_), SignalError.connectionFailed(_), SignalError.chatServiceInactive:
                throw OWSHTTPError.networkFailure(.genericFailure)
            case SignalError.connectionInvalidated:
                throw OWSHTTPError.networkFailure(.wrappedFailure(error))
            case is CancellationError:
                throw error
            default:
                owsFailDebug("[\(requestId)] failed with an unexpected error: \(error)", logger: request.logger)
                throw OWSHTTPError.networkFailure(.genericFailure)
            }
        }

        if DebugFlags.internalLogging {
            request.logger.info("received response for requestId: \(requestId), message: \(response.message), route: \(connectionInfo)")
        }

#if TESTABLE_BUILD
        if response.status / 100 != 2 {
            HTTPUtils.logCurl(for: request)
        }
#endif

        let headers = HttpHeaders(httpHeaders: response.headers, overwriteOnConflict: false)
        return try await handleRequestResponse(
            requestUrl: request.url,
            responseStatus: Int(response.status),
            responseHeaders: headers,
            responseData: response.body,
        )
    }

    private func handleRequestTimeout(usingChatService chatService: Connection) -> OWSHTTPError {
        // cycleSocket(), but only if the chatService we just used is the one that's still connected.
        self.serialQueue.async { [weak chatService] in
            if let chatService, self.connection.isActive(chatService) {
                self.disconnectIfNeeded()
            }
        }
        applyDesiredSocketState()
        return OWSHTTPError.networkFailure(.genericTimeout)
    }

    func connectionWasInterrupted(_ service: Connection, error: Error?) {
        self.serialQueue.async { [self] in
            guard connection.isActive(service) else {
                // Already done with this service.
                if let error {
                    Logger.warn("\(logPrefix) previous service was disconnected: \(error)")
                }
                return
            }

            if let error {
                Logger.warn("\(logPrefix) disconnected: \(error)")
            } else {
                owsFailDebug("\(logPrefix) libsignal disconnected us without being asked")
            }

            connection = .closed(task: nil)
            OutageDetection.shared.reportConnectionFailure()
            self.reconnectAfterFailure()
        }
    }

    private var mostRecentFailureDate: MonotonicDate?
    private var consecutiveFailureCount = 0

    private func reconnectAfterFailure() {
        assertOnQueue(serialQueue)

        let now = MonotonicDate()

        // If it's been "a while" since the most recent failure, reset the counter.
        // If "a while" is shorter than the maximum exponential backoff, the
        // counter may be inadvertently reset while trying to back off.
        if let mostRecentFailureDate, (now - mostRecentFailureDate).seconds > (2 * Self.socketReconnectDelay) {
            Logger.info("Resetting connection backoff state due to elapsed time.")
            self.consecutiveFailureCount = 0
        }
        self.mostRecentFailureDate = now

        let reconnectDelay: TimeInterval
        if self.consecutiveFailureCount == 0 {
            // Reconnect immediately after the first failure.
            reconnectDelay = 0
        } else {
            reconnectDelay = OWSOperation.retryIntervalForExponentialBackoff(
                failureCount: self.consecutiveFailureCount,
                maxAverageBackoff: Self.socketReconnectDelay,
            )
        }
        self.consecutiveFailureCount += 1

        let formattedReconnectDelay = String(format: "%.1f", reconnectDelay)
        Logger.info("Scheduling reconnect after \(formattedReconnectDelay)s")

        // Wait a few seconds before retrying to reduce server load.
        self.serialQueue.asyncAfter(deadline: .now() + reconnectDelay) { [weak self] in
            self?._applyDesiredSocketState()
        }
    }

    func withLibsignalConnection<Output>(
        timeout: TimeInterval = .infinity,
        _ callback: @escaping (Connection) async throws -> Output,
    ) async throws -> Output {
        try await waitUntilReadyAndPerformRequest {
            guard let service = await getOpenConnectionAfterHavingWaited() else {
                throw SignalError.chatServiceInactive("no connection to chat server")
            }
            try Task.checkCancellation()
            do {
                return try await withCooperativeTimeout(seconds: timeout) {
                    return try await callback(service)
                }
            } catch is CooperativeTimeoutError {
                throw self.handleRequestTimeout(usingChatService: service)
            }
        }
    }
}

class OWSUnauthConnectionUsingLibSignal: OWSChatConnectionUsingLibSignal<UnauthenticatedChatConnection> {
    init(libsignalNet: Net, appExpiry: AppExpiry, appReadiness: AppReadiness, db: any DB) {
        super.init(libsignalNet: libsignalNet, type: .unidentified, appExpiry: appExpiry, appReadiness: appReadiness, db: db)
    }

    override fileprivate var connection: ConnectionState {
        didSet {
            if case .open(let service) = connection {
                service.start(listener: self)
            }
        }
    }

    override func connectChatService(token: NSObject) async throws -> UnauthenticatedChatConnection {
        return try await libsignalNet.connectUnauthenticatedChat(languages: Array(HttpHeaders.topPreferredLanguages()))
    }
}

class OWSAuthConnectionUsingLibSignal: OWSChatConnectionUsingLibSignal<AuthenticatedChatConnection>, ChatConnectionListener {
    private var _hasEmptiedInitialQueue = false
    override var hasEmptiedInitialQueue: Bool {
        get async {
            return await withCheckedContinuation { continuation in
                serialQueue.async {
                    continuation.resume(returning: self._hasEmptiedInitialQueue)
                }
            }
        }
    }

    private var _keepaliveSenderTask: Task<Void, Never>?
    private var keepaliveSenderTask: Task<Void, Never>? {
        get {
            assertOnQueue(serialQueue)
            return _keepaliveSenderTask
        }
        set {
            assertOnQueue(serialQueue)
            _keepaliveSenderTask?.cancel()
            _keepaliveSenderTask = newValue
        }
    }

    var onRegistrationStateChange: ((_ isDelinkedOrDeregisterd: Bool, _ tx: DBWriteTransaction) -> Void)?

    private let accountManager: TSAccountManager
    private let inactivePrimaryDeviceStore: InactivePrimaryDeviceStore

    init(
        libsignalNet: Net,
        accountManager: TSAccountManager,
        appContext: any AppContext,
        appExpiry: AppExpiry,
        appReadiness: AppReadiness,
        db: any DB,
        inactivePrimaryDeviceStore: InactivePrimaryDeviceStore,
    ) {
        self.accountManager = accountManager
        self.inactivePrimaryDeviceStore = inactivePrimaryDeviceStore

        let priority: Int
        switch appContext.type {
        case .share: priority = 1
        case .main: priority = 2
        case .nse: priority = 3
        }
        let priorityCount = 3
        self.connectionLock = ConnectionLock(filePath: appContext.appSharedDataDirectoryPath().appendingPathComponent("chat-connection.lock"), priority: priority, of: priorityCount)

        super.init(libsignalNet: libsignalNet, type: .identified, appExpiry: appExpiry, appReadiness: appReadiness, db: db)
    }

    deinit {
        self.connectionLock.close()
    }

    override fileprivate func appDidBecomeReady() {
        super.appDidBecomeReady()

        NotificationCenter.default.addObserver(
            self,
            selector: #selector(registrationStateDidChange),
            name: .registrationStateDidChange,
            object: nil,
        )
        NotificationCenter.default.addObserver(
            self,
            selector: #selector(storiesEnabledStateDidChange),
            name: .storiesEnabledStateDidChange,
            object: nil,
        )
    }

    @objc
    private func registrationStateDidChange(_ notification: NSNotification) {
        AssertIsOnMainThread()

        updateCanOpenWebSocket()
    }

    @objc
    private func storiesEnabledStateDidChange(_ notification: NSNotification) {
        AssertIsOnMainThread()

        cycleSocket()
    }

    override func _canOpenWebSocketError() -> (any Error)? {
        if let error = super._canOpenWebSocketError() {
            return error
        }
        guard accountManager.registrationStateWithMaybeSneakyTransaction.isRegistered || registrationOverride else {
            return NotRegisteredError()
        }
        return nil
    }

    private var registrationOverride = false

    func setRegistrationOverride(_ chatServiceAuth: ChatServiceAuth) async {
        await withCheckedContinuation { continuation in
            serialQueue.async {
                // Set the chatServiceAuth first to ensure it's accessible when
                // setRegistrationOverride initiates a connection.
                self.authOverride.set(chatServiceAuth)
                self._setRegistrationOverride(true)
                continuation.resume()
            }
        }
    }

    fileprivate func _setRegistrationOverride(_ value: Bool) {
        assertOnQueue(serialQueue)
        self.registrationOverride = value
        self._updateCanOpenWebSocket()
    }

    func clearRegistrationOverride() async {
        await withCheckedContinuation { continuation in
            serialQueue.async {
                self._setRegistrationOverride(false)
                continuation.resume()
            }
        }

        // Most of the time, this will be a no-op because the connection will
        // remain open, but if we are closing it (likely due to an error), we want
        // to wait until it's closed before continuing...
        await waitForDisconnectIfClosed()

        // ...to ensure that we don't clear authOverride in the middle of a
        // connection attempt.
        self.authOverride.set(.implicit())
    }

    override fileprivate func connectChatService(token: NSObject) async throws -> AuthenticatedChatConnection {
        try await self.acquireConnectionLock()

        let username: String?
        let password: String?
        switch self.authOverride.get().credentials {
        case .implicit:
            (username, password) = db.read { tx in
                (accountManager.storedServerUsername(tx: tx), accountManager.storedServerAuthToken(tx: tx))
            }
        case .explicit(let _username, let _password):
            username = _username
            password = _password
        }

        // Note that we still try to connect for an unregistered user, so that we get a consistent error thrown.
        do {
            return try await libsignalNet.connectAuthenticatedChat(
                username: username ?? "",
                password: password ?? "",
                receiveStories: StoryManager.areStoriesEnabled,
                languages: Array(HttpHeaders.topPreferredLanguages()),
            )
        } catch {
            switch error {
            case SignalError.deviceDeregistered:
                serialQueue.async {
                    if self.connection.isCurrentlyConnecting(token) {
                        self._setRegistrationOverride(false)
                        self.db.write { tx in
                            self.onRegistrationStateChange?(true, tx)
                        }
                    }
                }
            default:
                break
            }
            throw error
        }
    }

    override fileprivate var connection: ConnectionState {
        get {
            return super.connection
        }
        set {
            assertOnQueue(serialQueue)
            let updatedValue: ConnectionState
            if case .closed(let task) = newValue {
                updatedValue = .closed(task: Task {
                    await task?.value
                    self.releaseConnectionLock()
                })
            } else {
                updatedValue = newValue
            }
            super.connection = updatedValue
            switch updatedValue {
            case .connecting(token: _, task: _):
                break
            case .open(let service):
                // Note that we don't get callbacks until this point.
                service.start(listener: self)
                if accountManager.registrationStateWithMaybeSneakyTransaction.isDeregistered {
                    db.write { tx in
                        self.onRegistrationStateChange?(false, tx)
                    }
                }
                keepaliveSenderTask = makeKeepaliveTask(service)
            case .closed(task: _):
                keepaliveSenderTask = nil
                _hasEmptiedInitialQueue = false
            }
        }
    }

    private let connectionLock: ConnectionLock
    private let heldConnectionLock = AtomicValue<ConnectionLock.HeldLock?>(nil, lock: .init())

    private func acquireConnectionLock() async throws {
        owsPrecondition(self.heldConnectionLock.get() == nil)
        let newValue = try await self.connectionLock.lock(onInterrupt: (self.serialQueue, {
            Logger.warn("Cycling the socket because the connection lock was interrupted")
            self._cycleSocket()
        }))
        let oldValue = self.heldConnectionLock.swap(newValue)
        owsPrecondition(oldValue == nil)
    }

    private func releaseConnectionLock() {
        let oldValue = self.heldConnectionLock.swap(nil)
        // We might be canceled while trying to acquire the lock, and we won't have
        // a lock that needs to be released in that case.
        if let oldValue {
            self.connectionLock.unlock(oldValue)
        }
    }

    /// Starts a task to call `/v1/keepalive` at regular intervals to allow the server to do some consistency checks.
    ///
    /// This is on top of the websocket pings libsignal already uses to keep connections alive.
    func makeKeepaliveTask(_ chat: AuthenticatedChatConnection) -> Task<Void, Never> {
        let keepaliveInterval: TimeInterval = 30
        return Task(priority: .low) { [logPrefix = self.logPrefix, weak chat] in
            while true {
                do {
                    // This does not quite send keepalives "every 30 seconds".
                    // Instead, it sends the next keepalive *at least 30 seconds* after the *response* for the previous one arrives.
                    try await Task.sleep(nanoseconds: keepaliveInterval.clampedNanoseconds)
                    guard let chat else {
                        // We've disconnected.
                        return
                    }

                    // Skip the full overhead of makeRequest(...).
                    // We don't need keepalives to count as background activity or anything like that.
                    // This 30-second timeout doesn't inherently need to match the send interval above,
                    // but neither do we need an especially tight timeout here either.
                    let request = ChatConnection.Request(method: "GET", pathAndQuery: "/v1/keepalive", timeout: 30)
                    Logger.debug("\(logPrefix) Sending /v1/keepalive")
                    _ = try await chat.send(request)

                } catch is CancellationError,
                    SignalError.chatServiceInactive(_)
                {
                    // No action necessary, we're done with this service.
                    return
                } catch SignalError.rateLimitedError(retryAfter: let delay, message: _) {
                    // Not likely to happen, but best to be careful about it if it does.
                    if delay > keepaliveInterval {
                        // Wait out the part of the delay longer than 30s.
                        // Ignore cancellation here; when we get back to the top of the loop we'll check it then.
                        _ = try? await Task.sleep(nanoseconds: (delay - keepaliveInterval).clampedNanoseconds)
                    }
                } catch {
                    // Also no action necessary! Log just in case the failure has something interesting going on,
                    // but continue to rely on libsignal reporting disconnects via delegate callback.
                    // Importantly, we will continue to send keepalives until disconnected, in case this was a temporary thing.
                    Logger.info("\(logPrefix) /v1/keepalive failed: \(error)")
                }
            }
        }
    }

    private func handleInactivePrimaryDeviceAlert(newInactivePrimaryDevice: Bool) {
        let storedInactivePrimaryDevice = db.read { tx in
            return inactivePrimaryDeviceStore.valueForInactivePrimaryDeviceAlert(transaction: tx)
        }

        guard storedInactivePrimaryDevice != newInactivePrimaryDevice else {
            return
        }

        Logger.info("Received new value for inactive primary device alert: \(newInactivePrimaryDevice)")

        db.write { transaction in
            inactivePrimaryDeviceStore.setValueForInactivePrimaryDeviceAlert(value: newInactivePrimaryDevice, transaction: transaction)
        }

        // Megaphones might load before we setup the OWSChatConnection
        // so we should notify the UI that the value has changed.
        NotificationCenter.default.postOnMainThread(name: .inactivePrimaryDeviceChanged, object: nil)
    }

    func chatConnection(_ chat: AuthenticatedChatConnection, didReceiveAlerts alerts: [String]) {
        self.serialQueue.async { [self] in
            guard self.connection.isActive(chat) else {
                // We have since disconnected from the chat service instance that reported the alerts.
                return
            }
            var alertSet = Set(alerts)

            handleInactivePrimaryDeviceAlert(newInactivePrimaryDevice: alertSet.contains(AlertType.idlePrimaryDevice.rawValue))

            alertSet.remove(AlertType.idlePrimaryDevice.rawValue)
            if !alertSet.isEmpty {
                Logger.warn("ignoring \(alertSet.count) alerts from the server")
            }
        }
    }

    func chatConnection(_ chat: AuthenticatedChatConnection, didReceiveIncomingMessage envelope: Data, serverDeliveryTimestamp: UInt64, sendAck: @escaping () throws -> Void) {
        let messageProcessor = SSKEnvironment.shared.messageProcessorRef
        messageProcessor.enqueueReceivedEnvelopeData(
            envelope,
            serverDeliveryTimestamp: serverDeliveryTimestamp,
            envelopeSource: .websocketIdentified,
        ) {
            do {
                // Note that this does not wait for a response.
                try sendAck()
            } catch {
                Logger.warn("Failed to ack message with serverTimestamp \(serverDeliveryTimestamp): \(error)")
            }
        }
    }

    func chatConnectionDidReceiveQueueEmpty(_ chat: AuthenticatedChatConnection) {
        // We need to "flush" (i.e., "jump through") the enqueueing queue to ensure
        // that all previously-enqueued messages (see prior method) are enqueued
        // for processing before we: a) mark the queue as empty, b) notify.
        //
        // The socket might close and re-open while we're flushing the queue, so
        // we make sure it's still active before marking the queue as empty.
        let messageProcessor = SSKEnvironment.shared.messageProcessorRef
        messageProcessor.flushEnqueuingQueue {
            self.serialQueue.async {
                guard self.connection.isActive(chat) else {
                    // We have since disconnected from the chat service instance that reported the empty queue.
                    return
                }
                let alreadyEmptied = self._hasEmptiedInitialQueue
                self._hasEmptiedInitialQueue = true

                if !alreadyEmptied {
                    // This notification is used to wake up anything waiting for hasEmptiedInitialQueue.
                    self.notifyStatusChange(newState: self._currentState)
                }
            }
        }
    }
}