Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Payments/PaymentFinder.swift
1 views
//
// Copyright 2021 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation
import GRDB

public class PaymentFinder {

    public class func paymentModels(
        paymentStates: [TSPaymentState],
        transaction: DBReadTransaction,
    ) -> [TSPaymentModel] {
        let paymentStatesToLookup = paymentStates.compactMap { $0.rawValue }.map { "\($0)" }.joined(separator: ",")

        let sql = """
        SELECT * FROM \(TSPaymentModel.databaseTableName)
        WHERE \(paymentModelColumn: .paymentState) IN (\(paymentStatesToLookup))
        """

        return fetchAll(sql: sql, arguments: [], tx: transaction)
    }

    public class func firstUnreadPaymentModel(transaction: DBReadTransaction) -> TSPaymentModel? {
        let sql = """
        SELECT * FROM \(TSPaymentModel.databaseTableName)
        WHERE \(paymentModelColumn: .isUnread) = 1
        LIMIT 1
        """
        return TSPaymentModel.anyFetch(
            sql: sql,
            arguments: [],
            transaction: transaction,
        )
    }

    public class func allUnreadPaymentModels(transaction: DBReadTransaction) -> [TSPaymentModel] {
        let sql = """
        SELECT * FROM \(TSPaymentModel.databaseTableName)
        WHERE \(paymentModelColumn: .isUnread) = 1
        """
        return fetchAll(sql: sql, arguments: [], tx: transaction)
    }

    public class func unreadCount(transaction: DBReadTransaction) -> UInt {
        return failIfThrows {
            return try UInt.fetchOne(
                transaction.database,
                sql: """
                SELECT COUNT(*)
                FROM \(TSPaymentModel.databaseTableName)
                WHERE \(paymentModelColumn: .isUnread) = 1
                """,
                arguments: [],
            )!
        }
    }

    public class func paymentModels(
        forMcLedgerBlockIndex mcLedgerBlockIndex: UInt64,
        transaction: DBReadTransaction,
    ) -> [TSPaymentModel] {
        let sql = """
        SELECT * FROM \(TSPaymentModel.databaseTableName)
        WHERE \(paymentModelColumn: .mcLedgerBlockIndex) = ?
        """
        return fetchAll(sql: sql, arguments: [mcLedgerBlockIndex], tx: transaction)
    }

    public class func paymentModels(
        forMcReceiptData mcReceiptData: Data,
        transaction: DBReadTransaction,
    ) -> [TSPaymentModel] {
        let sql = """
        SELECT * FROM \(TSPaymentModel.databaseTableName)
        WHERE \(paymentModelColumn: .mcReceiptData) = ?
        """
        return fetchAll(sql: sql, arguments: [mcReceiptData], tx: transaction)
    }

    public class func paymentModels(
        forMcTransactionData mcTransactionData: Data,
        transaction: DBReadTransaction,
    ) -> [TSPaymentModel] {
        let sql = """
        SELECT * FROM \(TSPaymentModel.databaseTableName)
        WHERE \(paymentModelColumn: .mcTransactionData) = ?
        """
        return fetchAll(sql: sql, arguments: [mcTransactionData], tx: transaction)
    }

    private static func fetchAll(sql: String, arguments: StatementArguments, tx: DBReadTransaction) -> [TSPaymentModel] {
        var results = [TSPaymentModel]()
        TSPaymentModel.anyEnumerate(
            transaction: tx,
            sql: sql,
            arguments: arguments,
            block: { paymentModel, _ in
                results.append(paymentModel)
            },
        )
        return results
    }
}