Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/Signal/Provisioning/ProvisioningSocketManager.swift
1 views
//
// Copyright 2025 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation
import LibSignalClient
public import SignalServiceKit

// MARK: - DecryptableProvisioningMessage

/// Protocol describing a message that can be sent and received through
/// the ProvisioningSocket.
public protocol DecryptableProvisioningMessage {

    /// Alias for the proto type that wraps the message.
    associatedtype Envelope: ProvisioningEnvelope

    init(plaintext: Data) throws
}

/// Protocol describing the envelope that contains a `DecryptableProvisioningMessage`
/// and can be sent and received through the ProvisioningSocket.
public protocol ProvisioningEnvelope {

    init(serializedData: Data) throws

    /// Encrypted payload containing the `DecryptableProvisioningMessage`
    var body: Data { get }

    /// The public key used to encrypt `bodyData`
    var publicKey: Data { get }
}

// MARK: - DecryptableProvisioningMessage conformance

extension ProvisioningProtoProvisionEnvelope: ProvisioningEnvelope {}
extension LinkingProvisioningMessage: DecryptableProvisioningMessage {
    public typealias Envelope = ProvisioningProtoProvisionEnvelope
}

extension RegistrationProvisioningEnvelope: ProvisioningEnvelope {}
extension RegistrationProvisioningMessage: DecryptableProvisioningMessage {
    public typealias Envelope = RegistrationProvisioningEnvelope
}

// MARK: - ProvisioningSocketManager

@MainActor
public protocol ProvisioningSocketManagerUIDelegate: AnyObject {
    func provisioningSocketManager(
        _ provisioningSocketManager: ProvisioningSocketManager,
        didUpdateProvisioningURL url: URL,
    )

    func provisioningSocketManagerDidPauseQRRotation(
        _ provisioningSocketManager: ProvisioningSocketManager,
    )
}

public class ProvisioningSocketManager: ProvisioningSocketDelegate {
    private struct ProvisioningUrlParams {
        let uuid: String
        let cipher: ProvisioningCipher
    }

    private struct DecryptableProvisionEnvelope {
        private let cipher: ProvisioningCipher
        private let encryptedEnvelope: Data

        init(cipher: ProvisioningCipher, data: Data) {
            self.cipher = cipher
            self.encryptedEnvelope = data
        }

        func decrypt<ProvisioningMessage: DecryptableProvisioningMessage>() throws -> ProvisioningMessage {
            let envelope = try ProvisioningMessage.Envelope(serializedData: encryptedEnvelope)
            let data = try cipher.decrypt(data: envelope.body, theirPublicKey: try PublicKey(envelope.publicKey))
            return try ProvisioningMessage(plaintext: data)
        }
    }

    /// Represents an attempt to communicate with the primary.
    private struct ProvisioningUrlCommunicationAttempt {
        /// The socket from which we hope to receive a provisioning envelope
        /// from a primary.
        let socket: ProvisioningSocket
        /// The cipher to be used in encrypting the provisioning envelope.
        let cipher: ProvisioningCipher
        /// A continuation waiting for us to fetch the parameters necessary for
        /// us to construct a provisioning URL, which we will present to the
        /// primary via QR code. The provisioning URL will contain the necessary
        /// data for the primary to send us a provisioning envelope over our
        /// provisioning socket, via the server.
        var fetchProvisioningUrlParamsContinuation: CheckedContinuation<ProvisioningUrlParams, Error>?
    }

    private var urlCommunicationAttempts: AtomicValue<[ProvisioningUrlCommunicationAttempt]> = AtomicValue([], lock: .init())
    private var awaitProvisionEnvelopeContinuation: AtomicValue<CheckedContinuation<DecryptableProvisionEnvelope, Error>?> = AtomicValue(nil, lock: .init())

    public var delegate: ProvisioningSocketManagerUIDelegate?

    private let linkType: DeviceProvisioningURL.LinkType
    public init(linkType: DeviceProvisioningURL.LinkType) {
        self.linkType = linkType
    }

    // Start:
    // rotate the sockets.  Call back to delegate when the socket updates
    // Call back whit
    public func start() {
        rotate()
    }

    public func reset() {
        stop()
        start()
    }

    public func stop() {
        rotationTask?.cancel()
        rotationTask = nil
    }

    // MARK: ProvisioningSocketDelegate

    public func provisioningSocket(
        _ provisioningSocket: ProvisioningSocket,
        didReceiveProvisioningUuid provisioningUuid: String,
    ) {
        urlCommunicationAttempts.update { attempts in
            let matchingAttemptIndex = attempts.firstIndex {
                $0.socket.id == provisioningSocket.id
            }

            guard
                let matchingAttemptIndex,
                let fetchParamsContinuation = attempts[matchingAttemptIndex].fetchProvisioningUrlParamsContinuation
            else {
                owsFailDebug("Got provisioning UUID for unknown socket!")
                return
            }

            attempts[matchingAttemptIndex].fetchProvisioningUrlParamsContinuation = nil

            fetchParamsContinuation.resume(
                returning: ProvisioningUrlParams(
                    uuid: provisioningUuid,
                    cipher: attempts[matchingAttemptIndex].cipher,
                ),
            )
        }
    }

    public func provisioningSocket(
        _ provisioningSocket: ProvisioningSocket,
        didReceiveEnvelopeData data: Data,
    ) {
        var cipherForSocket: ProvisioningCipher?

        /// We've gotten a provisioning message, from one of our attempts'
        /// sockets. (We don't care which one – it's whichever one the primary
        /// scanned and sent an envelope through!)
        for attempt in urlCommunicationAttempts.get() {
            /// After we get a provisioning message, we don't expect anything
            /// from this or any other socket.
            attempt.socket.disconnect(code: .normalClosure)

            if provisioningSocket.id == attempt.socket.id {
                owsAssertDebug(
                    cipherForSocket == nil,
                    "Extracting cipher, but unexpectedly already set from previous match!",
                )

                cipherForSocket = attempt.cipher
            }
        }

        guard let cipherForSocket else {
            owsFailDebug("Missing cipher for socket that received envelope!")
            return
        }

        awaitProvisionEnvelopeContinuation.update { existingContinuation in
            guard let continuation = existingContinuation else {
                owsFailDebug("Got provision envelope, but missing continuation or cipher!")
                return
            }

            stop()
            let envelope = DecryptableProvisionEnvelope(cipher: cipherForSocket, data: data)
            continuation.resume(returning: envelope)

            existingContinuation = nil
        }
    }

    public func provisioningSocket(_ provisioningSocket: ProvisioningSocket, didError error: Error) {
        if
            let webSocketError = error as? WebSocketError,
            case .closeError = webSocketError
        {
            Logger.info("Provisioning socket closed...")
        } else {
            Logger.error("\(error)")
        }

        urlCommunicationAttempts.update { attempts in
            let matchingAttemptIndex = attempts.firstIndex {
                $0.socket.id == provisioningSocket.id
            }

            guard let matchingAttemptIndex else {
                owsFailDebug("Got provisioning UUID for unknown socket!")
                return
            }

            attempts[matchingAttemptIndex].fetchProvisioningUrlParamsContinuation?.resume(throwing: error)
            attempts[matchingAttemptIndex].fetchProvisioningUrlParamsContinuation = nil
        }
    }

    // MARK: -

    private static func buildProvisioningUrl(
        type: DeviceProvisioningURL.LinkType,
        params: ProvisioningUrlParams,
    ) throws -> URL {

        let shouldLinkAndSync: Bool = {
            switch DependenciesBridge.shared.tsAccountManager.registrationStateWithMaybeSneakyTransaction {
            case .unregistered:
                return true
            case .delinked, .relinking:
                // We don't allow relinking secondaries to link'n'sync.
                return false
            case .transferred:
                // Transferring back to a transfered device will result in hitting this.
                return false
            case
                .registered,
                .provisioned,
                .reregistering,
                .transferringIncoming,
                .transferringLinkedOutgoing,
                .transferringPrimaryOutgoing,
                .deregistered:
                owsFailDebug("How are we provisioning from this state?")
                return false
            }
        }()

        var capabilities = [DeviceProvisioningURL.Capability]()
        if shouldLinkAndSync {
            capabilities.append(DeviceProvisioningURL.Capability.linknsync)
        }

        return try DeviceProvisioningURL(
            type: type,
            ephemeralDeviceId: params.uuid,
            publicKey: params.cipher.ourPublicKey,
            capabilities: capabilities,
        ).buildUrl()
    }

    /// Opens a new provisioning socket. Note that the server closes
    /// provisioning sockets after 90s, so callers must ensure that they do not
    /// need the socket longer than that.
    ///
    /// - Returns
    /// A provisioning URL containing information about the now-opened
    /// provisioning socket.
    func openNewProvisioningSocket() async throws -> URL {
        let provisioningUrlParams: ProvisioningUrlParams = try await withCheckedThrowingContinuation { paramsContinuation in
            let newAttempt = ProvisioningUrlCommunicationAttempt(
                socket: ProvisioningSocket(),
                cipher: ProvisioningCipher(),
                fetchProvisioningUrlParamsContinuation: paramsContinuation,
            )

            urlCommunicationAttempts.update { $0.append(newAttempt) }

            newAttempt.socket.delegate = self
            newAttempt.socket.connect()
        }

        return try Self.buildProvisioningUrl(type: linkType, params: provisioningUrlParams)
    }

    public func waitForMessage<ProvisioningMessage: DecryptableProvisioningMessage>() async throws -> ProvisioningMessage {
        let decryptableProvisionEnvelope: DecryptableProvisionEnvelope = try await withCheckedThrowingContinuation { newContinuation in
            awaitProvisionEnvelopeContinuation.update { existingContinuation in
                guard existingContinuation == nil else {
                    newContinuation.resume(throwing: OWSAssertionError("Attempted to await provisioning multiple times!"))
                    return
                }
                existingContinuation = newContinuation
            }
        }
        return try decryptableProvisionEnvelope.decrypt()
    }

    private var rotationTask: Task<Void, Never>?
    private func rotate() {
        rotationTask?.cancel()
        rotationTask = Task {
            /// Every 45s, five times, rotate the provisioning socket for which
            /// we're displaying a QR code. If we fail, or once we've exhausted
            /// the five rotations, fall back to showing a manual "refresh"
            /// button.
            ///
            /// Note that the server will close provisioning sockets after 90s,
            /// so hopefully rotating every 45s means no primary will ever end
            /// up trying to send into a closed socket.
            do {
                for _ in 0..<5 {
                    let provisioningUrl = try await self.openNewProvisioningSocket()

                    try Task.checkCancellation()

                    await delegate?.provisioningSocketManager(self, didUpdateProvisioningURL: provisioningUrl)

                    try await Task.sleep(nanoseconds: 45 * NSEC_PER_SEC)

                    try Task.checkCancellation()
                }
                await delegate?.provisioningSocketManagerDidPauseQRRotation(self)
            } catch is CancellationError {
                // We've been canceled; bail! It's the canceler's responsibility
                // to make sure the UI is updated.
                return
            } catch {
                // Fall through as if we'd exhausted our rotations.
            }
        }
    }
}