Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Devices/SentMessageTranscriptReceiver/SentMessageTranscriptReceiverImpl.swift
1 views
//
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation
import LibSignalClient

public class SentMessageTranscriptReceiverImpl: SentMessageTranscriptReceiver {

    private let attachmentDownloads: AttachmentDownloadManager
    private let attachmentManager: AttachmentManager
    private let disappearingMessagesExpirationJob: DisappearingMessagesExpirationJob
    private let earlyMessageManager: Shims.EarlyMessageManager
    private let groupManager: Shims.GroupManager
    private let interactionDeleteManager: InteractionDeleteManager
    private let interactionStore: InteractionStore
    private let messageStickerManager: MessageStickerManager
    private let paymentsHelper: PaymentsHelper
    private let pollMessageManager: PollMessageManager
    private let signalProtocolStoreManager: SignalProtocolStoreManager
    private let tsAccountManager: TSAccountManager
    private let viewOnceMessages: Shims.ViewOnceMessages

    public init(
        attachmentDownloads: AttachmentDownloadManager,
        attachmentManager: AttachmentManager,
        disappearingMessagesExpirationJob: DisappearingMessagesExpirationJob,
        earlyMessageManager: Shims.EarlyMessageManager,
        groupManager: Shims.GroupManager,
        interactionDeleteManager: InteractionDeleteManager,
        interactionStore: InteractionStore,
        messageStickerManager: MessageStickerManager,
        paymentsHelper: PaymentsHelper,
        pollMessageManager: PollMessageManager,
        signalProtocolStoreManager: SignalProtocolStoreManager,
        tsAccountManager: TSAccountManager,
        viewOnceMessages: Shims.ViewOnceMessages,
    ) {
        self.attachmentDownloads = attachmentDownloads
        self.attachmentManager = attachmentManager
        self.disappearingMessagesExpirationJob = disappearingMessagesExpirationJob
        self.earlyMessageManager = earlyMessageManager
        self.groupManager = groupManager
        self.interactionDeleteManager = interactionDeleteManager
        self.interactionStore = interactionStore
        self.messageStickerManager = messageStickerManager
        self.paymentsHelper = paymentsHelper
        self.pollMessageManager = pollMessageManager
        self.signalProtocolStoreManager = signalProtocolStoreManager
        self.tsAccountManager = tsAccountManager
        self.viewOnceMessages = viewOnceMessages
    }

    public func process(
        _ transcript: SentMessageTranscript,
        registeredState: RegisteredState,
        tx: DBWriteTransaction,
    ) -> Result<TSOutgoingMessage?, Error> {

        func validateTimestampInt64() -> Bool {
            guard SDS.fitsInInt64(transcript.timestamp) else {
                owsFailDebug("Invalid timestamp.")
                return false
            }
            return true
        }

        func validateTimestampValue() -> Bool {
            guard validateTimestampInt64() else {
                return false
            }
            guard transcript.timestamp >= 1 else {
                owsFailDebug("Transcript is missing timestamp.")
                // This transcript is invalid, discard it.
                return false
            }
            return true
        }

        switch transcript.type {
        case .recipientUpdate(let groupThread):
            // "Recipient updates" are processed completely separately in order
            // to avoid resurrecting threads or messages.
            // No timestamp validation
            return self.processRecipientUpdate(transcript, groupThread: groupThread, tx: tx)
        case .endSessionUpdate(let thread):
            guard validateTimestampInt64() else {
                return .failure(OWSAssertionError("Timestamp validation failed"))
            }
            Logger.info("EndSession was sent to recipient: \(thread.contactAddress)")
            self.archiveSessions(for: thread.contactAddress, tx: tx)

            let infoMessage = TSInfoMessage(thread: thread, messageType: .typeLocalUserEndedSession)
            interactionStore.insertInteraction(infoMessage, tx: tx)

            // Don't continue processing lest we print a bubble for the session reset.
            return .success(nil)
        case .paymentNotification(let paymentNotification):
            Logger.info("Recording payment notification from sync transcript in thread: \(paymentNotification.target.thread.logString) timestamp: \(transcript.timestamp)")
            guard validateTimestampValue() else {
                return .failure(OWSAssertionError("Timestamp validation failed"))
            }
            guard validateProtocolVersion(for: transcript, thread: paymentNotification.target.thread, tx: tx) else {
                return .failure(OWSAssertionError("Protocol version validation failed"))
            }

            let messageTimestamp = paymentNotification.serverTimestamp
            owsAssertDebug(messageTimestamp > 0)

            self.paymentsHelper.processReceivedTranscriptPaymentNotification(
                thread: paymentNotification.target.thread,
                paymentNotification: paymentNotification.notification,
                messageTimestamp: messageTimestamp,
                transaction: tx,
            )
            return .success(nil)
        case .archivedPayment(let archivedPayment):

            guard validateProtocolVersion(for: transcript, thread: archivedPayment.target.thread, tx: tx) else {
                return .failure(OWSAssertionError("Protocol version validation failed"))
            }

            let message = interactionStore.buildOutgoingArchivedPaymentMessage(
                builder: .withDefaultValues(
                    thread: archivedPayment.target.thread,
                    timestamp: transcript.timestamp,
                    expiresInSeconds: archivedPayment.expirationDurationSeconds,
                    // Archived payments don't set the chat timer; version is irrelevant.
                    expireTimerVersion: nil,
                    expireStartedAt: archivedPayment.expirationStartedAt,
                ),
                amount: archivedPayment.amount,
                fee: archivedPayment.fee,
                note: archivedPayment.note,
                tx: tx,
            )

            interactionStore.insertInteraction(message, tx: tx)
            interactionStore.updateRecipientsFromNonLocalDevice(
                message,
                recipientStates: transcript.recipientStates,
                isSentUpdate: false,
                tx: tx,
            )

            return .success(message)
        case .expirationTimerUpdate(let target):
            Logger.info("Recording expiration timer update transcript in thread: \(target.thread.logString) timestamp: \(transcript.timestamp)")
            guard validateTimestampValue() else {
                return .failure(OWSAssertionError("Timestamp validation failed"))
            }
            guard validateProtocolVersion(for: transcript, thread: target.thread, tx: tx) else {
                return .failure(OWSAssertionError("Protocol version validation failed"))
            }

            updateDisappearingMessageTokenIfNecessary(target: target, localIdentifiers: registeredState.localIdentifiers, tx: tx)
            return .success(nil)
        case .message(let messageParams):
            Logger.info("Recording transcript in thread: \(messageParams.target.thread.logString) timestamp: \(transcript.timestamp)")
            guard validateTimestampValue() else {
                return .failure(OWSAssertionError("Timestamp validation failed"))
            }
            return self.process(
                messageParams: messageParams,
                transcript: transcript,
                registeredState: registeredState,
                tx: tx,
            ).map { $0 }
        }
    }

    private func process(
        messageParams: SentMessageTranscriptType.Message,
        transcript: SentMessageTranscript,
        registeredState: RegisteredState,
        tx: DBWriteTransaction,
    ) -> Result<TSOutgoingMessage, Error> {
        guard validateProtocolVersion(for: transcript, thread: messageParams.target.thread, tx: tx) else {
            return .failure(OWSAssertionError("Protocol version validation failed"))
        }

        let localIdentifiers = registeredState.localIdentifiers

        updateDisappearingMessageTokenIfNecessary(target: messageParams.target, localIdentifiers: registeredState.localIdentifiers, tx: tx)

        let outgoingMessageBuilder = TSOutgoingMessageBuilder(
            thread: messageParams.target.thread,
            timestamp: transcript.timestamp,
            receivedAtTimestamp: nil,
            messageBody: messageParams.body,
            editState: .none, // Sent transcripts with edit state are handled by a different codepath
            expiresInSeconds: messageParams.expirationDurationSeconds,
            expireTimerVersion: messageParams.expireTimerVersion,
            expireStartedAt: messageParams.expirationStartedAt,
            isVoiceMessage: false,
            isSmsMessageRestoredFromBackup: false,
            isViewOnceMessage: messageParams.isViewOnceMessage,
            isViewOnceComplete: false,
            wasRemotelyDeleted: false,
            wasNotCreatedLocally: true,
            groupChangeProtoData: nil,
            storyAuthorAci: messageParams.storyAuthorAci,
            storyTimestamp: messageParams.storyTimestamp,
            storyReactionEmoji: nil,
            quotedMessage: messageParams.validatedQuotedReply?.quotedReply,
            contactShare: messageParams.validatedContactShare?.contact,
            linkPreview: messageParams.validatedLinkPreview?.preview,
            messageSticker: messageParams.validatedMessageSticker?.sticker,
            giftBadge: messageParams.giftBadge,
            isPoll: messageParams.validatedPollCreate != nil,
        )
        var outgoingMessage = interactionStore.buildOutgoingMessage(builder: outgoingMessageBuilder, tx: tx)

        let hasRenderableContent = outgoingMessageBuilder.hasRenderableContent(
            hasBodyAttachments: messageParams.attachmentPointerProtos.isEmpty.negated,
            hasLinkPreview: messageParams.validatedLinkPreview != nil,
            hasQuotedReply: messageParams.validatedQuotedReply != nil,
            hasContactShare: messageParams.validatedContactShare != nil,
            hasSticker: messageParams.validatedMessageSticker != nil,
            // Payment notifications go through a different path.
            hasPayment: false,
            hasPoll: messageParams.validatedPollCreate != nil,
        )
        if !hasRenderableContent, !outgoingMessage.isViewOnceMessage {
            switch messageParams.target {
            case .group(let thread):
                if thread.isGroupV2Thread {
                    // This is probably a v2 group update.
                    Logger.warn("Ignoring message transcript for empty v2 group message.")
                } else {
                    owsFailDebug("Got empty message transcript for v1 group. Who sent this?")
                }
            case .contact:
                Logger.warn("Ignoring message transcript for empty message.")
            }

            struct EmptyMessageTranscriptError: Error {}
            return .failure(EmptyMessageTranscriptError())
        }

        let existingFailedMessage = interactionStore.findMessage(
            withTimestamp: outgoingMessage.timestamp,
            threadId: outgoingMessage.uniqueThreadId,
            author: localIdentifiers.aciAddress,
            tx: tx,
        )
        if let existingFailedMessage = existingFailedMessage as? TSOutgoingMessage {
            // Update the reference to the outgoing message so that we apply all updates to the
            // existing copy, and just throw away the new copy before we insert it.
            outgoingMessage = existingFailedMessage
        } else {

            guard let threadRowId = messageParams.target.thread.sqliteRowId else {
                return .failure(OWSAssertionError("Uninserted thread"))
            }

            // Check for any placeholders inserted because of a previously undecryptable message
            // The sender may have resent the message. If so, we should swap it in place of the placeholder
            interactionStore.insertOrReplacePlaceholder(for: outgoingMessage, from: localIdentifiers.aciAddress, tx: tx)

            if let validatedPollCreate = messageParams.validatedPollCreate {
                do {
                    try pollMessageManager.processIncomingPollCreate(
                        interactionId: outgoingMessage.sqliteRowId!,
                        pollCreateProto: validatedPollCreate.pollCreateProto,
                        transaction: tx,
                    )
                } catch {
                    Logger.error("Failed to insert poll \(error)")
                    // Roll back the message
                    interactionDeleteManager.delete(outgoingMessage, sideEffects: .default(), tx: tx)
                }
            }

            do {
                for (idx, proto) in messageParams.attachmentPointerProtos.enumerated() {
                    let attachmentID = try attachmentManager.createAttachmentPointer(
                        from: OwnedAttachmentPointerProto(
                            proto: proto,
                            owner: .messageBodyAttachment(.init(
                                messageRowId: outgoingMessage.sqliteRowId!,
                                receivedAtTimestamp: outgoingMessage.receivedAtTimestamp,
                                threadRowId: threadRowId,
                                isViewOnce: outgoingMessage.isViewOnceMessage,
                                isPastEditRevision: outgoingMessage.isPastEditRevision(),
                                orderInMessage: UInt32(idx),
                            )),
                        ),
                        tx: tx,
                    )
                    Logger.info("Created body attachment \(attachmentID) (idx \(idx)) for received sent-transcript \(transcript.timestamp)")
                }

                if
                    let quotedReplyAttachmentDataSource = messageParams.validatedQuotedReply?.thumbnailDataSource,
                    MimeTypeUtil.isSupportedVisualMediaMimeType(quotedReplyAttachmentDataSource.originalAttachmentMimeType)
                {
                    let attachmentID = try attachmentManager.createQuotedReplyMessageThumbnail(
                        from: quotedReplyAttachmentDataSource,
                        owningMessageAttachmentBuilder: .init(
                            messageRowId: outgoingMessage.sqliteRowId!,
                            receivedAtTimestamp: outgoingMessage.receivedAtTimestamp,
                            threadRowId: threadRowId,
                            isPastEditRevision: outgoingMessage.isPastEditRevision(),
                        ),
                        tx: tx,
                    )
                    Logger.info("Created quoted-reply attachment \(attachmentID) for received sent-transcript \(transcript.timestamp)")
                }

                if let linkPreviewImageProto = messageParams.validatedLinkPreview?.imageProto {
                    let attachmentID = try attachmentManager.createAttachmentPointer(
                        from: OwnedAttachmentPointerProto(
                            proto: linkPreviewImageProto,
                            owner: .messageLinkPreview(.init(
                                messageRowId: outgoingMessage.sqliteRowId!,
                                receivedAtTimestamp: outgoingMessage.receivedAtTimestamp,
                                threadRowId: threadRowId,
                                isPastEditRevision: outgoingMessage.isPastEditRevision(),
                            )),
                        ),
                        tx: tx,
                    )
                    Logger.info("Created link preview attachment \(attachmentID) for received sent-transcript \(transcript.timestamp)")
                }

                if let validatedMessageSticker = messageParams.validatedMessageSticker {
                    let attachmentID = try attachmentManager.createAttachmentPointer(
                        from: OwnedAttachmentPointerProto(
                            proto: validatedMessageSticker.proto,
                            owner: .messageSticker(.init(
                                messageRowId: outgoingMessage.sqliteRowId!,
                                receivedAtTimestamp: outgoingMessage.receivedAtTimestamp,
                                threadRowId: threadRowId,
                                isPastEditRevision: outgoingMessage.isPastEditRevision(),
                                stickerPackId: validatedMessageSticker.sticker.packId,
                                stickerId: validatedMessageSticker.sticker.stickerId,
                            )),
                        ),
                        tx: tx,
                    )
                    Logger.info("Created sticker attachment \(attachmentID) for received sent-transcript \(transcript.timestamp)")
                }

                if let contactAvatarProto = messageParams.validatedContactShare?.avatarProto {
                    let attachmentID = try attachmentManager.createAttachmentPointer(
                        from: OwnedAttachmentPointerProto(
                            proto: contactAvatarProto,
                            owner: .messageContactAvatar(.init(
                                messageRowId: outgoingMessage.sqliteRowId!,
                                receivedAtTimestamp: outgoingMessage.receivedAtTimestamp,
                                threadRowId: threadRowId,
                                isPastEditRevision: outgoingMessage.isPastEditRevision(),
                            )),
                        ),
                        tx: tx,
                    )
                    Logger.info("Created contact avatar attachment \(attachmentID) for received sent-transcript \(transcript.timestamp)")
                }
            } catch let error {
                Logger.error("Attachment failure: \(error)")
                // Roll back the message
                interactionDeleteManager.delete(outgoingMessage, sideEffects: .default(), tx: tx)
                return .failure(error)
            }
        }
        owsAssertDebug(interactionStore.insertedMessageHasRenderableContent(
            message: outgoingMessage,
            rowId: outgoingMessage.sqliteRowId!,
            tx: tx,
        ))

        let recipientStates: [SignalServiceAddress: TSOutgoingMessageRecipientState] = {
            switch messageParams.target {
            case .contact(let contactThread, _) where localIdentifiers.contains(address: contactThread.contactAddress):
                // If this is a sent transcript that went to our Note to Self,
                // we should force it as read.
                return [
                    localIdentifiers.aciAddress: TSOutgoingMessageRecipientState(status: .read),
                ]
            case .contact, .group:
                return transcript.recipientStates
            }
        }()
        interactionStore.updateRecipientsFromNonLocalDevice(
            outgoingMessage,
            recipientStates: recipientStates,
            isSentUpdate: false,
            tx: tx,
        )

        if let expirationStartedAt = messageParams.expirationStartedAt {
            /// The insert and update methods above may start expiration for
            /// this message, but transcript.expirationStartedAt may be earlier,
            /// so we need to pass that to DisappearingMessagesExpirationJob in
            /// case it needs to back-date the expiration.
            disappearingMessagesExpirationJob.startExpiration(
                forMessage: outgoingMessage,
                expirationStartedAt: expirationStartedAt,
                tx: tx,
            )
        }

        self.earlyMessageManager.applyPendingMessages(for: outgoingMessage, registeredState: registeredState, tx: tx)

        if outgoingMessage.isViewOnceMessage {
            // Don't download attachments for "view-once" messages from linked devices.
            // To be extra-conservative, always mark as complete immediately.
            viewOnceMessages.markAsComplete(message: outgoingMessage, sendSyncMessages: false, tx: tx)
        } else {
            attachmentDownloads.enqueueDownloadOfAttachmentsForMessage(outgoingMessage, tx: tx)
        }

        return .success(outgoingMessage)
    }

    private func validateProtocolVersion(
        for transcript: SentMessageTranscript,
        thread: TSThread,
        tx: DBWriteTransaction,
    ) -> Bool {
        if
            let requiredProtocolVersion = transcript.requiredProtocolVersion,
            requiredProtocolVersion > SSKProtos.currentProtocolVersion
        {
            owsFailDebug("Unknown protocol version: \(requiredProtocolVersion)")

            let message = OWSUnknownProtocolVersionMessage(
                thread: thread,
                timestamp: MessageTimestampGenerator.sharedInstance.generateTimestamp(),
                sender: nil,
                protocolVersion: UInt(requiredProtocolVersion),
            )
            interactionStore.insertInteraction(message, tx: tx)
            return false
        }
        return true
    }

    private func updateDisappearingMessageTokenIfNecessary(
        target: SentMessageTranscriptTarget,
        localIdentifiers: LocalIdentifiers,
        tx: DBWriteTransaction,
    ) {
        switch target {
        case .group:
            return
        case .contact(let thread, let disappearingMessageToken):
            groupManager.remoteUpdateDisappearingMessages(
                withContactThread: thread,
                disappearingMessageToken: disappearingMessageToken,
                changeAuthor: localIdentifiers.aci,
                localIdentifiers: localIdentifiers,
                tx: tx,
            )
        }
    }

    // MARK: -

    private func processRecipientUpdate(
        _ transcript: SentMessageTranscript,
        groupThread: TSGroupThread,
        tx: DBWriteTransaction,
    ) -> Result<TSOutgoingMessage?, Error> {

        if transcript.recipientStates.isEmpty {
            return .failure(OWSAssertionError("Ignoring empty 'recipient update' transcript."))
        }

        let timestamp = transcript.timestamp
        if timestamp < 1 {
            return .failure(OWSAssertionError("'recipient update' transcript has invalid timestamp."))
        }
        if !SDS.fitsInInt64(timestamp) {
            return .failure(OWSAssertionError("Invalid timestamp."))
        }

        let groupId = groupThread.groupId
        if groupId.isEmpty {
            return .failure(OWSAssertionError("'recipient update' transcript has invalid groupId."))
        }

        let messages: [TSOutgoingMessage]
        do {
            messages = try interactionStore
                .fetchInteractions(timestamp: timestamp, tx: tx)
                .compactMap { $0 as? TSOutgoingMessage }
        } catch {
            return .failure(OWSAssertionError("Error loading interactions: \(error)"))
        }

        if messages.isEmpty {
            // This message may have disappeared.
            Logger.error("No matching message with timestamp: \(timestamp)")
            return .success(nil)
        }

        var messageFound: TSOutgoingMessage?
        for message in messages {
            guard message.wasNotCreatedLocally else {
                // wasNotCreatedLocally isn't always set for very old linked messages, but:
                //
                // a) We should never receive a "sent update" for a very old message.
                // b) It's safe to discard suspicious "sent updates."
                continue
            }
            guard message.uniqueThreadId == groupThread.uniqueId else {
                continue
            }

            Logger.info("Processing 'recipient update' transcript in thread: \(groupThread.logString), timestamp: \(timestamp), recipientIds: \(transcript.recipientStates.keys)")

            interactionStore.updateRecipientsFromNonLocalDevice(
                message,
                recipientStates: transcript.recipientStates,
                isSentUpdate: true,
                tx: tx,
            )

            // In theory more than one message could be found.
            // In practice, this should never happen, as we functionally
            // use timestamps as unique identifiers.
            messageFound = message
        }

        if messageFound == nil {
            // This message may have disappeared.
            Logger.error("No matching message with timestamp: \(timestamp)")
        }

        return .success(messageFound)
    }

    private func archiveSessions(for address: SignalServiceAddress, tx: DBWriteTransaction) {
        self.signalProtocolStoreManager.signalProtocolStore(for: .aci).sessionStore.archiveSessions(forAddress: address, tx: tx)
    }
}