Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/Signal/test/Backups/BackupAttachmentTrackerTest.swift
1 views
//
// Copyright 2025 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import SignalServiceKit
import Testing

@testable import Signal

class BackupAttachmentTrackerTest<Update: Equatable> {
    struct ExpectedUpdate {
        let update: Update
        let nextSteps: () async -> Void
    }

    actor ExpectedUpdateCompletionTracker {
        private let expectedUpdates: [ExpectedUpdate]
        private var completedExpectedUpdateIndexes: [UUID: Int]

        init(_ expectedUpdates: [ExpectedUpdate]) {
            self.expectedUpdates = expectedUpdates
            self.completedExpectedUpdateIndexes = [:]
        }

        func addNewExpectedUpdateConsumer() -> UUID {
            let uuid = UUID()
            completedExpectedUpdateIndexes[uuid] = -1
            return uuid
        }

        func incrementExpectedUpdate(forConsumer consumer: UUID) async -> ExpectedUpdate {
            let nextIndex = completedExpectedUpdateIndexes[consumer]! + 1
            completedExpectedUpdateIndexes[consumer] = nextIndex

            switch completedExpectedUpdateIndexes.values.areAllEqual() {
            case .no:
                break
            case .yes(let completedIndex):
                await expectedUpdates[completedIndex].nextSteps()
            }

            return expectedUpdates[nextIndex]
        }

        func areAllExpectedUpdatesComplete() -> Bool {
            switch completedExpectedUpdateIndexes.values.areAllEqual() {
            case .no:
                return false
            case .yes(let completedIndex):
                return completedIndex == expectedUpdates.count - 1
            }
        }
    }

    func runTest(
        updateStream: AsyncStream<Update>,
        expectedUpdates: [ExpectedUpdate],
    ) async {
        await runTest(updateStreams: [updateStream], expectedUpdates: expectedUpdates)
    }

    func runTest(
        updateStreams: [AsyncStream<Update>],
        expectedUpdates: [ExpectedUpdate],
    ) async {
        let expectedUpdateCompletionTracker = ExpectedUpdateCompletionTracker(expectedUpdates)
        var streamIds: [UUID] = []
        var streamTasks: [Task<Void, Never>] = []

        for _ in updateStreams {
            streamIds.append(await expectedUpdateCompletionTracker.addNewExpectedUpdateConsumer())
        }

        for (updateStream, id) in zip(updateStreams, streamIds) {
            streamTasks.append(Task {
                for await trackedUpdate in updateStream {
                    let nextExpectedUpdate = await expectedUpdateCompletionTracker
                        .incrementExpectedUpdate(forConsumer: id)

                    #expect(trackedUpdate == nextExpectedUpdate.update)
                }

                if Task.isCancelled {
                    return
                }

                Issue.record("Finished stream without cancellation!")
            })
        }

        let expectedUpdatesCompletedTask = Task {
            while await !expectedUpdateCompletionTracker.areAllExpectedUpdatesComplete() {
                await Task.yield()
            }

            streamTasks.forEach { $0.cancel() }
        }

        await expectedUpdatesCompletedTask.value
        for streamTask in streamTasks {
            await streamTask.value
        }
    }
}

// MARK: -

private extension Dictionary.Values where Element: Equatable {
    enum AllEqualResult {
        case yes(Element)
        case no
    }

    func areAllEqual() -> AllEqualResult {
        guard let first else { return .no }

        if allSatisfy({ $0 == first }) {
            return .yes(first)
        }

        return .no
    }
}