Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/Signal/Registration/RegistrationCoordinatorLoader.swift
1 views
//
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

public import LibSignalClient
public import SignalServiceKit

/// Loads a `RegistrationCoordinator`.
/// This class exists separately from the coordinator itself so that we separate
/// state which determines whether we _need_ a coordinator from the coordinator itself.
/// When we instantiate a coordinator, its because we intend to use it; its entire lifecycle
/// can assume this and be simpler as a result.
public protocol RegistrationCoordinatorLoader {

    /// If the return value is non-nil, the user had an in-progress registration that can (typically, must) be restored.
    func restoreLastMode(transaction: DBReadTransaction) -> RegistrationMode?

    /// `desiredMode` may not be the mode we end up in; for example if we
    /// were in the middle of re-registration and try to change number, that will
    /// be disallowed and we will fall back to re-registration.
    func coordinator(
        forDesiredMode: RegistrationMode,
        transaction: DBWriteTransaction,
        logger: PrefixedLogger,
    ) -> RegistrationCoordinator

    /// If true, message processing should be paused due to an in-progress change number.
    func hasPendingChangeNumber(transaction: DBReadTransaction) -> Bool
}

public class RegistrationCoordinatorLoaderImpl: RegistrationCoordinatorLoader {

    public enum Mode: Codable {
        case registering(RegisteringState)
        case reRegistering(ReRegisteringState)
        case changingNumber(ChangeNumberState)

        public struct RegisteringState: Codable, Equatable {
            fileprivate init() {}
        }

        public struct ReRegisteringState: Codable, Equatable {
            public let e164: E164
            @AciUuid public var aci: Aci

            fileprivate init(e164: E164, aci: Aci) {
                self.e164 = e164
                self._aci = aci.codableUuid
            }
        }

        public struct ChangeNumberState: Codable {
            public let oldE164: E164
            public let oldAuthToken: String
            @AciUuid public var localAci: Aci
            public let localDeviceId: DeviceId

            public struct PendingPniState {
                public let newE164: E164
                public let pniIdentityKeyPair: ECKeyPair
                public let localDevicePniSignedPreKeyRecord: Result<LibSignalClient.SignedPreKeyRecord, DecodingError>
                public let localDevicePniPqLastResortPreKeyRecord: Result<LibSignalClient.KyberPreKeyRecord, DecodingError>
                public let localDevicePniRegistrationId: UInt32
            }

            public fileprivate(set) var pniState: PendingPniState?

            fileprivate init(
                oldE164: E164,
                oldAuthToken: String,
                localAci: Aci,
                localDeviceId: DeviceId,
                pniState: PendingPniState?,
            ) {
                self.oldE164 = oldE164
                self.oldAuthToken = oldAuthToken
                self._localAci = localAci.codableUuid
                self.localDeviceId = localDeviceId
                self.pniState = pniState
            }
        }

        var hasPendingChangeNumber: Bool {
            switch self {
            case .registering, .reRegistering:
                return false
            case .changingNumber(let state):
                return state.pniState != nil
            }
        }
    }

    private lazy var kvStore: KeyValueStore = {
        KeyValueStore(collection: Constants.collectionName)
    }()

    private let deps: RegistrationCoordinatorDependencies

    public init(dependencies: RegistrationCoordinatorDependencies) {
        self.deps = dependencies
    }

    public func restoreLastMode(transaction: DBReadTransaction) -> RegistrationMode? {
        return loadMode(transaction: transaction)?.asRegistrationMode()
    }

    public func coordinator(
        forDesiredMode desiredMode: RegistrationMode,
        transaction: DBWriteTransaction,
        logger: PrefixedLogger,
    ) -> RegistrationCoordinator {
        let mode = loadMode(transaction: transaction) ?? desiredMode.asInternalMode()
        do {
            try self.kvStore.setCodable(mode, key: Constants.modeKey, transaction: transaction)
        } catch {
            owsFailDebug("Failed to write registration mode to disk: \(error)")
        }
        if mode.hasPendingChangeNumber {
            // This should happen on app startup, but do it here too to be safe.
            deps.messagePipelineSupervisor.suspendMessageProcessingWithoutHandle(for: .pendingChangeNumber)
            deps.preKeyManager.setIsChangingNumber(true)
        }
        let delegate = CoordinatorDelegate(loader: self)
        Logger.info("Starting registration, mode: \(mode.logString)")
        return RegistrationCoordinatorImpl(mode: mode, loader: delegate, dependencies: deps, logger: logger)
    }

    public func hasPendingChangeNumber(transaction: DBReadTransaction) -> Bool {
        return loadMode(transaction: transaction)?.hasPendingChangeNumber ?? false
    }

    private func loadMode(transaction: DBReadTransaction) -> Mode? {
        do {
            return try kvStore.getCodableValue(forKey: Constants.modeKey, transaction: transaction)
        } catch {
            // Failed to parse, even though we know there is something there.
            // This is BAD. We might've been in the middle of change number, which NEEDS to recover.
            owsFail("Unable to restore in-progress registration mode: \(error)")
        }
    }

    // Here we put methods from this loader impl class that we want to expose to
    // RegistrationCoordinatorImpl but not expose to anything else.
    // This misdirection only exists because we have one big package and `internal`
    // is meaningless; ideally RegistrationCoordinatorImpl and RegistrationCoordinatorLoaderImpl
    // would get to talk to each other in their own internal API and expose only public things
    // to the outside world.
    class CoordinatorDelegate: RegistrationCoordinatorLoaderDelegate {

        let loader: RegistrationCoordinatorLoaderImpl

        // Its important that this initializer be fileprivate; nothing outside of this
        // class should initialize one of these.
        fileprivate init(loader: RegistrationCoordinatorLoaderImpl) {
            self.loader = loader
        }

        func clearPersistedMode(transaction: DBWriteTransaction) {
            loader.kvStore.removeValue(forKey: Constants.modeKey, transaction: transaction)
        }

        func savePendingChangeNumber(
            oldState: Mode.ChangeNumberState,
            pniState: Mode.ChangeNumberState.PendingPniState?,
            transaction: DBWriteTransaction,
        ) throws -> Mode.ChangeNumberState {
            var newState = oldState
            newState.pniState = pniState
            try loader.kvStore.setCodable(Mode.changingNumber(newState), key: Constants.modeKey, transaction: transaction)
            let messagePipelineSupervisor = loader.deps.messagePipelineSupervisor
            let preKeyManager = loader.deps.preKeyManager
            transaction.addSyncCompletion {
                let hasPendingChangeNumber = Mode.changingNumber(newState).hasPendingChangeNumber
                if hasPendingChangeNumber {
                    messagePipelineSupervisor.suspendMessageProcessingWithoutHandle(for: .pendingChangeNumber)
                } else {
                    messagePipelineSupervisor.unsuspendMessageProcessing(for: .pendingChangeNumber)
                }
                preKeyManager.setIsChangingNumber(hasPendingChangeNumber)
            }
            return newState
        }
    }

    enum Constants {
        static let collectionName = "RegistrationCoordinatorLoader"
        static let modeKey = "mode"
    }
}

// MARK: - Mode Transformers

extension RegistrationMode {

    fileprivate func asInternalMode() -> RegistrationCoordinatorLoaderImpl.Mode {
        switch self {
        case .registering:
            return .registering(.init())
        case .reRegistering(let params):
            return .reRegistering(.init(e164: params.e164, aci: params.aci))
        case .changingNumber(let params):
            return .changingNumber(.init(
                oldE164: params.oldE164,
                oldAuthToken: params.oldAuthToken,
                localAci: params.localAci,
                localDeviceId: params.localDeviceId,
                pniState: nil,
            ))
        }
    }
}

extension RegistrationCoordinatorLoaderImpl.Mode {

    fileprivate func asRegistrationMode() -> RegistrationMode {
        switch self {
        case .registering:
            return .registering
        case .reRegistering(let state):
            return .reRegistering(.init(e164: state.e164, aci: state.aci))
        case .changingNumber(let state):
            return .changingNumber(.init(
                oldE164: state.oldE164,
                oldAuthToken: state.oldAuthToken,
                localAci: state.localAci,
                localDeviceId: state.localDeviceId,
            ))
        }
    }

    fileprivate var logString: String {
        switch self {
        case .registering:
            return "initial registration"
        case .reRegistering(let reRegisteringState):
            return "re-registration aci:\(reRegisteringState.aci) e164:\(reRegisteringState.e164.stringValue)"
        case .changingNumber(let changeNumberState):
            return "changing number: aci:\(changeNumberState.localAci) old e164:\(changeNumberState.oldE164.stringValue)"
        }
    }
}

// MARK: - PNI state transformer

extension ChangePhoneNumberPni.PendingState {
    func asRegPniState() -> RegistrationCoordinatorLoaderImpl.Mode.ChangeNumberState.PendingPniState {
        return RegistrationCoordinatorLoaderImpl.Mode.ChangeNumberState.PendingPniState(
            newE164: newE164,
            pniIdentityKeyPair: pniIdentityKeyPair,
            localDevicePniSignedPreKeyRecord: .success(localDevicePniSignedPreKeyRecord),
            localDevicePniPqLastResortPreKeyRecord: .success(localDevicePniPqLastResortPreKeyRecord),
            localDevicePniRegistrationId: localDevicePniRegistrationId,
        )
    }
}

// MARK: - PNI state Codable

extension RegistrationCoordinatorLoaderImpl.Mode.ChangeNumberState.PendingPniState: Codable {
    private enum CodingKeys: String, CodingKey {
        case newE164
        case pniIdentityKeyPair
        case localDevicePniSignedPreKeyRecord // deprecated
        case localDevicePniSignedPreKeyRecordData
        case localDevicePniPqLastResortPreKeyRecord // deprecated
        case localDevicePniPqLastResortPreKeyRecordData
        case localDevicePniRegistrationId
    }

    public init(from decoder: Decoder) throws {
        let container = try decoder.container(keyedBy: CodingKeys.self)

        self.newE164 = try container.decode(E164.self, forKey: .newE164)
        self.localDevicePniRegistrationId = try container.decode(UInt32.self, forKey: .localDevicePniRegistrationId)

        if let modernValue = try container.decodeIfPresent(Data.self, forKey: .localDevicePniPqLastResortPreKeyRecordData) {
            self.localDevicePniPqLastResortPreKeyRecord = .success(try LibSignalClient.KyberPreKeyRecord(bytes: modernValue))
        } else if
            BuildFlags.decodeDeprecatedPreKeys,
            let deprecatedValue = try container.decodeIfPresent(KyberRecordKeyData.self, forKey: .localDevicePniPqLastResortPreKeyRecord)
        {
            self.localDevicePniPqLastResortPreKeyRecord = .success(try LibSignalClient.KyberPreKeyRecord(bytes: deprecatedValue.keyData))
        } else {
            // We don't want to fail the ENTIRE registration operation when this is
            // missing -- we can recover in this case, but we need to communicate the
            // failure to finalizePniIdentity.
            self.localDevicePniPqLastResortPreKeyRecord = .failure(.dataCorruptedError(
                forKey: .localDevicePniPqLastResortPreKeyRecordData,
                in: container,
                debugDescription: "last resort pre key is missing",
            ))
        }

        if let modernValue = try container.decodeIfPresent(Data.self, forKey: .localDevicePniSignedPreKeyRecordData) {
            self.localDevicePniSignedPreKeyRecord = .success(try LibSignalClient.SignedPreKeyRecord(bytes: modernValue))
        } else if
            BuildFlags.decodeDeprecatedPreKeys,
            let deprecatedValue = try container.decodeIfPresent(Data.self, forKey: .localDevicePniSignedPreKeyRecord)
        {
            guard let signedPreKeyRecord = try NSKeyedUnarchiver.unarchivedObject(ofClass: SignalServiceKit.SignedPreKeyRecord.self, from: deprecatedValue) else {
                throw DecodingError.dataCorruptedError(forKey: .localDevicePniSignedPreKeyRecord, in: container, debugDescription: "")
            }
            self.localDevicePniSignedPreKeyRecord = .success(try LibSignalClient.SignedPreKeyRecord(
                id: UInt32(bitPattern: signedPreKeyRecord.id),
                timestamp: signedPreKeyRecord.generatedAt.ows_millisecondsSince1970,
                privateKey: signedPreKeyRecord.keyPair.keyPair.privateKey,
                signature: signedPreKeyRecord.signature,
            ))
        } else {
            // We don't want to fail the ENTIRE registration operation when this is
            // missing -- we can recover in this case, but we need to communicate the
            // failure to finalizePniIdentity.
            self.localDevicePniSignedPreKeyRecord = .failure(.dataCorruptedError(
                forKey: .localDevicePniSignedPreKeyRecordData,
                in: container,
                debugDescription: "signed pre key is missing",
            ))
        }

        guard
            let pniIdentityKeyPair: ECKeyPair = try Self.decodeKeyedArchive(
                fromDecodingContainer: container,
                forKey: .pniIdentityKeyPair,
            )
        else {
            throw OWSAssertionError("Unable to deserialize NSKeyedArchiver fields!")
        }

        self.pniIdentityKeyPair = pniIdentityKeyPair
    }

    public func encode(to encoder: Encoder) throws {
        var container = encoder.container(keyedBy: CodingKeys.self)

        try container.encode(newE164, forKey: .newE164)
        try container.encode(localDevicePniRegistrationId, forKey: .localDevicePniRegistrationId)
        try container.encodeIfPresent((try? localDevicePniSignedPreKeyRecord.get())?.serialize(), forKey: .localDevicePniSignedPreKeyRecordData)
        try container.encodeIfPresent((try? localDevicePniPqLastResortPreKeyRecord.get())?.serialize(), forKey: .localDevicePniPqLastResortPreKeyRecordData)

        try Self.encodeKeyedArchive(
            value: pniIdentityKeyPair,
            toEncodingContainer: &container,
            forKey: .pniIdentityKeyPair,
        )
    }

    /// A shim of the former KyberPreKeyRecord that contains what's necessary to
    /// maintain continuity with historically-encoded values.
    private struct KyberRecordKeyData: Codable {
        var keyData: Data
    }

    // MARK: NSKeyed[Un]Archiver

    private static func decodeKeyedArchive<T: NSObject & NSSecureCoding>(
        fromDecodingContainer decodingContainer: KeyedDecodingContainer<CodingKeys>,
        forKey key: CodingKeys,
    ) throws -> T? {
        let data = try decodingContainer.decode(Data.self, forKey: key)

        return try NSKeyedUnarchiver.unarchivedObject(ofClass: T.self, from: data)
    }

    private static func encodeKeyedArchive<T: NSObject & NSSecureCoding>(
        value: T,
        toEncodingContainer encodingContainer: inout KeyedEncodingContainer<CodingKeys>,
        forKey key: CodingKeys,
    ) throws {
        let data = try NSKeyedArchiver.archivedData(
            withRootObject: value,
            requiringSecureCoding: true,
        )

        try encodingContainer.encode(data, forKey: key)
    }
}