Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Storage/AxolotlStore/SessionStore.swift
1 views
//
// Copyright 2025 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation
import GRDB
public import LibSignalClient

struct SessionRecord: Codable, FetchableRecord, PersistableRecord {
    static let databaseTableName: String = "Session"

    let id: Int64
    var recipientId: SignalRecipient.RowId
    let localIdentity: OWSIdentity
    let deviceId: DeviceId
    /// May be nil if there was a legacy session.
    var serializedRecord: Data?

    enum CodingKeys: String, CodingKey {
        case id
        case recipientId
        case localIdentity
        case deviceId
        case serializedRecord
    }

    enum Columns {
        static let recipientId = Column(CodingKeys.recipientId.rawValue)
        static let localIdentity = Column(CodingKeys.localIdentity.rawValue)
        static let deviceId = Column(CodingKeys.deviceId.rawValue)
        static let serializedRecord = Column(CodingKeys.serializedRecord.rawValue)
    }
}

struct SessionStore {
    func hasSessionRecords(
        forRecipientId recipientId: SignalRecipient.RowId,
        localIdentity: OWSIdentity,
        tx: DBReadTransaction,
    ) -> Bool {
        let sessionRecords = fetchSessionRecords(
            forRecipientId: recipientId,
            localIdentity: localIdentity,
            tx: tx,
        )
        return !sessionRecords.isEmpty
    }

    func mergeRecipientId(
        _ recipientId: SignalRecipient.RowId,
        into targetRecipientId: SignalRecipient.RowId,
        localIdentity: OWSIdentity,
        tx: DBWriteTransaction,
    ) {
        if hasSessionRecords(forRecipientId: targetRecipientId, localIdentity: localIdentity, tx: tx) {
            // There's already sessions -- prefers those instead of ours.
            deleteSessions(forRecipientId: recipientId, localIdentity: localIdentity, tx: tx)
        } else {
            // There's no sessions -- move ours and reuse them.
            let sessionRecords = fetchSessionRecords(
                forRecipientId: recipientId,
                localIdentity: localIdentity,
                tx: tx,
            )
            for var sessionRecord in sessionRecords {
                sessionRecord.recipientId = targetRecipientId
                failIfThrows { try sessionRecord.update(tx.database) }
            }
        }
    }

    private func buildQuery(
        recipientId: SignalRecipient.RowId,
        localIdentity: OWSIdentity,
        deviceId: DeviceId? = nil,
    ) -> QueryInterfaceRequest<SessionRecord> {
        var result = SessionRecord.filter(SessionRecord.Columns.recipientId == recipientId)
        result = result.filter(SessionRecord.Columns.localIdentity == localIdentity.rawValue)
        if let deviceId {
            result = result.filter(SessionRecord.Columns.deviceId == deviceId.rawValue)
        }
        return result
    }

    private func fetchSessionRecords(
        forRecipientId recipientId: SignalRecipient.RowId,
        localIdentity: OWSIdentity,
        deviceId: DeviceId? = nil,
        tx: DBReadTransaction,
    ) -> [SessionRecord] {
        return failIfThrows {
            return try buildQuery(
                recipientId: recipientId,
                localIdentity: localIdentity,
                deviceId: deviceId,
            ).fetchAll(tx.database)
        }
    }

    func fetchSession(
        forRecipientId recipientId: SignalRecipient.RowId,
        localIdentity: OWSIdentity,
        deviceId: DeviceId,
        tx: DBReadTransaction,
    ) -> LibSignalClient.SessionRecord? {
        let sessionRecords = fetchSessionRecords(
            forRecipientId: recipientId,
            localIdentity: localIdentity,
            deviceId: deviceId,
            tx: tx,
        )
        do {
            return try (sessionRecords.first?.serializedRecord).map(LibSignalClient.SessionRecord.init(bytes:))
        } catch {
            // If we can't decode the session, it's likely due to database corruption,
            // and we continue as if it doesn't exist (to create a new one).
            Logger.warn("couldn't decode session, continuing without it: \(error)")
            return nil
        }
    }

    func archiveSessions(
        forRecipientId recipientId: SignalRecipient.RowId,
        localIdentity: OWSIdentity,
        tx: DBWriteTransaction,
    ) {
        _archiveSessions(forRecipientId: recipientId, localIdentity: localIdentity, deviceId: nil, tx: tx)
    }

    fileprivate func _archiveSessions(
        forRecipientId recipientId: SignalRecipient.RowId,
        localIdentity: OWSIdentity,
        deviceId: DeviceId?,
        tx: DBWriteTransaction,
    ) {
        let sessionRecords = fetchSessionRecords(
            forRecipientId: recipientId,
            localIdentity: localIdentity,
            deviceId: deviceId,
            tx: tx,
        )
        for var sessionRecord in sessionRecords {
            guard let serializedRecord = sessionRecord.serializedRecord else {
                Logger.warn("couldn't decode legacy session to archive it; leaving it as-is")
                continue
            }
            let libSignalSessionRecord: LibSignalClient.SessionRecord
            do {
                libSignalSessionRecord = try LibSignalClient.SessionRecord(bytes: serializedRecord)
            } catch {
                owsFailDebug("couldn't decode session to archive it: \(error)")
                continue
            }
            libSignalSessionRecord.archiveCurrentState()
            sessionRecord.serializedRecord = libSignalSessionRecord.serialize()
            failIfThrows { try sessionRecord.update(tx.database) }
        }
    }

    func deleteSessions(
        forRecipientId recipientId: SignalRecipient.RowId,
        localIdentity: OWSIdentity,
        tx: DBWriteTransaction,
    ) {
        failIfThrows {
            _ = try buildQuery(
                recipientId: recipientId,
                localIdentity: localIdentity,
                deviceId: nil,
            ).deleteAll(tx.database)
        }
    }

    func upsertSession(
        forRecipientId recipientId: SignalRecipient.RowId,
        deviceId: DeviceId,
        localIdentity: OWSIdentity,
        recordData: Data,
        tx: DBWriteTransaction,
    ) {
        failIfThrows {
            try tx.database.execute(
                sql: """
                INSERT OR REPLACE INTO \(SessionRecord.databaseTableName) (
                    \(SessionRecord.Columns.recipientId.name),
                    \(SessionRecord.Columns.deviceId.name),
                    \(SessionRecord.Columns.localIdentity.name),
                    \(SessionRecord.Columns.serializedRecord.name)
                ) VALUES (?, ?, ?, ?)
                """,
                arguments: [
                    recipientId,
                    deviceId.rawValue,
                    localIdentity.rawValue,
                    recordData,
                ],
            )
        }
    }

    func deleteAllSessions(tx: DBWriteTransaction) {
        failIfThrows { _ = try SessionRecord.deleteAll(tx.database) }
    }
}

public class SessionManagerForIdentity: LibSignalClient.SessionStore {
    private let identity: OWSIdentity
    private let recipientIdFinder: RecipientIdFinder
    private let sessionStore: SessionStore

    init(
        identity: OWSIdentity,
        recipientIdFinder: RecipientIdFinder,
        sessionStore: SessionStore,
    ) {
        self.identity = identity
        self.recipientIdFinder = recipientIdFinder
        self.sessionStore = sessionStore
    }

    func archiveSession(forServiceId serviceId: ServiceId, deviceId: DeviceId, tx: DBWriteTransaction) {
        Logger.info("archiving session for \(serviceId).\(deviceId)")
        self._archiveSessions(
            recipientIdResult: self.recipientIdFinder.recipientId(for: serviceId, tx: tx),
            deviceId: deviceId,
            tx: tx,
        )
    }

    public func archiveSessions(forServiceId serviceId: ServiceId, tx: DBWriteTransaction) {
        Logger.info("archiving all sessions for \(serviceId)")
        self._archiveSessions(
            recipientIdResult: self.recipientIdFinder.recipientId(for: serviceId, tx: tx),
            deviceId: nil,
            tx: tx,
        )
    }

    func archiveSessions(forAddress address: SignalServiceAddress, tx: DBWriteTransaction) {
        Logger.info("archiving all sessions for \(address)")
        self._archiveSessions(
            recipientIdResult: self.recipientIdFinder.recipientId(for: address, tx: tx),
            deviceId: nil,
            tx: tx,
        )
    }

    private func _archiveSessions(
        recipientIdResult: Result<SignalRecipient.RowId, RecipientIdError>?,
        deviceId: DeviceId?,
        tx: DBWriteTransaction,
    ) {
        switch recipientIdResult {
        case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
            // There can't possibly be any sessions that need to be archived.
            return
        case .some(.success(let recipientId)):
            self.sessionStore._archiveSessions(
                forRecipientId: recipientId,
                localIdentity: self.identity,
                deviceId: deviceId,
                tx: tx,
            )
        }
    }

    public func deleteSessions(forServiceId serviceId: ServiceId, tx: DBWriteTransaction) {
        switch self.recipientIdFinder.recipientId(for: serviceId, tx: tx) {
        case .none, .some(.failure(.mustNotUsePniBecauseAciExists)):
            // There can't possibly be any sessions that need to be deleted.
            return
        case .some(.success(let recipientId)):
            self.sessionStore.deleteSessions(forRecipientId: recipientId, localIdentity: self.identity, tx: tx)
        }
    }

    func loadSession(
        forServiceId serviceId: ServiceId,
        deviceId: DeviceId,
        tx: DBReadTransaction,
    ) throws -> LibSignalClient.SessionRecord? {
        switch self.recipientIdFinder.recipientId(for: serviceId, tx: tx) {
        case .none:
            return nil
        case .some(.success(let recipientId)):
            return self.sessionStore.fetchSession(
                forRecipientId: recipientId,
                localIdentity: self.identity,
                deviceId: deviceId,
                tx: tx,
            )
        case .some(.failure(let error)):
            switch error {
            case .mustNotUsePniBecauseAciExists:
                throw error
            }
        }
    }

    public func loadSession(
        for address: LibSignalClient.ProtocolAddress,
        context: any LibSignalClient.StoreContext,
    ) throws -> LibSignalClient.SessionRecord? {
        return try loadSession(
            forServiceId: address.serviceId,
            deviceId: address.deviceIdObj,
            tx: context.asTransaction,
        )
    }

    public func loadExistingSessions(
        for addresses: [LibSignalClient.ProtocolAddress],
        context: any LibSignalClient.StoreContext,
    ) throws -> [LibSignalClient.SessionRecord] {
        return try addresses.map { address in
            guard let session = try loadSession(for: address, context: context) else {
                throw SignalError.sessionNotFound("\(address)")
            }
            return session
        }
    }

    public func storeSession(
        _ record: LibSignalClient.SessionRecord,
        for address: LibSignalClient.ProtocolAddress,
        context: any LibSignalClient.StoreContext,
    ) throws {
        switch recipientIdFinder.ensureRecipientId(for: address.serviceId, tx: context.asTransaction) {
        case .success(let recipientId):
            self.sessionStore.upsertSession(
                forRecipientId: recipientId,
                deviceId: address.deviceIdObj,
                localIdentity: self.identity,
                recordData: record.serialize(),
                tx: context.asTransaction,
            )
        case .failure(let error):
            switch error {
            case .mustNotUsePniBecauseAciExists:
                throw error
            }
        }
    }
}