Path: blob/main/SignalServiceKit/Messages/Interactions/PinnedMessages/PinnedMessageManager.swift
1 views
//
// Copyright 2025 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import GRDB
public import LibSignalClient
public struct PinMessageDetails {
let pinnedAtTimestamp: UInt64
let expiresAtTimestamp: UInt64? // nil is forever
}
public class PinnedMessageManager {
private let disappearingMessagesConfigurationStore: DisappearingMessagesConfigurationStore
private let accountManager: TSAccountManager
private let interactionStore: InteractionStore
private let keyValueStore: NewKeyValueStore
private let db: DB
private let threadStore: ThreadStore
private let dateProvider: DateProvider
private let expirationJob: PinnedMessageExpirationJob
// Int value of how many times the disappearing message warning has been shown.
// If 3 or greater, don't show again.
private static let disappearingMessageWarningShownKey = "disappearingMessageWarningShownKey"
init(
disappearingMessagesConfigurationStore: DisappearingMessagesConfigurationStore,
interactionStore: InteractionStore,
accountManager: TSAccountManager,
db: DB,
threadStore: ThreadStore,
dateProvider: @escaping DateProvider,
expirationJob: PinnedMessageExpirationJob,
) {
self.disappearingMessagesConfigurationStore = disappearingMessagesConfigurationStore
self.interactionStore = interactionStore
self.accountManager = accountManager
self.db = db
self.threadStore = threadStore
self.keyValueStore = NewKeyValueStore(collection: "PinnedMessage")
self.dateProvider = dateProvider
self.expirationJob = expirationJob
}
public func fetchPinnedMessagesForThread(
threadId: Int64,
tx: DBReadTransaction,
) -> [TSMessage] {
return failIfThrows {
return try InteractionRecord.fetchAll(
tx.database,
sql: """
SELECT m.* FROM \(InteractionRecord.databaseTableName) as m
JOIN \(PinnedMessageRecord.databaseTableName) as p
ON p.\(PinnedMessageRecord.CodingKeys.interactionId.rawValue) = m.\(InteractionRecord.CodingKeys.id.rawValue)
WHERE \(PinnedMessageRecord.CodingKeys.threadId.rawValue) = ?
ORDER BY p.\(PinnedMessageRecord.CodingKeys.receivedTimestamp.rawValue) DESC
""",
arguments: [threadId],
).compactMap { try TSInteraction.fromRecord($0) as? TSMessage }
}
}
private func validateInputsForPinMessage(
pinMessageProto: SSKProtoDataMessagePinMessage,
) throws {
guard SDS.fitsInInt64(pinMessageProto.targetSentTimestamp) else {
throw OWSAssertionError("Invalid timestamp.")
}
guard pinMessageProto.pinDurationSeconds < Int32.max else {
throw OWSAssertionError("Invalid timestamp.")
}
}
public func pinMessage(
pinMessageProto: SSKProtoDataMessagePinMessage,
pinAuthor: Aci,
thread: TSThread,
pinSentAtTimestamp: UInt64,
expireTimer: UInt32?,
expireTimerVersion: UInt32?,
transaction: DBWriteTransaction,
) throws {
try validateInputsForPinMessage(pinMessageProto: pinMessageProto)
let pinReceivedAtTimestamp = dateProvider().ows_millisecondsSince1970
guard let localAci = accountManager.localIdentifiers(tx: transaction)?.aci else {
throw OWSAssertionError("User not registered")
}
guard
let targetAuthorAciBinary = pinMessageProto.targetAuthorAciBinary,
let targetAuthorAci = try? Aci.parseFrom(serviceIdBinary: targetAuthorAciBinary)
else {
throw OWSAssertionError("Target author ACI not present")
}
var targetMessageInteractionId: Int64
var targetMessageTimestamp: UInt64
guard
let targetMessage = try interactionStore.fetchMessage(
timestamp: pinMessageProto.targetSentTimestamp,
incomingMessageAuthor: targetAuthorAci == localAci ? nil : targetAuthorAci,
threadUniqueId: thread.uniqueId,
transaction: transaction,
), let interactionId = targetMessage.grdbId?.int64Value,
targetMessage.giftBadge == nil,
!targetMessage.wasRemotelyDeleted
else {
throw OWSAssertionError("Invalid or missing target pinned message")
}
targetMessageInteractionId = interactionId
targetMessageTimestamp = targetMessage.timestamp
if targetMessage.editState == .pastRevision {
// Pin targeted an old edit revision, fetch the latest
// version to ensure the pin shows up properly.
if
let latestEdit = DependenciesBridge.shared.editMessageStore.findMessage(
fromEdit: targetMessage,
tx: transaction,
)
{
targetMessageInteractionId = latestEdit.sqliteRowId!
targetMessageTimestamp = latestEdit.timestamp
} else {
throw OWSAssertionError("Can't find latest edit for pinned message")
}
}
var expiresAt: UInt64?
if pinMessageProto.hasPinDurationSeconds {
let pinDurationMilliseconds = UInt64(pinMessageProto.pinDurationSeconds) * 1000
let result = pinDurationMilliseconds.addingReportingOverflow(pinReceivedAtTimestamp)
if !result.overflow {
expiresAt = result.partialValue
}
} else if pinMessageProto.hasPinDurationForever {
// expiresAt should stay nil
} else {
throw OWSAssertionError("Pin message has no duration")
}
guard let threadId = thread.sqliteRowId else {
throw OWSAssertionError("threadId not found")
}
// If this is a retry of an existing pinned message, delete the old entry so the expiry gets updated.
deletePinForMessage(interactionId: targetMessageInteractionId, transaction: transaction)
pruneOldestPinnedMessagesIfNecessary(
threadId: threadId,
transaction: transaction,
)
failIfThrows {
_ = try PinnedMessageRecord.insertRecord(
interactionId: targetMessageInteractionId,
threadId: threadId,
expiresAt: expiresAt,
sentTimestamp: pinSentAtTimestamp,
receivedTimestamp: pinReceivedAtTimestamp,
tx: transaction,
)
}
insertInfoMessageForPinnedMessage(
timestamp: MessageTimestampGenerator.sharedInstance.generateTimestamp(),
thread: thread,
targetMessageTimestamp: targetMessageTimestamp,
targetMessageAuthor: targetAuthorAci,
pinAuthor: pinAuthor,
expireTimer: expireTimer,
expireTimerVersion: expireTimerVersion,
tx: transaction,
)
expirationJob.restart()
}
public func unpinMessage(
unpinMessageProto: SSKProtoDataMessageUnpinMessage,
threadUniqueId: String,
transaction: DBWriteTransaction,
) throws -> TSInteraction {
guard SDS.fitsInInt64(unpinMessageProto.targetSentTimestamp) else {
throw OWSAssertionError("Invalid timestamp.")
}
guard let localAci = accountManager.localIdentifiers(tx: transaction)?.aci else {
throw OWSAssertionError("User not registered")
}
guard
let targetAuthorAciBinary = unpinMessageProto.targetAuthorAciBinary,
let targetAuthorAci = try? Aci.parseFrom(serviceIdBinary: targetAuthorAciBinary)
else {
throw OWSAssertionError("Target author ACI not present")
}
var targetMessageInteractionId: Int64
guard
let targetMessage = try interactionStore.fetchMessage(
timestamp: unpinMessageProto.targetSentTimestamp,
incomingMessageAuthor: targetAuthorAci == localAci ? nil : targetAuthorAci,
threadUniqueId: threadUniqueId,
transaction: transaction,
), let interactionId = targetMessage.grdbId?.int64Value
else {
throw OWSAssertionError("Can't find target pinned message")
}
targetMessageInteractionId = interactionId
if targetMessage.editState == .pastRevision {
// Pin targeted an old edit revision, fetch the latest
// version to ensure the pin shows up properly.
if
let latestEdit = DependenciesBridge.shared.editMessageStore.findMessage(
fromEdit: targetMessage,
tx: transaction,
)
{
targetMessageInteractionId = latestEdit.sqliteRowId!
} else {
throw OWSAssertionError("Can't find latest edit for pinned message")
}
}
failIfThrows {
_ = try PinnedMessageRecord
.filter(PinnedMessageRecord.Columns.interactionId == targetMessageInteractionId)
.deleteAll(transaction.database)
}
return targetMessage
}
public func deletePinForMessage(
interactionId: Int64,
transaction: DBWriteTransaction,
) {
_ = failIfThrows {
try PinnedMessageRecord
.filter(PinnedMessageRecord.Columns.interactionId == interactionId)
.deleteAll(transaction.database)
}
}
public func pruneOldestPinnedMessagesIfNecessary(
threadId: Int64,
transaction: DBWriteTransaction,
) {
let maxNumberOfPinnedMessages = RemoteConfig.current.pinnedMessageLimit
failIfThrows {
// Keep the newest pinned messages up to the limit minus one, since we're about to insert.
let mostRecentPinnedMessageIds: [Int64] = try PinnedMessageRecord
.filter(PinnedMessageRecord.Columns.threadId == threadId)
.order(PinnedMessageRecord.Columns.id.desc)
.limit(Int(maxNumberOfPinnedMessages) - 1)
.select(PinnedMessageRecord.Columns.id)
.fetchAll(transaction.database)
// Delete all others
try PinnedMessageRecord
.filter(PinnedMessageRecord.Columns.threadId == threadId)
.filter(!mostRecentPinnedMessageIds.contains(PinnedMessageRecord.Columns.id))
.deleteAll(transaction.database)
}
}
public func shouldShowDisappearingMessageWarning(
message: TSMessage,
tx: DBReadTransaction,
) -> Bool {
if message.expiresInSeconds == 0 {
return false
}
let numberOfTimesWarningShown = keyValueStore.fetchValue(
Int64.self,
forKey: Self.disappearingMessageWarningShownKey,
tx: tx,
) ?? 0
return numberOfTimesWarningShown < 3
}
public func incrementDisappearingMessageWarningCount(tx: DBWriteTransaction) {
let numberOfTimesWarningShown = keyValueStore.fetchValue(Int64.self, forKey: Self.disappearingMessageWarningShownKey, tx: tx) ?? 0
keyValueStore.writeValue(numberOfTimesWarningShown + 1, forKey: Self.disappearingMessageWarningShownKey, tx: tx)
}
public func stopShowingDisappearingMessageWarning(tx: DBWriteTransaction) {
keyValueStore.writeValue(3, forKey: Self.disappearingMessageWarningShownKey, tx: tx)
}
public func applyPinMessageChangeToLocalState(
targetTimestamp: UInt64,
targetAuthorAci: Aci,
expiresAt: UInt64?,
isPin: Bool,
sentTimestamp: UInt64,
threadUniqueId: String,
tx: DBWriteTransaction,
) {
let pinnedAtTimestamp = NSDate.ows_millisecondTimeStamp()
guard let localAci = accountManager.localIdentifiers(tx: tx)?.aci else {
owsFailDebug("User not registered")
return
}
guard
let targetMessage = try? interactionStore.fetchMessage(
timestamp: targetTimestamp,
incomingMessageAuthor: targetAuthorAci == localAci ? nil : targetAuthorAci,
threadUniqueId: threadUniqueId,
transaction: tx,
), let thread = threadStore.fetchThread(
uniqueId: targetMessage.uniqueThreadId,
tx: tx,
), let threadId = thread.sqliteRowId,
let interactionId = targetMessage.sqliteRowId
else {
return
}
if !isPin {
return failIfThrows {
try PinnedMessageRecord
.filter(PinnedMessageRecord.Columns.interactionId == interactionId)
.deleteAll(tx.database)
db.touch(
interaction: targetMessage,
shouldReindex: false,
tx: tx,
)
}
}
// If this is a retry of an existing pinned message, delete the old entry so the expiry gets updated.
deletePinForMessage(interactionId: interactionId, transaction: tx)
pruneOldestPinnedMessagesIfNecessary(
threadId: threadId,
transaction: tx,
)
failIfThrows {
_ = try PinnedMessageRecord.insertRecord(
interactionId: interactionId,
threadId: threadId,
expiresAt: expiresAt,
sentTimestamp: sentTimestamp,
receivedTimestamp: pinnedAtTimestamp,
tx: tx,
)
}
db.touch(
interaction: targetMessage,
shouldReindex: false,
tx: tx,
)
let dmConfig = disappearingMessagesConfigurationStore.fetchOrBuildDefault(for: .thread(thread), tx: tx)
insertInfoMessageForPinnedMessage(
timestamp: MessageTimestampGenerator.sharedInstance.generateTimestamp(),
thread: thread,
targetMessageTimestamp: targetTimestamp,
targetMessageAuthor: targetAuthorAci,
pinAuthor: localAci,
expireTimer: dmConfig.durationSeconds,
expireTimerVersion: dmConfig.timerVersion,
tx: tx,
)
expirationJob.restart()
}
private func getMessageAuthorAci(interaction: TSMessage, tx: DBReadTransaction) -> Aci? {
guard let localAci = accountManager.localIdentifiers(tx: tx)?.aci else {
owsFailDebug("Can't find data for original message")
return nil
}
if let _ = interaction as? TSOutgoingMessage {
return localAci
} else if
let incomingMessage = interaction as? TSIncomingMessage,
let authorUUID = incomingMessage.authorUUID,
let incomingAci = try? Aci.parseFrom(serviceIdString: authorUUID)
{
return incomingAci
} else {
return nil
}
}
public func getOutgoingPinMessage(
interaction: TSMessage,
thread: TSThread,
expiresAt: TimeInterval?,
tx: DBWriteTransaction,
) -> OutgoingPinMessage? {
guard let authorAci = getMessageAuthorAci(interaction: interaction, tx: tx) else {
owsFailDebug("unable to parse authorAci")
return nil
}
var pinDurationSeconds: UInt32?
if let expiresAt {
pinDurationSeconds = UInt32(expiresAt)
}
return OutgoingPinMessage(
thread: thread,
targetMessageTimestamp: interaction.timestamp,
targetMessageAuthorAciBinary: authorAci,
pinDurationSeconds: pinDurationSeconds ?? 0,
pinDurationForever: expiresAt == nil,
messageExpiresInSeconds: disappearingMessagesConfigurationStore.durationSeconds(for: thread, tx: tx),
tx: tx,
)
}
public func getOutgoingUnpinMessage(
interaction: TSMessage,
thread: TSThread,
expiresAt: Int64?,
tx: DBWriteTransaction,
) -> OutgoingUnpinMessage? {
guard let authorAci = getMessageAuthorAci(interaction: interaction, tx: tx) else {
owsFailDebug("unable to parse authorAci")
return nil
}
return OutgoingUnpinMessage(
thread: thread,
targetMessageTimestamp: interaction.timestamp,
targetMessageAuthorAci: authorAci,
messageExpiresInSeconds: disappearingMessagesConfigurationStore.durationSeconds(for: thread, tx: tx),
tx: tx,
)
}
public func insertInfoMessageForPinnedMessage(
timestamp: UInt64,
thread: TSThread,
targetMessageTimestamp: UInt64,
targetMessageAuthor: Aci,
pinAuthor: Aci,
expireTimer: UInt32?,
expireTimerVersion: UInt32?,
tx: DBWriteTransaction,
) {
var userInfoForNewMessage: [InfoMessageUserInfoKey: Any] = [:]
userInfoForNewMessage[.pinnedMessage] = PersistablePinnedMessageItem(
pinnedMessageAuthorAci: pinAuthor,
originalMessageAuthorAci: targetMessageAuthor,
timestamp: Int64(targetMessageTimestamp),
)
var timerVersion: NSNumber?
if let expireTimerVersion {
timerVersion = NSNumber(value: expireTimerVersion)
}
let infoMessage = TSInfoMessage(
thread: thread,
timestamp: timestamp,
serverGuid: nil,
messageType: .typePinnedMessage,
expireTimerVersion: timerVersion,
expiresInSeconds: expireTimer ?? 0,
infoMessageUserInfo: userInfoForNewMessage,
)
infoMessage.anyInsert(transaction: tx)
}
public class func nextExpiringPinnedMessage(tx: DBReadTransaction) -> PinnedMessageRecord? {
return failIfThrows {
try PinnedMessageRecord
.filter(PinnedMessageRecord.Columns.expiresAt != nil)
.order(PinnedMessageRecord.Columns.expiresAt)
.fetchOne(tx.database)
}
}
// MARK: - Backups
public func pinMessageDetails(
interactionId: Int64,
tx: DBReadTransaction,
) -> PinMessageDetails? {
let sql = """
SELECT * FROM \(PinnedMessageRecord.databaseTableName)
WHERE \(PinnedMessageRecord.CodingKeys.interactionId.rawValue) = ?
"""
return failIfThrows {
let statement = try tx.database.cachedStatement(sql: sql)
return try PinnedMessageRecord.fetchOne(statement, arguments: [interactionId])
.map {
PinMessageDetails(
pinnedAtTimestamp: $0.receivedTimestamp,
expiresAtTimestamp: $0.expiresAt,
)
}
}
}
private func numberOfPinnedMessagesForThread(threadId: Int64, tx: DBReadTransaction) -> Int {
failIfThrows {
try PinnedMessageRecord
.filter(PinnedMessageRecord.Columns.threadId == threadId)
.fetchCount(tx.database)
}
}
public func applyPinMessageFromBackup(
message: TSMessage,
threadId: Int64,
pinDetails: PinMessageDetails,
chatItemId: BackupArchive.ChatItemId,
tx: DBWriteTransaction,
) -> BackupArchive.RestoreFrameResult<BackupArchive.ChatItemId> {
guard let interactionId = message.sqliteRowId else {
return .failure([.restoreFrameError(.databaseModelMissingRowId(modelClass: TSMessage.self), chatItemId)])
}
// check if there's already max limit and throw if so
let numExistingPins = numberOfPinnedMessagesForThread(threadId: threadId, tx: tx)
guard numExistingPins < RemoteConfig.current.pinnedMessageLimit else {
return .partialRestore([.restoreFrameError(.invalidProtoData(.invalidNumberOfPinnedMessages), chatItemId)])
}
failIfThrows {
_ = try PinnedMessageRecord.insertRecord(
interactionId: interactionId,
threadId: threadId,
expiresAt: pinDetails.expiresAtTimestamp,
sentTimestamp: 0, // Currently not used.
receivedTimestamp: pinDetails.pinnedAtTimestamp,
tx: tx,
)
}
return .success
}
}