Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Account/PreKeys/PreKeyTaskManager.swift
1 views
//
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import LibSignalClient

/// Used by ``PreKeyManagerImpl`` to actually execute prekey tasks.
/// Stateless! All state exists within each task function. The only instance vars are dependencies.
///
/// A PreKey task is broken down into the following steps:
/// 1. Fetch the identity key.  If this is a create operation, create the key, otherwise error if missing
/// 2. If registered and not a create operation, check that message processing is idle before continuing
/// 3. Check the server for the number of remaining PreKeys (skip on create/force refresh)
/// 4. Run any logic to determine what requested operations are really necessary
/// 5. Generate the necessary keys for the resulting operations
/// 6. Upload these new keys to the server (except for registration/provisioning)
/// 7. Store the new keys and run any cleanup logic
struct PreKeyTaskManager {
    private let logger = PrefixedLogger(prefix: "[PreKey]")

    private let apiClient: PreKeyTaskAPIClient
    private let dateProvider: DateProvider
    private let db: any DB
    private let identityKeyMismatchManager: IdentityKeyMismatchManager
    private let identityManager: OWSIdentityManager
    private let messageProcessor: MessageProcessor
    private let protocolStoreManager: SignalProtocolStoreManager
    private let remoteConfigProvider: any RemoteConfigProvider
    private let tsAccountManager: TSAccountManager

    init(
        apiClient: PreKeyTaskAPIClient,
        dateProvider: @escaping DateProvider,
        db: any DB,
        identityKeyMismatchManager: IdentityKeyMismatchManager,
        identityManager: OWSIdentityManager,
        messageProcessor: MessageProcessor,
        protocolStoreManager: SignalProtocolStoreManager,
        remoteConfigProvider: any RemoteConfigProvider,
        tsAccountManager: TSAccountManager,
    ) {
        self.apiClient = apiClient
        self.dateProvider = dateProvider
        self.db = db
        self.identityKeyMismatchManager = identityKeyMismatchManager
        self.identityManager = identityManager
        self.messageProcessor = messageProcessor
        self.protocolStoreManager = protocolStoreManager
        self.remoteConfigProvider = remoteConfigProvider
        self.tsAccountManager = tsAccountManager
    }

    enum Constants {
        // We generate 100 one-time prekeys at a time.
        // Replenish whenever 10 or less remain
        static let EphemeralPreKeysMinimumCount: UInt = 10

        static let PqPreKeysMinimumCount: UInt = 10

        // Signed prekeys should be rotated every at least every 2 days
        static let SignedPreKeyRotationTime: TimeInterval = 2 * .day

        static let LastResortPqPreKeyRotationTime: TimeInterval = 2 * .day
    }

    enum Error: Swift.Error {
        case noIdentityKey
        case notRegistered
        case cancelled
    }

    // MARK: - API

    // MARK: Registration/Provisioning

    /// When we register, we create a new identity key and other keys. So this variant:
    /// CAN create a new identity key (or uses any existing one)
    /// ALWAYS changes the targeted keys (regardless of current key state)
    func createForRegistration() async -> RegistrationPreKeyUploadBundles {
        logger.info("Create for registration")
        let (aciBundle, pniBundle) = await db.awaitableWrite { tx in
            let aciBundle = self.generateKeysForRegistration(identity: .aci, tx: tx)
            let pniBundle = self.generateKeysForRegistration(identity: .pni, tx: tx)
            self.persistKeysPriorToUpload(bundle: aciBundle, tx: tx)
            self.persistKeysPriorToUpload(bundle: pniBundle, tx: tx)
            return (aciBundle, pniBundle)
        }
        return .init(aci: aciBundle, pni: pniBundle)
    }

    /// When we provision, we use the primary's identity key to create other keys. So this variant:
    /// NEVER creates an identity key
    /// ALWAYS changes the targeted keys (regardless of current key state)
    func createForProvisioning(
        aciIdentityKeyPair: ECKeyPair,
        pniIdentityKeyPair: ECKeyPair,
    ) async -> RegistrationPreKeyUploadBundles {
        logger.info("Create for provisioning")
        let (aciBundle, pniBundle) = await db.awaitableWrite { tx in
            let aciBundle = self.generateKeysForProvisioning(
                identity: .aci,
                identityKeyPair: aciIdentityKeyPair,
                tx: tx,
            )
            let pniBundle = self.generateKeysForProvisioning(
                identity: .pni,
                identityKeyPair: pniIdentityKeyPair,
                tx: tx,
            )
            self.persistKeysPriorToUpload(bundle: aciBundle, tx: tx)
            self.persistKeysPriorToUpload(bundle: pniBundle, tx: tx)
            return (aciBundle, pniBundle)
        }
        return .init(aci: aciBundle, pni: pniBundle)
    }

    func persistAfterRegistration(
        bundles: RegistrationPreKeyUploadBundles,
        uploadDidSucceed: Bool,
    ) async {
        logger.info("Persist after provisioning")
        await db.awaitableWrite { tx in
            if uploadDidSucceed {
                self.persistStateAfterUpload(bundle: bundles.aci, tx: tx)
                self.persistStateAfterUpload(bundle: bundles.pni, tx: tx)
            } else {
                // Wipe the keys.
                self.wipeKeysAfterFailedRegistration(bundle: bundles.aci, tx: tx)
                self.wipeKeysAfterFailedRegistration(bundle: bundles.pni, tx: tx)
            }
        }
    }

    // MARK: Standard Operations

    /// When we create refresh keys (happens periodically) we should never change
    /// our identity key, but may rotate other keys depending on expiration. So this variant:
    /// CANNOT create a new identity key
    /// SOMETIMES changes the targeted keys (dependent on current key state)
    /// In other words, this variant can potentially no-op.
    func refresh(
        identity: OWSIdentity,
        targets: PreKeyTargets,
        force: Bool = false,
        auth: ChatServiceAuth,
    ) async throws {
        try Task.checkCancellation()
        try await waitForMessageProcessing(identity: identity)
        try Task.checkCancellation()

        let filteredTargets: PreKeyTargets
        if force {
            filteredTargets = targets
        } else {
            let ecCount: Int?
            let pqCount: Int?
            if targets.contains(target: .oneTimePreKey) || targets.contains(target: .oneTimePqPreKey) {
                (ecCount, pqCount) = try await self.apiClient.getAvailablePreKeys(for: identity)
            } else {
                // No need to fetch prekey counts.
                (ecCount, pqCount) = (nil, nil)
            }
            try Task.checkCancellation()

            filteredTargets = self.filterToNecessaryTargets(
                identity: identity,
                unfilteredTargets: targets,
                ecPreKeyRecordCount: ecCount,
                pqPreKeyRecordCount: pqCount,
            )
        }

        if filteredTargets.isEmpty {
            return
        }

        logger.info("[\(identity)] Refresh: [\(filteredTargets)]")
        let bundle = try await db.awaitableWrite { tx in
            let identityKeyPair = try self.requireIdentityKeyPair(for: identity, tx: tx)
            return self.createAndPersistPartialBundle(
                identity: identity,
                identityKeyPair: identityKeyPair,
                targets: filteredTargets,
                tx: tx,
            )
        }

        try Task.checkCancellation()
        try await uploadAndPersistBundle(bundle, auth: auth)
    }

    func createOneTimePreKeys(
        identity: OWSIdentity,
        auth: ChatServiceAuth,
    ) async throws {
        logger.info("[\(identity)] Create one-time prekeys")
        try Task.checkCancellation()
        let bundle = try await db.awaitableWrite { tx in
            let identityKeyPair = try self.requireIdentityKeyPair(for: identity, tx: tx)
            return self.createAndPersistPartialBundle(
                identity: identity,
                identityKeyPair: identityKeyPair,
                targets: [.oneTimePreKey, .oneTimePqPreKey],
                tx: tx,
            )
        }

        try Task.checkCancellation()
        try await uploadAndPersistBundle(bundle, auth: auth)
    }

    // MARK: - Private helpers

    // MARK: Per-identity registration generators

    private func generateKeysForRegistration(
        identity: OWSIdentity,
        tx: DBWriteTransaction,
    ) -> RegistrationPreKeyUploadBundle {
        return generateKeysForProvisioning(
            identity: identity,
            identityKeyPair: getOrCreateIdentityKeyPair(identity: identity, tx: tx),
            tx: tx,
        )
    }

    private func generateKeysForProvisioning(
        identity: OWSIdentity,
        identityKeyPair: ECKeyPair,
        tx: DBWriteTransaction,
    ) -> RegistrationPreKeyUploadBundle {
        let identityKey = identityKeyPair.keyPair.privateKey
        let protocolStore = self.protocolStoreManager.signalProtocolStore(for: identity)

        let signedPreKeyStore = protocolStore.signedPreKeyStore
        let signedPreKey = SignedPreKeyStoreImpl.generateSignedPreKey(
            keyId: signedPreKeyStore.allocatePreKeyId(tx: tx),
            signedBy: identityKey,
        )

        let kyberPreKeyStore = protocolStore.kyberPreKeyStore
        let lastResortPreKey = kyberPreKeyStore.generatePreKeyRecords(
            forPreKeyIds: kyberPreKeyStore.allocatePreKeyIds(count: 1, tx: tx),
            signedBy: identityKey,
        ).first!

        return RegistrationPreKeyUploadBundle(
            identity: identity,
            identityKeyPair: identityKeyPair,
            signedPreKey: signedPreKey,
            lastResortPreKey: lastResortPreKey,
        )
    }

    // MARK: Identity Key

    private func getOrCreateIdentityKeyPair(
        identity: OWSIdentity,
        tx: DBWriteTransaction,
    ) -> ECKeyPair {
        let existingKeyPair = identityManager.identityKeyPair(for: identity, tx: tx)
        if let identityKeyPair = existingKeyPair {
            return identityKeyPair
        }
        let identityKeyPair = identityManager.generateNewIdentityKeyPair()
        identityManager.setIdentityKeyPair(
            identityKeyPair,
            for: identity,
            tx: tx,
        )
        return identityKeyPair
    }

    func requireIdentityKeyPair(
        for identity: OWSIdentity,
        tx: DBReadTransaction,
    ) throws -> ECKeyPair {
        guard let identityKey = identityManager.identityKeyPair(for: identity, tx: tx) else {
            logger.warn("cannot perform operation for \(identity); missing identity key")
            throw Error.noIdentityKey
        }
        return identityKey
    }

    // MARK: Bundle construction

    private func createAndPersistPartialBundle(
        identity: OWSIdentity,
        identityKeyPair: ECKeyPair,
        targets: PreKeyTargets,
        tx: DBWriteTransaction,
    ) -> PartialPreKeyUploadBundle {
        let protocolStore = self.protocolStoreManager.signalProtocolStore(
            for: identity,
        )

        // Map the keys to the requested operation.  Create the necessary keys and
        // pass them along to be uploaded to the service/stored/accepted
        var signedPreKey: LibSignalClient.SignedPreKeyRecord?
        var preKeyRecords: [LibSignalClient.PreKeyRecord]?
        var lastResortPreKey: LibSignalClient.KyberPreKeyRecord?
        var pqPreKeyRecords: [LibSignalClient.KyberPreKeyRecord]?

        let identityKey = identityKeyPair.keyPair.privateKey

        targets.targets.forEach { target in
            switch target {
            case .oneTimePreKey:
                let preKeyIds = protocolStore.preKeyStore.allocatePreKeyIds(tx: tx)
                preKeyRecords = PreKeyStoreImpl.generatePreKeyRecords(forPreKeyIds: preKeyIds)
            case .signedPreKey:
                let preKeyId = protocolStore.signedPreKeyStore.allocatePreKeyId(tx: tx)
                signedPreKey = SignedPreKeyStoreImpl.generateSignedPreKey(keyId: preKeyId, signedBy: identityKey)
            case .oneTimePqPreKey:
                let preKeyIds = protocolStore.kyberPreKeyStore.allocatePreKeyIds(count: 100, tx: tx)
                pqPreKeyRecords = protocolStore.kyberPreKeyStore.generatePreKeyRecords(forPreKeyIds: preKeyIds, signedBy: identityKey)
            case .lastResortPqPreKey:
                let preKeyIds = protocolStore.kyberPreKeyStore.allocatePreKeyIds(count: 1, tx: tx)
                lastResortPreKey = protocolStore.kyberPreKeyStore.generatePreKeyRecords(forPreKeyIds: preKeyIds, signedBy: identityKey).first!
            }
        }
        let result = PartialPreKeyUploadBundle(
            identity: identity,
            signedPreKey: signedPreKey,
            preKeyRecords: preKeyRecords,
            lastResortPreKey: lastResortPreKey,
            pqPreKeyRecords: pqPreKeyRecords,
        )
        persistKeysPriorToUpload(bundle: result, tx: tx)
        return result
    }

    // MARK: Filtering (based on fetched prekey results)

    private func filterToNecessaryTargets(
        identity: OWSIdentity,
        unfilteredTargets: PreKeyTargets,
        ecPreKeyRecordCount: Int?,
        pqPreKeyRecordCount: Int?,
    ) -> PreKeyTargets {
        let protocolStore = self.protocolStoreManager.signalProtocolStore(for: identity)
        let (lastSuccessfulRotation, lastKyberSuccessfulRotation) = db.read { tx in
            let lastSuccessfulRotation = protocolStore.signedPreKeyStore.getLastSuccessfulRotationDate(tx: tx)
            let lastKyberSuccessfulRotation = protocolStore.kyberPreKeyStore.getLastSuccessfulRotationDate(tx: tx)
            return (lastSuccessfulRotation, lastKyberSuccessfulRotation)
        }

        // Take the gathered PreKeyState information and run it through
        // logic to determine what really needs to be updated.
        return unfilteredTargets.targets.reduce(into: []) { value, target in
            switch target {
            case .oneTimePreKey:
                guard let ecPreKeyRecordCount else {
                    logger.warn("Did not fetch prekey count, aborting.")
                    return
                }
                if ecPreKeyRecordCount < Constants.EphemeralPreKeysMinimumCount {
                    value.insert(target: target)
                }
            case .oneTimePqPreKey:
                guard let pqPreKeyRecordCount else {
                    logger.warn("Did not fetch pq prekey count, aborting.")
                    return
                }
                if pqPreKeyRecordCount < Constants.PqPreKeysMinimumCount {
                    value.insert(target: target)
                }
            case .signedPreKey:
                if
                    let lastSuccessfulRotation,
                    dateProvider().timeIntervalSince(lastSuccessfulRotation) < Constants.SignedPreKeyRotationTime
                {
                    // it's recent enough
                } else {
                    value.insert(target: target)
                }
            case .lastResortPqPreKey:
                if
                    let lastKyberSuccessfulRotation,
                    dateProvider().timeIntervalSince(lastKyberSuccessfulRotation) < Constants.LastResortPqPreKeyRotationTime
                {
                    // it's recent enough
                } else {
                    value.insert(target: target)
                }
            }
        }
    }

    // MARK: Message Processing

    /// Waits (potentially forever) for message processing. Supports cancellation.
    private func waitForMessageProcessing(identity: OWSIdentity) async throws(CancellationError) {
        switch identity {
        case .aci:
            // We can't change our ACI via a message, so there's no need to wait.
            return
        case .pni:
            // Our PNI might change via a change number message, so wait.
            break
        }

        try await messageProcessor.waitForFetchingAndProcessing()
    }

    // MARK: Persist

    private func persistKeysPriorToUpload(
        bundle: PreKeyUploadBundle,
        tx: DBWriteTransaction,
    ) {
        let protocolStore = protocolStoreManager.signalProtocolStore(for: bundle.identity)
        if let signedPreKeyRecord = bundle.getSignedPreKey() {
            protocolStore.signedPreKeyStore.storeSignedPreKey(signedPreKeyRecord, tx: tx)
        }
        if let lastResortPreKey = bundle.getLastResortPreKey() {
            protocolStore.kyberPreKeyStore.storePreKeyRecords([lastResortPreKey], isLastResort: true, tx: tx)
        }
        if let newPreKeyRecords = bundle.getPreKeyRecords() {
            protocolStore.preKeyStore.storePreKeyRecords(newPreKeyRecords, tx: tx)
        }
        if let pqPreKeyRecords = bundle.getPqPreKeyRecords() {
            protocolStore.kyberPreKeyStore.storePreKeyRecords(pqPreKeyRecords, isLastResort: false, tx: tx)
        }
    }

    private func persistStateAfterUpload(
        bundle: PreKeyUploadBundle,
        tx: DBWriteTransaction,
    ) {
        let protocolStore = protocolStoreManager.signalProtocolStore(for: bundle.identity)

        if let signedPreKeyRecord = bundle.getSignedPreKey() {
            protocolStore.signedPreKeyStore.setLastSuccessfulRotationDate(self.dateProvider(), tx: tx)
            protocolStore.signedPreKeyStore.setReplacedAtToNowIfNil(exceptFor: signedPreKeyRecord.id, tx: tx)
        }

        if let lastResortPreKey = bundle.getLastResortPreKey() {
            // Register a successful key rotation
            protocolStore.kyberPreKeyStore.setLastSuccessfulRotationDate(self.dateProvider(), tx: tx)
            protocolStore.kyberPreKeyStore.setReplacedAtToNowIfNil(exceptFor: [lastResortPreKey.id], isLastResort: true, tx: tx)
        }

        if let preKeyRecords = bundle.getPreKeyRecords() {
            protocolStore.preKeyStore.setReplacedAtToNowIfNil(exceptFor: preKeyRecords.map(\.id), tx: tx)
        }

        if let oneTimePreKeys = bundle.getPqPreKeyRecords() {
            protocolStore.kyberPreKeyStore.setReplacedAtToNowIfNil(exceptFor: oneTimePreKeys.map(\.id), isLastResort: false, tx: tx)
        }

        protocolStoreManager.preKeyStore.cullPreKeys(gracePeriod: gracePeriodBeforeMessageProcessing(), tx: tx)
    }

    /// The "grace period" to use when culling pre keys before we've finished
    /// processing messages. After we rotate pre keys, there might still be
    /// not-yet-received messages that we're about to receive that reference
    /// obsolete pre keys. We defer culling pre keys in this "grace period"
    /// until `cullStateAfterMessageProcessing` (which is typically called in
    /// quick succession but may take longer in pathological cases).
    private func gracePeriodBeforeMessageProcessing() -> TimeInterval {
        let messageQueueTime = remoteConfigProvider.currentConfig().messageQueueTime
        owsAssertDebug(.day <= messageQueueTime && messageQueueTime <= 90 * .day)
        return messageQueueTime.clamp(.day, 90 * .day)
    }

    /// Called after we've finished processing messages to cull any pre keys in
    /// the "grace period".
    private func cullStateAfterMessageProcessing(tx: DBWriteTransaction) {
        protocolStoreManager.preKeyStore.cullPreKeys(gracePeriod: 0, tx: tx)
    }

    private func wipeKeysAfterFailedRegistration(
        bundle: RegistrationPreKeyUploadBundle,
        tx: DBWriteTransaction,
    ) {
        let preKeyStore = protocolStoreManager.preKeyStore.forIdentity(bundle.identity)
        preKeyStore.removePreKey(in: .signed, keyId: bundle.signedPreKey.id, tx: tx)
        preKeyStore.removePreKey(in: .kyber, keyId: bundle.lastResortPreKey.id, tx: tx)
    }

    // MARK: Upload

    private func uploadAndPersistBundle(
        _ bundle: PreKeyUploadBundle,
        auth: ChatServiceAuth,
    ) async throws {
        let identity = bundle.identity
        let uploadResult = await upload(bundle: bundle, auth: auth)

        switch uploadResult {
        case .skipped:
            break
        case .success:
            logger.info("[\(identity)] Successfully uploaded prekeys")
            await db.awaitableWrite { tx in
                self.persistStateAfterUpload(bundle: bundle, tx: tx)
            }
            Task {
                try await self.messageProcessor.waitForFetchingAndProcessing()
                await self.db.awaitableWrite { tx in self.cullStateAfterMessageProcessing(tx: tx) }
            }
        case let .failure(error) where error.httpStatusCode == 422:
            // We think we might have an incorrect identity key -- check it and
            // deregister if it's wrong. We always eat this error because we want the
            // caller to see the `error` from `uploadResult`.
            try? await self.identityKeyMismatchManager.validateIdentityKey(for: identity)
            fallthrough
        case let .failure(error):
            logger.info("[\(identity)] Failed to upload prekeys")
            throw error
        }
    }

    private enum UploadResult {
        case success
        case skipped
        case failure(Swift.Error)
    }

    private func upload(
        bundle: PreKeyUploadBundle,
        auth: ChatServiceAuth,
    ) async -> UploadResult {
        // If there is nothing to update, skip this step.
        guard !bundle.isEmpty() else { return .skipped }

        logger.info("[\(bundle.identity)] uploading prekeys")

        do {
            try await self.apiClient.registerPreKeys(
                for: bundle.identity,
                signedPreKeyRecord: bundle.getSignedPreKey(),
                preKeyRecords: bundle.getPreKeyRecords(),
                pqLastResortPreKeyRecord: bundle.getLastResortPreKey(),
                pqPreKeyRecords: bundle.getPqPreKeyRecords(),
                auth: auth,
            )
            return .success
        } catch let error {
            return .failure(error)
        }
    }
}