Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Storage/Database/Records/ThreadFinder.swift
1 views
//
// Copyright 2019 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import GRDB

public class ThreadFinder {
    public init() {}

    private func requiredVisibleThreadsClause(forThreadIds threadIds: Set<String>) -> String {
        if threadIds.isEmpty {
            return ""
        } else {
            let threadIdsExpression = threadIds.lazy.map { "'\($0)'" }.joined(separator: ", ")
            return "OR \(threadColumnFullyQualified: .uniqueId) IN (\(threadIdsExpression))"
        }
    }

    /// Fetch a thread with the given SQLite row ID, if one exists.
    public func fetch(rowId: Int64, tx: DBReadTransaction) -> TSThread? {
        guard
            let thread = TSThread.anyFetch(
                sql: """
                    SELECT *
                    FROM \(TSThread.databaseTableName)
                    WHERE \(threadColumn: .id) = ?
                """,
                arguments: [rowId],
                transaction: tx,
            )
        else {
            owsFailDebug("Missing thread with row ID - how did we get this row ID?")
            return nil
        }

        return thread
    }

    public func fetchUniqueIds(tx: DBReadTransaction) -> [String] {
        return failIfThrows {
            do {
                return try String.fetchAll(
                    tx.database,
                    sql: "SELECT \(threadColumn: .uniqueId) FROM \(TSThread.databaseTableName)",
                )
            } catch {
                throw error.grdbErrorForLogging
            }
        }
    }

    /// Enumerates through all story thread (distribution lists)
    /// - Parameter block
    /// A block executed for each enumerated thread. Returns `true` if
    /// enumeration should continue, and `false` otherwise.
    public func enumerateStoryThreads(
        transaction: DBReadTransaction,
        block: (TSPrivateStoryThread) throws -> Bool,
    ) throws {
        let sql = """
            SELECT *
            FROM \(TSThread.databaseTableName)
            WHERE \(threadColumn: .recordType) = \(SDSRecordType.privateStoryThread.rawValue)
        """
        let cursor = try TSPrivateStoryThread.fetchCursor(
            transaction.database,
            sql: sql,
        )
        while let storyThread = try cursor.next() {
            guard try block(storyThread) else {
                break
            }
        }
    }

    /// Enumerates group threads in "last interaction" order.
    /// - Parameter block
    /// A block executed for each enumerated thread. Returns `true` if
    /// enumeration should continue, and `false` otherwise.
    public func enumerateGroupThreads(
        transaction: DBReadTransaction,
        block: (TSGroupThread) throws -> Bool,
    ) throws {
        let sql = """
            SELECT *
            FROM \(TSThread.databaseTableName)
            WHERE \(groupThreadColumn: .groupModel) IS NOT NULL
            ORDER BY \(threadColumn: .lastInteractionRowId) DESC
        """

        let cursor = try TSThread.fetchCursor(
            transaction.database,
            sql: sql,
        )
        while let threadRecord = try cursor.next() {
            guard let groupThread = threadRecord as? TSGroupThread else {
                owsFailDebug("Skipping thread that's not a group.")
                continue
            }
            guard try block(groupThread) else {
                break
            }
        }
    }

    /// Enumerates all non-story threads in arbitrary order.
    /// - Parameter block
    /// A block executed for each enumerated thread. Returns `true` if
    /// enumeration should continue, and `false` otherwise.
    public func enumerateNonStoryThreads(
        transaction: DBReadTransaction,
        block: (TSThread) throws -> Bool,
    ) throws {
        let sql = """
            SELECT *
            FROM \(TSThread.databaseTableName)
            WHERE \(threadColumn: .recordType) IS NOT ?
        """

        let cursor = try TSThread.fetchCursor(
            transaction.database,
            sql: sql,
            arguments: [SDSRecordType.privateStoryThread.rawValue],
        )
        while let thread = try cursor.next(), try block(thread) {}
    }

    public func visibleThreadCount(
        isArchived: Bool,
        transaction: DBReadTransaction,
    ) throws -> UInt {
        let sql = """
        SELECT COUNT(*)
        FROM \(TSThread.databaseTableName)
        \(threadAssociatedDataJoinClause(isArchived: isArchived))
        WHERE \(threadColumn: .shouldThreadBeVisible) = 1
        """

        guard
            let count = try UInt.fetchOne(
                transaction.database,
                sql: sql,
            )
        else {
            owsFailDebug("count was unexpectedly nil")
            return 0
        }

        return count
    }

    public func enumerateVisibleThreads(
        isArchived: Bool,
        transaction: DBReadTransaction,
        block: (TSThread) -> Void,
    ) {
        let sql = """
        SELECT *
        FROM \(TSThread.databaseTableName)
        \(threadAssociatedDataJoinClause(isArchived: isArchived))
        WHERE \(threadColumn: .shouldThreadBeVisible) = 1
        ORDER BY \(threadColumn: .lastInteractionRowId) DESC
        """

        failIfThrows {
            try TSThread.fetchCursor(
                transaction.database,
                sql: sql,
            ).forEach { thread in
                block(thread)
            }
        }
    }

    public func fetchContactSyncThreadRowIds(tx: DBReadTransaction) throws -> [Int64] {
        let sql = """
        SELECT \(threadColumn: .id)
        FROM \(TSThread.databaseTableName)
        WHERE \(threadColumn: .shouldThreadBeVisible) = 1
        ORDER BY \(threadColumn: .lastInteractionRowId) DESC
        """
        do {
            return try Int64.fetchAll(tx.database, sql: sql)
        } catch {
            throw error.grdbErrorForLogging
        }
    }

    public func hasPendingMessageRequest(
        thread: TSThread,
        transaction: DBReadTransaction,
    ) -> Bool {
        // TODO: Should we consult isRequestingMember() here?
        if let groupThread = thread as? TSGroupThread, groupThread.isGroupV2Thread, groupThread.groupModel.groupMembership.isLocalUserInvitedMember {
            return true
        }

        // If we're creating the thread, don't show the message request view
        if !thread.shouldThreadBeVisible {
            return false
        }

        // If this thread is blocked AND we're still in the thread, show the message
        // request view regardless of if we have sent messages or not.
        if SSKEnvironment.shared.blockingManagerRef.isThreadBlocked(thread, transaction: transaction) {
            return true
        }

        let isGroupThread = thread is TSGroupThread
        let isLocalUserInGroup = (thread as? TSGroupThread)?.groupModel.groupMembership.isLocalUserFullOrInvitedMember == true

        // If this is a group thread and we're not a member, never show the message request.
        if isGroupThread, !isLocalUserInGroup {
            return false
        }

        let interactionFinder = InteractionFinder(threadUniqueId: thread.uniqueId)

        let recipientDatabaseTable = DependenciesBridge.shared.recipientDatabaseTable
        let recipientHidingManager = DependenciesBridge.shared.recipientHidingManager
        if
            let contactThread = thread as? TSContactThread,
            let signalRecipient = recipientDatabaseTable.fetchRecipient(
                contactThread: contactThread,
                tx: transaction,
            ),
            let hiddenRecipient = recipientHidingManager.fetchHiddenRecipient(
                recipientId: signalRecipient.id,
                tx: transaction,
            )
        {
            return recipientHidingManager.isHiddenRecipientThreadInMessageRequest(
                hiddenRecipient: hiddenRecipient,
                contactThread: contactThread,
                tx: transaction,
            )
        }

        // If the thread is already whitelisted, do nothing. The user has already
        // accepted the request for this thread.
        if SSKEnvironment.shared.profileManagerRef.isThread(inProfileWhitelist: thread, transaction: transaction) {
            return false
        }

        // At this point, we know this is an un-whitelisted group thread.
        // If someone added us to the group, there will be a group update info message
        // in which case we want to show a pending message request. If the thread
        // is otherwise empty, we don't want to show the message request.
        if isGroupThread, interactionFinder.hasGroupUpdateInfoMessage(transaction: transaction) {
            return true
        }

        // This thread is likely only visible because of system messages like so-and-so
        // is on signal or sync status. Some of the "possibly" incoming messages might
        // actually have been triggered by us, but if we sent one of these then the thread
        // should be in our profile white list and not make it to this check.
        return interactionFinder.possiblyHasIncomingMessages(transaction: transaction)
    }

    /// Whether we should set the default timer for the given contact thread.
    ///
    /// - Note
    /// We never set the default timer for group threads, which are instead set
    /// during group creation.
    public func shouldSetDefaultDisappearingMessageTimer(
        contactThread: TSContactThread,
        transaction tx: DBReadTransaction,
    ) -> Bool {
        let dmConfigurationStore = DependenciesBridge.shared.disappearingMessagesConfigurationStore

        // Make sure the universal timer is enabled.
        guard
            dmConfigurationStore.fetchOrBuildDefault(
                for: .universal,
                tx: tx,
            ).isEnabled
        else {
            return false
        }

        // Make sure the current timer is disabled.
        guard
            !dmConfigurationStore.fetchOrBuildDefault(
                for: .thread(contactThread),
                tx: tx,
            ).isEnabled
        else {
            return false
        }

        // Make sure there has been no user initiated interactions.
        return !InteractionFinder(threadUniqueId: contactThread.uniqueId)
            .hasUserInitiatedInteraction(transaction: tx)
    }

    public func existsGroupThread(transaction: DBReadTransaction) -> Bool {
        let sql = """
            SELECT EXISTS(
                SELECT 1
                FROM \(TSThread.databaseTableName)
                WHERE \(threadColumn: .recordType) = ?
                LIMIT 1
            )
        """
        let arguments: StatementArguments = [SDSRecordType.groupThread.rawValue]
        return failIfThrows {
            return try Bool.fetchOne(
                transaction.database,
                sql: sql,
                arguments: arguments,
            ) ?? false
        }
    }

    public func storyThreads(
        includeImplicitGroupThreads: Bool,
        transaction: DBReadTransaction,
    ) -> [TSThread] {
        var allowedDefaultThreadIds = [String]()

        if includeImplicitGroupThreads {
            // Prefetch the group thread uniqueIds that currently have stories
            // TODO: We could potential join on the KVS for groupId -> threadId
            // to further reduce the number of queries required here, but it
            // may be overkill.

            let storyMessageGroupIdsSQL = """
                SELECT DISTINCT \(StoryMessage.columnName(.groupId))
                FROM \(StoryMessage.databaseTableName)
                WHERE \(StoryMessage.columnName(.groupId)) IS NOT NULL
            """

            do {
                let groupIdCursor = try Data.fetchCursor(
                    transaction.database,
                    sql: storyMessageGroupIdsSQL,
                )

                while let groupId = try groupIdCursor.next() {
                    allowedDefaultThreadIds.append(TSGroupThread.threadId(
                        forGroupId: groupId,
                        transaction: transaction,
                    ))
                }
            } catch {
                owsFailDebug("Failed to query group thread ids \(error)")
            }
        }

        let sql = """
            SELECT *
            FROM \(TSThread.databaseTableName)
            WHERE \(threadColumn: .storyViewMode) != \(TSThreadStoryViewMode.disabled.rawValue)
            AND \(threadColumn: .storyViewMode) != \(TSThreadStoryViewMode.default.rawValue)
            OR (
                \(threadColumn: .storyViewMode) = \(TSThreadStoryViewMode.default.rawValue)
                AND \(threadColumn: .recordType) = \(SDSRecordType.groupThread.rawValue)
                AND \(threadColumn: .uniqueId) IN (\(allowedDefaultThreadIds.map { "\"\($0)\"" }.joined(separator: ", ")))
            )
            ORDER BY \(threadColumn: .lastSentStoryTimestamp) DESC
        """

        var threads = [TSThread]()
        TSThread.anyEnumerate(
            transaction: transaction,
            sql: sql,
            arguments: [],
            block: { thread, stop in
                if let groupThread = thread as? TSGroupThread {
                    guard groupThread.isStorySendEnabled(transaction: transaction) else {
                        return
                    }
                }
                threads.append(thread)
            },
        )
        return threads
    }

    public func threadsWithRecentInteractions(
        limit: UInt,
        transaction: DBReadTransaction,
    ) -> [TSThread] {
        let sql = """
            SELECT *
            FROM \(TSThread.databaseTableName)
            ORDER BY \(threadColumn: .lastInteractionRowId) DESC
            LIMIT \(limit)
        """

        var threads = [TSThread]()
        TSThread.anyEnumerate(
            transaction: transaction,
            sql: sql,
            arguments: [],
            block: { thread, stop in
                threads.append(thread)
            },
        )
        return threads
    }

    private func threadAssociatedDataJoinClause(isArchived: Bool) -> String {
        """
        INNER JOIN \(ThreadAssociatedData.databaseTableName)
            ON \(ThreadAssociatedData.databaseTableName).threadUniqueId = \(threadColumnFullyQualified: .uniqueId)
            AND \(ThreadAssociatedData.databaseTableName).isArchived = \(isArchived ? "1" : "0")
        """
    }

    // MARK: -

    public func visibleInboxThreadIds(
        filteredBy inboxFilter: InboxFilter? = nil,
        requiredVisibleThreadIds: Set<String> = [],
        transaction: DBReadTransaction,
    ) throws -> [String] {
        if inboxFilter == .unread {
            let sql = """
            SELECT
                \(threadColumnFullyQualified: .uniqueId) AS thread_uniqueId,
                \(ThreadAssociatedData.databaseTableName).isMarkedUnread AS thread_isMarkedUnread,
                COUNT(i.\(interactionColumn: .uniqueId)) AS interactions_unreadCount
            FROM \(TSThread.databaseTableName)
            INNER JOIN \(ThreadAssociatedData.databaseTableName)
                ON \(ThreadAssociatedData.databaseTableName).threadUniqueId = \(threadColumnFullyQualified: .uniqueId)
                AND \(ThreadAssociatedData.databaseTableName).isArchived = 0
            LEFT OUTER JOIN \(InteractionRecord.databaseTableName) AS i
                \(DEBUG_INDEXED_BY("index_model_TSInteraction_UnreadMessages"))
                ON i.\(interactionColumn: .threadUniqueId) = thread_uniqueId
                AND \(InteractionFinder.sqlClauseForUnreadInteractionCounts(interactionsAlias: "i"))
            WHERE \(threadColumnFullyQualified: .shouldThreadBeVisible) = 1
            GROUP BY thread_uniqueId
            HAVING (
                thread_isMarkedUnread = 1
                OR interactions_unreadCount > 0
                \(requiredVisibleThreadsClause(forThreadIds: requiredVisibleThreadIds))
            )
            ORDER BY
                CASE WHEN \(threadColumn: .lastDraftInteractionRowId) > \(threadColumn: .lastInteractionRowId)
                    THEN \(threadColumn: .lastDraftInteractionRowId) ELSE \(threadColumn: .lastInteractionRowId)
                END DESC,
                \(threadColumn: .lastDraftUpdateTimestamp) DESC
            """

            return try String.fetchAll(transaction.database, sql: sql, adapter: RangeRowAdapter(0..<1))
        } else {
            let sql = """
            SELECT \(threadColumn: .uniqueId)
            FROM \(TSThread.databaseTableName)
            INNER JOIN \(ThreadAssociatedData.databaseTableName)
                ON \(ThreadAssociatedData.databaseTableName).threadUniqueId = \(threadColumnFullyQualified: .uniqueId)
                AND \(ThreadAssociatedData.databaseTableName).isArchived = 0
            WHERE \(threadColumn: .shouldThreadBeVisible) = 1
            ORDER BY
                CASE WHEN \(threadColumn: .lastDraftInteractionRowId) > \(threadColumn: .lastInteractionRowId)
                    THEN \(threadColumn: .lastDraftInteractionRowId) ELSE \(threadColumn: .lastInteractionRowId)
                END DESC,
                \(threadColumn: .lastDraftUpdateTimestamp) DESC
            """
            return try String.fetchAll(transaction.database, sql: sql)
        }
    }

    public func visibleArchivedThreadIds(
        transaction: DBReadTransaction,
    ) throws -> [String] {
        let sql = """
        SELECT \(threadColumn: .uniqueId)
        FROM \(TSThread.databaseTableName)
        \(threadAssociatedDataJoinClause(isArchived: true))
        WHERE \(threadColumn: .shouldThreadBeVisible) = 1
        ORDER BY
            CASE WHEN \(threadColumn: .lastDraftInteractionRowId) > \(threadColumn: .lastInteractionRowId)
                THEN \(threadColumn: .lastDraftInteractionRowId) ELSE \(threadColumn: .lastInteractionRowId)
            END DESC,
            \(threadColumn: .lastDraftUpdateTimestamp) DESC
        """

        return try String.fetchAll(transaction.database, sql: sql)
    }
}