Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Messages/Interactions/PinnedMessages/PinnedMessageManager.swift
1 views
//
// Copyright 2025 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation
import GRDB
public import LibSignalClient

public struct PinMessageDetails {
    let pinnedAtTimestamp: UInt64
    let expiresAtTimestamp: UInt64? // nil is forever
}

public class PinnedMessageManager {
    private let disappearingMessagesConfigurationStore: DisappearingMessagesConfigurationStore
    private let accountManager: TSAccountManager
    private let interactionStore: InteractionStore
    private let keyValueStore: NewKeyValueStore
    private let db: DB
    private let threadStore: ThreadStore
    private let dateProvider: DateProvider
    private let expirationJob: PinnedMessageExpirationJob

    // Int value of how many times the disappearing message warning has been shown.
    // If 3 or greater, don't show again.
    private static let disappearingMessageWarningShownKey = "disappearingMessageWarningShownKey"

    init(
        disappearingMessagesConfigurationStore: DisappearingMessagesConfigurationStore,
        interactionStore: InteractionStore,
        accountManager: TSAccountManager,
        db: DB,
        threadStore: ThreadStore,
        dateProvider: @escaping DateProvider,
        expirationJob: PinnedMessageExpirationJob,
    ) {
        self.disappearingMessagesConfigurationStore = disappearingMessagesConfigurationStore
        self.interactionStore = interactionStore
        self.accountManager = accountManager
        self.db = db
        self.threadStore = threadStore
        self.keyValueStore = NewKeyValueStore(collection: "PinnedMessage")
        self.dateProvider = dateProvider
        self.expirationJob = expirationJob
    }

    public func fetchPinnedMessagesForThread(
        threadId: Int64,
        tx: DBReadTransaction,
    ) -> [TSMessage] {
        return failIfThrows {
            return try InteractionRecord.fetchAll(
                tx.database,
                sql: """
                    SELECT m.* FROM \(InteractionRecord.databaseTableName) as m
                    JOIN \(PinnedMessageRecord.databaseTableName) as p
                    ON p.\(PinnedMessageRecord.CodingKeys.interactionId.rawValue) = m.\(InteractionRecord.CodingKeys.id.rawValue)
                    WHERE \(PinnedMessageRecord.CodingKeys.threadId.rawValue) = ?
                    ORDER BY p.\(PinnedMessageRecord.CodingKeys.receivedTimestamp.rawValue) DESC
                """,
                arguments: [threadId],
            ).compactMap { try TSInteraction.fromRecord($0) as? TSMessage }
        }
    }

    private func validateInputsForPinMessage(
        pinMessageProto: SSKProtoDataMessagePinMessage,
    ) throws {
        guard SDS.fitsInInt64(pinMessageProto.targetSentTimestamp) else {
            throw OWSAssertionError("Invalid timestamp.")
        }

        guard pinMessageProto.pinDurationSeconds < Int32.max else {
            throw OWSAssertionError("Invalid timestamp.")
        }
    }

    public func pinMessage(
        pinMessageProto: SSKProtoDataMessagePinMessage,
        pinAuthor: Aci,
        thread: TSThread,
        pinSentAtTimestamp: UInt64,
        expireTimer: UInt32?,
        expireTimerVersion: UInt32?,
        transaction: DBWriteTransaction,
    ) throws {
        try validateInputsForPinMessage(pinMessageProto: pinMessageProto)

        let pinReceivedAtTimestamp = dateProvider().ows_millisecondsSince1970

        guard let localAci = accountManager.localIdentifiers(tx: transaction)?.aci else {
            throw OWSAssertionError("User not registered")
        }

        guard
            let targetAuthorAciBinary = pinMessageProto.targetAuthorAciBinary,
            let targetAuthorAci = try? Aci.parseFrom(serviceIdBinary: targetAuthorAciBinary)
        else {
            throw OWSAssertionError("Target author ACI not present")
        }

        var targetMessageInteractionId: Int64
        var targetMessageTimestamp: UInt64
        guard
            let targetMessage = try interactionStore.fetchMessage(
                timestamp: pinMessageProto.targetSentTimestamp,
                incomingMessageAuthor: targetAuthorAci == localAci ? nil : targetAuthorAci,
                threadUniqueId: thread.uniqueId,
                transaction: transaction,
            ), let interactionId = targetMessage.grdbId?.int64Value,
            targetMessage.giftBadge == nil,
            !targetMessage.wasRemotelyDeleted
        else {
            throw OWSAssertionError("Invalid or missing target pinned message")
        }

        targetMessageInteractionId = interactionId
        targetMessageTimestamp = targetMessage.timestamp
        if targetMessage.editState == .pastRevision {
            // Pin targeted an old edit revision, fetch the latest
            // version to ensure the pin shows up properly.
            if
                let latestEdit = DependenciesBridge.shared.editMessageStore.findMessage(
                    fromEdit: targetMessage,
                    tx: transaction,
                )
            {
                targetMessageInteractionId = latestEdit.sqliteRowId!
                targetMessageTimestamp = latestEdit.timestamp
            } else {
                throw OWSAssertionError("Can't find latest edit for pinned message")
            }
        }

        var expiresAt: UInt64?
        if pinMessageProto.hasPinDurationSeconds {
            let pinDurationMilliseconds = UInt64(pinMessageProto.pinDurationSeconds) * 1000
            let result = pinDurationMilliseconds.addingReportingOverflow(pinReceivedAtTimestamp)
            if !result.overflow {
                expiresAt = result.partialValue
            }
        } else if pinMessageProto.hasPinDurationForever {
            // expiresAt should stay nil
        } else {
            throw OWSAssertionError("Pin message has no duration")
        }

        guard let threadId = thread.sqliteRowId else {
            throw OWSAssertionError("threadId not found")
        }

        // If this is a retry of an existing pinned message, delete the old entry so the expiry gets updated.
        deletePinForMessage(interactionId: targetMessageInteractionId, transaction: transaction)

        pruneOldestPinnedMessagesIfNecessary(
            threadId: threadId,
            transaction: transaction,
        )

        failIfThrows {
            _ = try PinnedMessageRecord.insertRecord(
                interactionId: targetMessageInteractionId,
                threadId: threadId,
                expiresAt: expiresAt,
                sentTimestamp: pinSentAtTimestamp,
                receivedTimestamp: pinReceivedAtTimestamp,
                tx: transaction,
            )
        }

        insertInfoMessageForPinnedMessage(
            timestamp: MessageTimestampGenerator.sharedInstance.generateTimestamp(),
            thread: thread,
            targetMessageTimestamp: targetMessageTimestamp,
            targetMessageAuthor: targetAuthorAci,
            pinAuthor: pinAuthor,
            expireTimer: expireTimer,
            expireTimerVersion: expireTimerVersion,
            tx: transaction,
        )

        expirationJob.restart()
    }

    public func unpinMessage(
        unpinMessageProto: SSKProtoDataMessageUnpinMessage,
        threadUniqueId: String,
        transaction: DBWriteTransaction,
    ) throws -> TSInteraction {
        guard SDS.fitsInInt64(unpinMessageProto.targetSentTimestamp) else {
            throw OWSAssertionError("Invalid timestamp.")
        }

        guard let localAci = accountManager.localIdentifiers(tx: transaction)?.aci else {
            throw OWSAssertionError("User not registered")
        }

        guard
            let targetAuthorAciBinary = unpinMessageProto.targetAuthorAciBinary,
            let targetAuthorAci = try? Aci.parseFrom(serviceIdBinary: targetAuthorAciBinary)
        else {
            throw OWSAssertionError("Target author ACI not present")
        }

        var targetMessageInteractionId: Int64
        guard
            let targetMessage = try interactionStore.fetchMessage(
                timestamp: unpinMessageProto.targetSentTimestamp,
                incomingMessageAuthor: targetAuthorAci == localAci ? nil : targetAuthorAci,
                threadUniqueId: threadUniqueId,
                transaction: transaction,
            ), let interactionId = targetMessage.grdbId?.int64Value
        else {
            throw OWSAssertionError("Can't find target pinned message")
        }

        targetMessageInteractionId = interactionId
        if targetMessage.editState == .pastRevision {
            // Pin targeted an old edit revision, fetch the latest
            // version to ensure the pin shows up properly.
            if
                let latestEdit = DependenciesBridge.shared.editMessageStore.findMessage(
                    fromEdit: targetMessage,
                    tx: transaction,
                )
            {
                targetMessageInteractionId = latestEdit.sqliteRowId!
            } else {
                throw OWSAssertionError("Can't find latest edit for pinned message")
            }
        }

        failIfThrows {
            _ = try PinnedMessageRecord
                .filter(PinnedMessageRecord.Columns.interactionId == targetMessageInteractionId)
                .deleteAll(transaction.database)
        }

        return targetMessage
    }

    public func deletePinForMessage(
        interactionId: Int64,
        transaction: DBWriteTransaction,
    ) {
        _ = failIfThrows {
            try PinnedMessageRecord
                .filter(PinnedMessageRecord.Columns.interactionId == interactionId)
                .deleteAll(transaction.database)
        }
    }

    public func pruneOldestPinnedMessagesIfNecessary(
        threadId: Int64,
        transaction: DBWriteTransaction,
    ) {
        let maxNumberOfPinnedMessages = RemoteConfig.current.pinnedMessageLimit

        failIfThrows {
            // Keep the newest pinned messages up to the limit minus one, since we're about to insert.
            let mostRecentPinnedMessageIds: [Int64] = try PinnedMessageRecord
                .filter(PinnedMessageRecord.Columns.threadId == threadId)
                .order(PinnedMessageRecord.Columns.id.desc)
                .limit(Int(maxNumberOfPinnedMessages) - 1)
                .select(PinnedMessageRecord.Columns.id)
                .fetchAll(transaction.database)

            // Delete all others
            try PinnedMessageRecord
                .filter(PinnedMessageRecord.Columns.threadId == threadId)
                .filter(!mostRecentPinnedMessageIds.contains(PinnedMessageRecord.Columns.id))
                .deleteAll(transaction.database)
        }
    }

    public func shouldShowDisappearingMessageWarning(
        message: TSMessage,
        tx: DBReadTransaction,
    ) -> Bool {
        if message.expiresInSeconds == 0 {
            return false
        }
        let numberOfTimesWarningShown = keyValueStore.fetchValue(
            Int64.self,
            forKey: Self.disappearingMessageWarningShownKey,
            tx: tx,
        ) ?? 0

        return numberOfTimesWarningShown < 3
    }

    public func incrementDisappearingMessageWarningCount(tx: DBWriteTransaction) {
        let numberOfTimesWarningShown = keyValueStore.fetchValue(Int64.self, forKey: Self.disappearingMessageWarningShownKey, tx: tx) ?? 0

        keyValueStore.writeValue(numberOfTimesWarningShown + 1, forKey: Self.disappearingMessageWarningShownKey, tx: tx)
    }

    public func stopShowingDisappearingMessageWarning(tx: DBWriteTransaction) {
        keyValueStore.writeValue(3, forKey: Self.disappearingMessageWarningShownKey, tx: tx)
    }

    public func applyPinMessageChangeToLocalState(
        targetTimestamp: UInt64,
        targetAuthorAci: Aci,
        expiresAt: UInt64?,
        isPin: Bool,
        sentTimestamp: UInt64,
        threadUniqueId: String,
        tx: DBWriteTransaction,
    ) {
        let pinnedAtTimestamp = NSDate.ows_millisecondTimeStamp()

        guard let localAci = accountManager.localIdentifiers(tx: tx)?.aci else {
            owsFailDebug("User not registered")
            return
        }

        guard
            let targetMessage = try? interactionStore.fetchMessage(
                timestamp: targetTimestamp,
                incomingMessageAuthor: targetAuthorAci == localAci ? nil : targetAuthorAci,
                threadUniqueId: threadUniqueId,
                transaction: tx,
            ), let thread = threadStore.fetchThread(
                uniqueId: targetMessage.uniqueThreadId,
                tx: tx,
            ), let threadId = thread.sqliteRowId,
            let interactionId = targetMessage.sqliteRowId
        else {
            return
        }

        if !isPin {
            return failIfThrows {
                try PinnedMessageRecord
                    .filter(PinnedMessageRecord.Columns.interactionId == interactionId)
                    .deleteAll(tx.database)

                db.touch(
                    interaction: targetMessage,
                    shouldReindex: false,
                    tx: tx,
                )
            }
        }

        // If this is a retry of an existing pinned message, delete the old entry so the expiry gets updated.
        deletePinForMessage(interactionId: interactionId, transaction: tx)

        pruneOldestPinnedMessagesIfNecessary(
            threadId: threadId,
            transaction: tx,
        )

        failIfThrows {
            _ = try PinnedMessageRecord.insertRecord(
                interactionId: interactionId,
                threadId: threadId,
                expiresAt: expiresAt,
                sentTimestamp: sentTimestamp,
                receivedTimestamp: pinnedAtTimestamp,
                tx: tx,
            )
        }

        db.touch(
            interaction: targetMessage,
            shouldReindex: false,
            tx: tx,
        )

        let dmConfig = disappearingMessagesConfigurationStore.fetchOrBuildDefault(for: .thread(thread), tx: tx)

        insertInfoMessageForPinnedMessage(
            timestamp: MessageTimestampGenerator.sharedInstance.generateTimestamp(),
            thread: thread,
            targetMessageTimestamp: targetTimestamp,
            targetMessageAuthor: targetAuthorAci,
            pinAuthor: localAci,
            expireTimer: dmConfig.durationSeconds,
            expireTimerVersion: dmConfig.timerVersion,
            tx: tx,
        )

        expirationJob.restart()
    }

    private func getMessageAuthorAci(interaction: TSMessage, tx: DBReadTransaction) -> Aci? {
        guard let localAci = accountManager.localIdentifiers(tx: tx)?.aci else {
            owsFailDebug("Can't find data for original message")
            return nil
        }

        if let _ = interaction as? TSOutgoingMessage {
            return localAci
        } else if
            let incomingMessage = interaction as? TSIncomingMessage,
            let authorUUID = incomingMessage.authorUUID,
            let incomingAci = try? Aci.parseFrom(serviceIdString: authorUUID)
        {
            return incomingAci
        } else {
            return nil
        }
    }

    public func getOutgoingPinMessage(
        interaction: TSMessage,
        thread: TSThread,
        expiresAt: TimeInterval?,
        tx: DBWriteTransaction,
    ) -> OutgoingPinMessage? {
        guard let authorAci = getMessageAuthorAci(interaction: interaction, tx: tx) else {
            owsFailDebug("unable to parse authorAci")
            return nil
        }

        var pinDurationSeconds: UInt32?
        if let expiresAt {
            pinDurationSeconds = UInt32(expiresAt)
        }

        return OutgoingPinMessage(
            thread: thread,
            targetMessageTimestamp: interaction.timestamp,
            targetMessageAuthorAciBinary: authorAci,
            pinDurationSeconds: pinDurationSeconds ?? 0,
            pinDurationForever: expiresAt == nil,
            messageExpiresInSeconds: disappearingMessagesConfigurationStore.durationSeconds(for: thread, tx: tx),
            tx: tx,
        )
    }

    public func getOutgoingUnpinMessage(
        interaction: TSMessage,
        thread: TSThread,
        expiresAt: Int64?,
        tx: DBWriteTransaction,
    ) -> OutgoingUnpinMessage? {

        guard let authorAci = getMessageAuthorAci(interaction: interaction, tx: tx) else {
            owsFailDebug("unable to parse authorAci")
            return nil
        }

        return OutgoingUnpinMessage(
            thread: thread,
            targetMessageTimestamp: interaction.timestamp,
            targetMessageAuthorAci: authorAci,
            messageExpiresInSeconds: disappearingMessagesConfigurationStore.durationSeconds(for: thread, tx: tx),
            tx: tx,
        )
    }

    public func insertInfoMessageForPinnedMessage(
        timestamp: UInt64,
        thread: TSThread,
        targetMessageTimestamp: UInt64,
        targetMessageAuthor: Aci,
        pinAuthor: Aci,
        expireTimer: UInt32?,
        expireTimerVersion: UInt32?,
        tx: DBWriteTransaction,
    ) {
        var userInfoForNewMessage: [InfoMessageUserInfoKey: Any] = [:]
        userInfoForNewMessage[.pinnedMessage] = PersistablePinnedMessageItem(
            pinnedMessageAuthorAci: pinAuthor,
            originalMessageAuthorAci: targetMessageAuthor,
            timestamp: Int64(targetMessageTimestamp),
        )

        var timerVersion: NSNumber?
        if let expireTimerVersion {
            timerVersion = NSNumber(value: expireTimerVersion)
        }

        let infoMessage = TSInfoMessage(
            thread: thread,
            timestamp: timestamp,
            serverGuid: nil,
            messageType: .typePinnedMessage,
            expireTimerVersion: timerVersion,
            expiresInSeconds: expireTimer ?? 0,
            infoMessageUserInfo: userInfoForNewMessage,
        )

        infoMessage.anyInsert(transaction: tx)
    }

    public class func nextExpiringPinnedMessage(tx: DBReadTransaction) -> PinnedMessageRecord? {
        return failIfThrows {
            try PinnedMessageRecord
                .filter(PinnedMessageRecord.Columns.expiresAt != nil)
                .order(PinnedMessageRecord.Columns.expiresAt)
                .fetchOne(tx.database)
        }
    }

    // MARK: - Backups

    public func pinMessageDetails(
        interactionId: Int64,
        tx: DBReadTransaction,
    ) -> PinMessageDetails? {
        let sql = """
            SELECT * FROM \(PinnedMessageRecord.databaseTableName)
            WHERE \(PinnedMessageRecord.CodingKeys.interactionId.rawValue) = ?
        """
        return failIfThrows {
            let statement = try tx.database.cachedStatement(sql: sql)
            return try PinnedMessageRecord.fetchOne(statement, arguments: [interactionId])
                .map {
                    PinMessageDetails(
                        pinnedAtTimestamp: $0.receivedTimestamp,
                        expiresAtTimestamp: $0.expiresAt,
                    )
                }
        }
    }

    private func numberOfPinnedMessagesForThread(threadId: Int64, tx: DBReadTransaction) -> Int {
        failIfThrows {
            try PinnedMessageRecord
                .filter(PinnedMessageRecord.Columns.threadId == threadId)
                .fetchCount(tx.database)
        }
    }

    public func applyPinMessageFromBackup(
        message: TSMessage,
        threadId: Int64,
        pinDetails: PinMessageDetails,
        chatItemId: BackupArchive.ChatItemId,
        tx: DBWriteTransaction,
    ) -> BackupArchive.RestoreFrameResult<BackupArchive.ChatItemId> {
        guard let interactionId = message.sqliteRowId else {
            return .failure([.restoreFrameError(.databaseModelMissingRowId(modelClass: TSMessage.self), chatItemId)])
        }

        // check if there's already max limit and throw if so
        let numExistingPins = numberOfPinnedMessagesForThread(threadId: threadId, tx: tx)
        guard numExistingPins < RemoteConfig.current.pinnedMessageLimit else {
            return .partialRestore([.restoreFrameError(.invalidProtoData(.invalidNumberOfPinnedMessages), chatItemId)])
        }

        failIfThrows {
            _ = try PinnedMessageRecord.insertRecord(
                interactionId: interactionId,
                threadId: threadId,
                expiresAt: pinDetails.expiresAtTimestamp,
                sentTimestamp: 0, // Currently not used.
                receivedTimestamp: pinDetails.pinnedAtTimestamp,
                tx: tx,
            )
        }

        return .success
    }
}