Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Messages/Edit/EditMessageStore.swift
1 views
//
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation
import GRDB
public import LibSignalClient

public enum EditMessageTarget {
    case outgoingMessage(OutgoingEditMessageWrapper)
    case incomingMessage(IncomingEditMessageWrapper)

    var wrapper: any EditMessageWrapper {
        switch self {
        case .outgoingMessage(let outgoingMessage):
            return outgoingMessage
        case .incomingMessage(let incomingMessage):
            return incomingMessage
        }
    }
}

public struct EditMessageStore {

    public init() {}

    // MARK: - Reads

    public func editTarget(
        timestamp: UInt64,
        authorAci: Aci?,
        threadUniqueId: String,
        tx: DBReadTransaction,
    ) -> EditMessageTarget? {
        guard SDS.fitsInInt64(timestamp) else {
            owsFailDebug("Received invalid timestamp!")
            return nil
        }
        let sql = """
        SELECT *
        FROM \(InteractionRecord.databaseTableName)
        \(DEBUG_INDEXED_BY("Interaction_timestamp", or: "index_interactions_on_timestamp_sourceDeviceId_and_authorPhoneNumber"))
        WHERE \(interactionColumn: .timestamp) = ?
        AND \(interactionColumn: .authorUUID) IS ?
        AND \(interactionColumn: .threadUniqueId) IS ?
        LIMIT 1
        """
        let interaction = TSInteraction.grdbFetchOne(
            sql: sql,
            arguments: [timestamp, authorAci?.serviceIdUppercaseString, threadUniqueId],
            transaction: tx,
        )
        switch (interaction, authorAci) {
        case (let outgoingMessage as TSOutgoingMessage, nil):
            guard let thread = outgoingMessage.thread(tx: tx) else {
                Logger.warn("No thread for message")
                return nil
            }
            return .outgoingMessage(OutgoingEditMessageWrapper(
                message: outgoingMessage,
                thread: thread,
            ))
        case (let incomingMessage as TSIncomingMessage, let authorAci?):
            guard let thread = incomingMessage.thread(tx: tx) else {
                Logger.warn("No thread for message")
                return nil
            }
            return .incomingMessage(IncomingEditMessageWrapper(
                message: incomingMessage,
                thread: thread,
                authorAci: authorAci,
            ))
        case (.some, _):
            Logger.warn("Unexpected message type found for edit")
            fallthrough
        default:
            return nil
        }
    }

    public func findMessage(
        fromEdit edit: TSMessage,
        tx: DBReadTransaction,
    ) -> TSMessage? {
        let transaction = tx

        let sql = """
            SELECT * FROM \(InteractionRecord.databaseTableName) AS interaction
            INNER JOIN \(EditRecord.databaseTableName) AS editRecord
            ON interaction.\(interactionColumn: .id) = editRecord.latestRevisionId
            WHERE editRecord.pastRevisionId = ?
            LIMIT 1
        """

        let arguments: StatementArguments = [edit.grdbId]
        return TSMessage.grdbFetchOne(
            sql: sql,
            arguments: arguments,
            transaction: transaction,
        ) as? TSMessage
    }

    public func numberOfEdits(
        for message: TSMessage,
        tx: DBReadTransaction,
    ) -> Int {
        let sql = """
            SELECT COUNT(*)
            FROM \(EditRecord.databaseTableName)
            WHERE editRecord.latestRevisionId = ?
        """

        let arguments: StatementArguments = [message.grdbId]

        return failIfThrows {
            return try Int.fetchOne(
                tx.database,
                sql: sql,
                arguments: arguments,
            ) ?? 0
        }
    }

    /// Fetches all past revisions for the given most-recent-revision message.
    ///
    /// - Returns
    /// An edit record and message instance (if one is found) for each past
    /// revision, from newest to oldest.
    public func findEditHistory<MessageType: TSMessage>(
        forMostRecentRevision message: MessageType,
        tx: DBReadTransaction,
    ) throws -> [(record: EditRecord, message: MessageType?)] {
        /// By ordering DESC on `pastRevisionId`, we end up ordering edits
        /// newest-to-oldest. That's because the highest `pastRevisionId` refers
        /// to the most-recently-inserted revision, or newest edit.
        let recordSQL = """
            SELECT * FROM \(EditRecord.databaseTableName)
            WHERE latestRevisionId = ?
            ORDER BY pastRevisionId DESC
        """

        let arguments: StatementArguments = [message.grdbId]

        let records = try EditRecord.fetchAll(
            tx.database,
            sql: recordSQL,
            arguments: arguments,
        )

        return records.map { record -> (EditRecord, MessageType?) in
            let interaction = InteractionFinder.fetch(
                rowId: record.pastRevisionId,
                transaction: tx,
            )
            guard let message = interaction as? MessageType else {
                owsFailDebug("Interaction has unexpected type: \(type(of: interaction))")
                return (record, nil)
            }
            return (record: record, edit: message)
        }
    }

    /// Fetches all EditRecords related to `message`.
    ///
    /// The `message` may be the latest revision or a past revision.
    ///
    /// The EditRecords are fetched "recursively", meaning that every EditRecord
    /// that references a message ID which is referenced by any element of the
    /// result will be returned. This is useful when deleting messages because
    /// it allows us to maintain invariants required by FOREIGN KEY constraints.
    ///
    /// For example, if the revision "graph" is well-formed, we'll return
    /// EditRecords with distinct pastRevisionIds (e.g., 102, 103) which all
    /// refer to the same latestRevisionId (e.g., 101), and we'll return this
    /// exact same result regardless of whether `message` refers to a past
    /// revision (e.g., 102) or the latest revision (e.g., 101).
    ///
    /// If the revision "graph" isn't well-formed, we must fetch extra
    /// EditRecords to ensure we delete all the EditRecords that reference the
    /// messages that are about to be deleted.
    public func findEditRecords(
        relatedTo message: TSMessage,
        tx: DBReadTransaction,
    ) throws -> [EditRecord] {
        // We need to fetch every EditRecord that references message.grdbId or
        // anything that those EditRecords reference, recursively.

        var revisionIdsToCheck = [message.sqliteRowId].compacted()
        var alreadyCheckedRevisionIds = Set<Int64>()

        var editRecords = [EditRecord]()
        while !revisionIdsToCheck.isEmpty {
            let revisionId = revisionIdsToCheck.removeFirst()
            guard alreadyCheckedRevisionIds.insert(revisionId).inserted else {
                continue
            }
            let records = try EditRecord.filter(
                Column(EditRecord.CodingKeys.latestRevisionId) == revisionId
                    || Column(EditRecord.CodingKeys.pastRevisionId) == revisionId,
            ).fetchAll(tx.database)
            revisionIdsToCheck.append(contentsOf: records.map(\.latestRevisionId))
            revisionIdsToCheck.append(contentsOf: records.map(\.pastRevisionId))
            editRecords.append(contentsOf: records)
        }

        // We'll have duplicates because some will be fetched repeatedly.
        return editRecords.removingDuplicates(uniquingElementsBy: { $0.id! })
    }

    // MARK: - Writes

    public func insert(
        _ editRecord: EditRecord,
        tx: DBWriteTransaction,
    ) throws {
        try editRecord.insert(tx.database)
    }

    public func update(
        _ editRecord: EditRecord,
        tx: DBWriteTransaction,
    ) throws {
        try editRecord.update(tx.database)
    }
}