Path: blob/main/SignalServiceKit/Messages/Attachments/V2/Downloads/AttachmentDownloadManagerImpl.swift
1 views
//
// Copyright 2024 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
public class AttachmentDownloadManagerImpl: AttachmentDownloadManager {
private enum DownloadResult {
case stream(AttachmentStream)
case thumbnail(AttachmentBackupThumbnail)
}
private let appReadiness: AppReadiness
private let attachmentDownloadStore: AttachmentDownloadStore
private let attachmentStore: AttachmentStore
private let attachmentUpdater: AttachmentUpdater
private let backupSettingsStore: BackupSettingsStore
private let db: any DB
private let decrypter: Decrypter
private let downloadQueue: DownloadQueue
private let downloadabilityChecker: DownloadabilityChecker
private let progressStates: ProgressStates
private let queueLoader: TaskQueueLoader<DownloadTaskRunner>
private let remoteConfigProvider: any RemoteConfigProvider
private let tsAccountManager: TSAccountManager
public init(
accountKeyStore: AccountKeyStore,
appReadiness: AppReadiness,
attachmentDownloadStore: AttachmentDownloadStore,
attachmentStore: AttachmentStore,
attachmentUploadStore: AttachmentUploadStore,
attachmentValidator: AttachmentContentValidator,
backupAttachmentUploadScheduler: BackupAttachmentUploadScheduler,
backupRequestManager: BackupRequestManager,
backupSettingsStore: BackupSettingsStore,
currentCallProvider: CurrentCallProvider,
dateProvider: @escaping DateProvider,
db: any DB,
interactionStore: InteractionStore,
mediaBandwidthPreferenceStore: MediaBandwidthPreferenceStore,
orphanedAttachmentCleaner: OrphanedAttachmentCleaner,
orphanedAttachmentStore: OrphanedAttachmentStore,
orphanedBackupAttachmentScheduler: OrphanedBackupAttachmentScheduler,
profileManager: ProfileManager,
reachabilityManager: SSKReachabilityManager,
remoteConfigProvider: any RemoteConfigProvider,
signalService: OWSSignalServiceProtocol,
stickerManager: Shims.StickerManager,
storyStore: any StoryStore,
threadStore: ThreadStore,
tsAccountManager: TSAccountManager,
) {
self.attachmentDownloadStore = attachmentDownloadStore
self.attachmentStore = attachmentStore
self.appReadiness = appReadiness
self.backupSettingsStore = backupSettingsStore
self.db = db
self.decrypter = Decrypter(
attachmentValidator: attachmentValidator,
stickerManager: stickerManager,
)
self.progressStates = ProgressStates()
self.downloadQueue = DownloadQueue(
progressStates: progressStates,
signalService: signalService,
)
self.attachmentUpdater = AttachmentUpdater(
attachmentStore: attachmentStore,
backupAttachmentUploadScheduler: backupAttachmentUploadScheduler,
db: db,
decrypter: decrypter,
interactionStore: interactionStore,
orphanedAttachmentCleaner: orphanedAttachmentCleaner,
orphanedAttachmentStore: orphanedAttachmentStore,
orphanedBackupAttachmentScheduler: orphanedBackupAttachmentScheduler,
storyStore: storyStore,
threadStore: threadStore,
)
self.downloadabilityChecker = DownloadabilityChecker(
attachmentStore: attachmentStore,
backupSettingsStore: backupSettingsStore,
currentCallProvider: currentCallProvider,
db: db,
mediaBandwidthPreferenceStore: mediaBandwidthPreferenceStore,
profileManager: profileManager,
reachabilityManager: reachabilityManager,
storyStore: storyStore,
threadStore: threadStore,
)
let taskRunner = DownloadTaskRunner(
accountKeyStore: accountKeyStore,
attachmentDownloadStore: attachmentDownloadStore,
attachmentStore: attachmentStore,
attachmentUpdater: attachmentUpdater,
attachmentUploadStore: attachmentUploadStore,
backupRequestManager: backupRequestManager,
dateProvider: dateProvider,
db: db,
decrypter: decrypter,
downloadQueue: downloadQueue,
downloadabilityChecker: downloadabilityChecker,
remoteConfigProvider: remoteConfigProvider,
stickerManager: stickerManager,
tsAccountManager: tsAccountManager,
)
self.queueLoader = TaskQueueLoader(
maxConcurrentTasks: 12,
dateProvider: dateProvider,
db: db,
runner: taskRunner,
)
self.remoteConfigProvider = remoteConfigProvider
self.tsAccountManager = tsAccountManager
appReadiness.runNowOrWhenMainAppDidBecomeReadyAsync { [weak self] in
guard let self else { return }
self.beginDownloadingIfNecessary()
NotificationCenter.default.addObserver(
self,
selector: #selector(registrationStateDidChange),
name: .registrationStateDidChange,
object: nil,
)
}
}
@objc
private func registrationStateDidChange() {
AssertIsOnMainThread()
guard tsAccountManager.registrationStateWithMaybeSneakyTransaction.isRegistered else { return }
self.beginDownloadingIfNecessary()
}
private func maxEncryptedBackupDownloadSize() -> UInt64 {
return 1_000_000_000
}
public func downloadBackup(
metadata: BackupReadCredential,
progress: OWSProgressSink?,
) async throws -> URL {
let uuid = UUID()
let downloadState = DownloadState(type: .backup(metadata: metadata, uuid: uuid))
return try await self.downloadQueue.enqueueDownload(
downloadState: downloadState,
maxDownloadSizeBytes: maxEncryptedBackupDownloadSize(),
expectedDownloadSize: .useHeadRequest,
progress: progress,
)
}
public func backupCdnInfo(
metadata: BackupReadCredential,
) async throws -> BackupCdnInfo {
let uuid = UUID()
let downloadState = DownloadState(type: .backup(metadata: metadata, uuid: uuid))
var prefixLength = BackupNonce.metadataHeaderByteLengthUpperBound
while true {
let (cdnInfo, prefix) = try await self.downloadQueue.performPrefixRequest(
downloadState: downloadState,
maxDownloadSizeBytes: maxEncryptedBackupDownloadSize(),
length: prefixLength,
)
do throws(BackupNonce.MetadataHeader.ParsingError) {
let metadataHeader = try BackupNonce.MetadataHeader.from(prefixBytes: prefix)
return BackupCdnInfo(
fileInfo: cdnInfo,
metadataHeader: metadataHeader,
)
} catch {
switch error {
case .unrecognizedFileSignature:
throw OWSAssertionError("Unrecognized backup file signature")
case .dataMissingOrEmpty:
throw OWSAssertionError("Missing backup file prefix data")
case .headerTooLarge:
throw OWSAssertionError("Backup header too large")
case .moreDataNeeded(let length):
if length <= prefixLength {
// We got fewer bytes than we requested
throw OWSAssertionError("Backup file too small!")
}
prefixLength = length
}
}
}
}
public func downloadTransientAttachment(
metadata: AttachmentDownloads.DownloadMetadata,
progress: OWSProgressSink?,
) async throws -> URL {
// We want to avoid large downloads from a compromised or buggy service.
let maxDownloadSize = self.remoteConfigProvider.currentConfig().attachmentMaxEncryptedReceiveBytes
let downloadState = DownloadState(type: .transientAttachment(metadata, uuid: UUID()))
let encryptedFileUrl = try await self.downloadQueue.enqueueDownload(
downloadState: downloadState,
maxDownloadSizeBytes: maxDownloadSize,
expectedDownloadSize: metadata.plaintextLength.map({ .estimatedSizeBytes(UInt($0)) }) ?? .useHeadRequest,
progress: progress,
)
switch metadata.source {
case .linkNSyncBackup:
return encryptedFileUrl
default:
return try await self.decrypter.decryptTransientAttachment(encryptedFileUrl: encryptedFileUrl, metadata: metadata)
}
}
public func enqueueDownloadOfAttachmentsForMessage(
_ message: TSMessage,
priority: AttachmentDownloadPriority,
tx: DBWriteTransaction,
) {
guard let messageRowId = message.sqliteRowId else {
owsFailDebug("Downloading attachments for uninserted message!")
return
}
var referencedAttachments = attachmentStore.fetchReferencedAttachmentsOwnedByMessage(
messageRowId: messageRowId,
tx: tx,
)
// Do not enqueue download of the thumbnail for quotes for which
// we have the target message locally; the thumbnail will be filled in
// IFF we download the original attachment.
if
let quotedMessage = message.quotedMessage,
quotedMessage.bodySource == .local
{
referencedAttachments.removeAll {
switch $0.reference.owner {
case .message(.quotedReply): true
default: false
}
}
}
enqueueDownloadOfAttachments(referencedAttachments, priority: priority, tx: tx)
}
public func enqueueDownloadOfAttachmentsForStoryMessage(
_ message: StoryMessage,
priority: AttachmentDownloadPriority,
tx: DBWriteTransaction,
) {
guard let storyMessageRowId = message.id else {
owsFailDebug("Downloading attachments for uninserted message!")
return
}
let referencedAttachments = attachmentStore.fetchReferencedAttachmentsOwnedByStory(
storyMessageRowId: storyMessageRowId,
tx: tx,
)
enqueueDownloadOfAttachments(referencedAttachments, priority: priority, tx: tx)
}
private func enqueueDownloadOfAttachments(
_ referencedAttachments: [ReferencedAttachment],
priority: AttachmentDownloadPriority,
tx: DBWriteTransaction,
) {
let backupPlan = backupSettingsStore.backupPlan(tx: tx)
let isEligibleToDownloadFromMediaTier: Bool
switch backupPlan {
case .disabled, .disabling:
isEligibleToDownloadFromMediaTier = false
case .free:
// We still might attempt media tier downloads
// while currently free tier.
isEligibleToDownloadFromMediaTier = true
case .paid, .paidExpiringSoon, .paidAsTester:
isEligibleToDownloadFromMediaTier = true
}
var didEnqueueAnyDownloads = false
referencedAttachments.forEach { referencedAttachment in
let sourceToUse: QueuedAttachmentDownloadRecord.SourceType = {
// We only download from the latest transit tier info.
let transitTierInfo = referencedAttachment.attachment.latestTransitTierInfo
let mediaTierInfo = referencedAttachment.attachment.mediaTierInfo
guard
let transitTierInfo,
let mediaTierInfo
else {
// If we don't have both there's nothing to decide
return mediaTierInfo == nil ? .transitTier : .mediaTierFullsize
}
if
isEligibleToDownloadFromMediaTier,
mediaTierInfo.lastDownloadAttemptTimestamp == nil
{
// If we've never tried media tier, always try that first.
return .mediaTierFullsize
} else
if transitTierInfo.lastDownloadAttemptTimestamp == nil {
// If we tried media tier and failed, try transit tier
// next time.
return .transitTier
} else {
// If both have failed fall back to default.
return isEligibleToDownloadFromMediaTier
? .mediaTierFullsize
: .transitTier
}
}()
let downloadability = downloadabilityChecker.downloadability(
of: referencedAttachment.reference,
priority: priority,
source: sourceToUse,
mimeType: referencedAttachment.attachment.mimeType,
tx: tx,
)
switch downloadability {
case .downloadable:
didEnqueueAnyDownloads = true
attachmentDownloadStore.enqueueDownloadOfAttachment(
withId: referencedAttachment.reference.attachmentRowId,
source: sourceToUse,
priority: priority,
tx: tx,
)
case .blockedByActiveCall:
Logger.info("Skipping enqueue of download during active call")
case .blockedByPendingMessageRequest:
Logger.info("Skipping enqueue of download due to pending message request")
case .blockedByAutoDownloadSettings:
Logger.info("Skipping enqueue of download due to auto download settings")
case .blockedByNetworkState:
Logger.info("Skipping enqueue of download due to network state")
}
}
if didEnqueueAnyDownloads {
tx.addSyncCompletion { [weak self] in
self?.db.asyncWrite { tx in
referencedAttachments.forEach { referencedAttachment in
self?.attachmentUpdater.touchOwner(referencedAttachment.reference.owner, tx: tx)
}
}
self?.beginDownloadingIfNecessary()
}
}
}
public func enqueueDownloadOfAttachment(
id: Attachment.IDType,
priority: AttachmentDownloadPriority,
source: QueuedAttachmentDownloadRecord.SourceType,
tx: DBWriteTransaction,
) {
if CurrentAppContext().isRunningTests {
// No need to enqueue downloads if we're running tests.
return
}
attachmentDownloadStore.enqueueDownloadOfAttachment(
withId: id,
source: source,
priority: priority,
tx: tx,
)
tx.addSyncCompletion { [weak self] in
self?.beginDownloadingIfNecessary()
}
}
public func downloadAttachment(
id: Attachment.IDType,
priority: AttachmentDownloadPriority,
source: QueuedAttachmentDownloadRecord.SourceType,
progress: OWSProgressSink?,
) async throws {
if CurrentAppContext().isRunningTests {
// No need to enqueue downloads if we're running tests.
return
}
let downloadKey = DownloadQueue.DownloadKey(id: id, source: source)
await downloadQueue.clearOldDownloadsAndIncrementProgressID(key: downloadKey)
let downloadWaitingTask = Task {
try await self.downloadQueue.waitForDownloadOfAttachment(
id: id,
source: source,
progress: progress,
)
}
do {
await db.awaitableWrite { tx in
self.attachmentDownloadStore.enqueueDownloadOfAttachment(
withId: id,
source: source,
priority: priority,
tx: tx,
)
}
self.beginDownloadingIfNecessary()
try await downloadWaitingTask.value
} catch {
Logger.error("Error downloading attachment id \(id) from \(source): \(error)")
await downloadQueue.clearDownloadProgressAndMarkFinished(key: downloadKey)
throw error
}
await downloadQueue.clearDownloadProgressAndMarkFinished(key: downloadKey)
}
public func beginDownloadingIfNecessary() {
guard CurrentAppContext().isMainApp else {
return
}
guard tsAccountManager.registrationStateWithMaybeSneakyTransaction.isRegistered else {
return
}
Task { [weak self] in
try await self?.queueLoader.loadAndRunTasks()
}
}
public func cancelDownload(for attachmentId: Attachment.IDType, tx: DBWriteTransaction) {
progressStates.markDownloadCancelled(for: attachmentId)
QueuedAttachmentDownloadRecord.SourceType.allCases.forEach { source in
attachmentDownloadStore.removeAttachmentFromQueue(
withId: attachmentId,
source: source,
tx: tx,
)
}
self.attachmentUpdater.touchAllOwners(
attachmentId: attachmentId,
tx: tx,
)
}
// MARK: - Persisted Queue
private struct DownloadTaskRecord: TaskRecord {
let id: Int64
let record: QueuedAttachmentDownloadRecord
}
private class DownloadTaskRecordStore: TaskRecordStore {
typealias Record = DownloadTaskRecord
private let store: AttachmentDownloadStore
init(store: AttachmentDownloadStore) {
self.store = store
}
func peek(count: UInt, tx: DBReadTransaction) -> [DownloadTaskRecord] {
return store.peek(count: count, tx: tx).map {
return .init(id: $0.id!, record: $0)
}
}
func removeRecord(_ record: DownloadTaskRecord, tx: DBWriteTransaction) {
store.removeAttachmentFromQueue(
withId: record.record.attachmentId,
source: record.record.sourceType,
tx: tx,
)
}
}
private final class DownloadTaskRunner: TaskRecordRunner {
typealias Store = DownloadTaskRecordStore
private let accountKeyStore: AccountKeyStore
private let attachmentDownloadStore: AttachmentDownloadStore
private let attachmentStore: AttachmentStore
private let attachmentUpdater: AttachmentUpdater
private let attachmentUploadStore: AttachmentUploadStore
private let backupRequestManager: BackupRequestManager
private let dateProvider: DateProvider
private let db: any DB
private let decrypter: Decrypter
private let downloadabilityChecker: DownloadabilityChecker
private let downloadQueue: DownloadQueue
private let remoteConfigProvider: any RemoteConfigProvider
private let stickerManager: Shims.StickerManager
let store: Store
private let tsAccountManager: TSAccountManager
init(
accountKeyStore: AccountKeyStore,
attachmentDownloadStore: AttachmentDownloadStore,
attachmentStore: AttachmentStore,
attachmentUpdater: AttachmentUpdater,
attachmentUploadStore: AttachmentUploadStore,
backupRequestManager: BackupRequestManager,
dateProvider: @escaping DateProvider,
db: any DB,
decrypter: Decrypter,
downloadQueue: DownloadQueue,
downloadabilityChecker: DownloadabilityChecker,
remoteConfigProvider: any RemoteConfigProvider,
stickerManager: Shims.StickerManager,
tsAccountManager: TSAccountManager,
) {
self.accountKeyStore = accountKeyStore
self.attachmentDownloadStore = attachmentDownloadStore
self.attachmentStore = attachmentStore
self.attachmentUpdater = attachmentUpdater
self.attachmentUploadStore = attachmentUploadStore
self.backupRequestManager = backupRequestManager
self.dateProvider = dateProvider
self.db = db
self.decrypter = decrypter
self.downloadQueue = downloadQueue
self.downloadabilityChecker = downloadabilityChecker
self.remoteConfigProvider = remoteConfigProvider
self.stickerManager = stickerManager
self.store = DownloadTaskRecordStore(store: attachmentDownloadStore)
self.tsAccountManager = tsAccountManager
}
// MARK: TaskRecordRunner conformance
func runTask(
record: DownloadTaskRecord,
loader: TaskQueueLoader<DownloadTaskRunner>,
) async -> TaskRecordResult {
Logger.info("Starting download of attachment \(record.record.attachmentId) from \(record.record.sourceType)")
return await self.downloadRecord(record.record)
}
func didSucceed(
record: DownloadTaskRecord,
tx: DBWriteTransaction,
) throws {
Logger.info("Succeeded download of attachment \(record.record.attachmentId) from \(record.record.sourceType)")
let downloadKey = DownloadQueue.downloadKey(record: record.record)
Task {
await downloadQueue.updateObservers(downloadKey: downloadKey, error: nil)
}
}
func didObsolete(
record: DownloadTaskRecord,
tx: DBWriteTransaction,
) throws {
Logger.info("Obsoleted download of attachment \(record.record.attachmentId) from \(record.record.sourceType)")
let downloadKey = DownloadQueue.downloadKey(record: record.record)
Task {
await downloadQueue.updateObservers(downloadKey: downloadKey, error: nil)
}
}
func didFail(record: DownloadTaskRecord, error: Error, isRetryable: Bool, tx: DBWriteTransaction) {
let record = record.record
Logger.warn("Failed download of attachment \(record.attachmentId) from \(record.sourceType). \(error)")
if isRetryable, let retryTime = self.retryTime(for: record) {
// Don't update observers; they'll be updated when the retry succeeds.
attachmentDownloadStore.markQueuedDownloadFailed(
withId: record.id!,
minRetryTimestamp: retryTime,
tx: tx,
)
} else {
let attachment = attachmentStore.fetch(id: record.attachmentId, tx: tx)
// If we tried to download as media tier, and failed, and we have
// a transit tier fallback available, try downloading from that.
let shouldReEnqueueAsTransitTier =
record.sourceType == .mediaTierFullsize
&& attachment?.latestTransitTierInfo != nil
// Backup restore download queue does its own fallbacks
&& record.priority != .backupRestore
// Not retrying; just delete the enqueued download
attachmentDownloadStore.removeAttachmentFromQueue(
withId: record.attachmentId,
source: record.sourceType,
tx: tx,
)
if
let error = error as? TransitTierExpiredError,
let attachment = attachmentStore.fetch(id: record.attachmentId, tx: tx)
{
Logger.info("Expiring transit tier due to failed download")
attachmentStore.removeTransitTierInfo(
error.transitTierInfo,
attachment: attachment,
tx: tx,
)
} else if
record.priority != .backupRestore,
let attachment = attachmentStore.fetch(id: record.attachmentId, tx: tx)
{
// Backup restore doenload queue does its own marking of failed state.
attachmentStore.updateAttachmentAsFailedToDownload(
attachment: attachment,
sourceType: record.sourceType,
timestamp: self.dateProvider().ows_millisecondsSince1970,
tx: tx,
)
}
if shouldReEnqueueAsTransitTier {
attachmentDownloadStore.enqueueDownloadOfAttachment(
withId: record.attachmentId,
source: .transitTier,
priority: record.priority,
tx: tx,
)
} else {
// If we aren't re-enqueuing, tell observers its failed.
let downloadKey = DownloadQueue.downloadKey(record: record)
Task {
await downloadQueue.updateObservers(downloadKey: downloadKey, error: error)
}
}
tx.addSyncCompletion { [weak self] in
guard let self else { return }
self.db.asyncWrite { tx in
self.attachmentUpdater.touchAllOwners(
attachmentId: record.attachmentId,
tx: tx,
)
}
}
}
}
private struct TransitTierExpiredError: Error {
let transitTierInfo: Attachment.TransitTierInfo
}
private func wrapDownloadError(
error: Error,
record: QueuedAttachmentDownloadRecord,
attachmentBeforeDownloadAttempt: Attachment,
) -> TaskRecordResult {
// Check if we should mark the transit tier download as
// "expired" (meaning we wipe the transit tier info).
let now = dateProvider().ows_millisecondsSince1970
if
// We only expire if we get a 404 from the server
error.httpStatusCode == 404,
// We only expire transit tier downloads
record.sourceType == .transitTier,
// Check that the transit tier info hasn't changed (cdn key downloads unlikely)
let refetchedAttachment = db.read(
block: { attachmentStore.fetch(id: record.attachmentId, tx: $0) },
),
let transitTierInfoBeforeDownloadAttempt = attachmentBeforeDownloadAttempt.latestTransitTierInfo,
refetchedAttachment.latestTransitTierInfo?.cdnKey
== transitTierInfoBeforeDownloadAttempt.cdnKey,
// Only proactively expire if the upload is old enough
let uploadTimestamp = refetchedAttachment.latestTransitTierInfo?.uploadTimestamp,
uploadTimestamp < now,
now - uploadTimestamp >= remoteConfigProvider.currentConfig().messageQueueTimeMs
{
return .unretryableError(TransitTierExpiredError(transitTierInfo: transitTierInfoBeforeDownloadAttempt))
}
// We retry all other network-level errors (with an exponential backoff).
// Even if we get e.g. a 404, the file may not be available _yet_
// but might be in the future (exception below)
// The other type of error that can be expected here is if CDN
// credentials expire between enqueueing the download and the download
// excuting. The outcome is the same: fail the current download and retry.
return .retryableError(error)
}
/// Returns nil if should not be retried.
/// Note these are not network-level retries; those happen separately.
/// These are persisted retries, usually for longer running retry attempts.
private nonisolated func retryTime(for record: QueuedAttachmentDownloadRecord) -> UInt64? {
switch record.sourceType {
case .transitTier:
// We don't do persistent retries fromt the transit tier.
return nil
case .mediaTierFullsize, .mediaTierThumbnail:
// User initiated downloads aren't retried,
// and backups-initiated downloads get retried
// at the backup manager level.
return nil
}
}
// MARK: Downloading
private nonisolated func downloadRecord(
_ record: QueuedAttachmentDownloadRecord,
) async -> TaskRecordResult {
guard
let attachment = db.read(block: { tx in
attachmentStore.fetch(id: record.attachmentId, tx: tx)
})
else {
// Because of the foreign key relationship and cascading deletes, this should
// only happen if the attachment got deleted between when we fetched the
// download queue record and now. Regardless, the record should now be deleted.
owsFailDebug("Attempting to download an attachment that doesn't exist!")
return .obsolete
}
guard attachment.asStream() == nil else {
// Already a stream! No need to download.
return .obsolete
}
switch self.downloadabilityChecker.downloadability(record, attachment: attachment) {
case .downloadable:
break
case nil:
// Because of the foreign key relationship and cascading deletes, this should
// only happen if all the references got deleted between when we fetched the
// download queue record and now. Regardless, the record should now be deleted.
owsFailDebug("Attempting to download an attachment with no references \(record.attachmentId)")
return .obsolete
case .blockedByActiveCall:
// This is a temporary setback; retry in a bit if the source allows it.
Logger.info("Skipping attachment download due to active call \(record.attachmentId)")
return .retryableError(AttachmentDownloads.Error.blockedByActiveCall)
case .blockedByPendingMessageRequest:
Logger.info("Skipping attachment download due to pending message request \(record.attachmentId)")
// These can only be resolved by user action; cancel the enqueued download.
return .unretryableError(AttachmentDownloads.Error.blockedByPendingMessageRequest)
case .blockedByAutoDownloadSettings:
Logger.info("Skipping attachment download due to auto download settings \(record.attachmentId)")
// These can only be resolved by user action; cancel the enqueued download.
return .unretryableError(AttachmentDownloads.Error.blockedByAutoDownloadSettings)
case .blockedByNetworkState:
Logger.info("Skipping attachment download due to network state \(record.attachmentId)")
return .unretryableError(AttachmentDownloads.Error.blockedByNetworkState)
}
Logger.info("Downloading attachment \(record.attachmentId) from \(record.sourceType)")
if
let originalAttachmentIdForQuotedReply = attachment.originalAttachmentIdForQuotedReply,
await quoteUnquoteDownloadQuotedReplyFromOriginalStream(
originalAttachmentIdForQuotedReply: originalAttachmentIdForQuotedReply,
record: record,
)
{
// Done!
Logger.info("Sourced quote attachment from original \(record.attachmentId)")
return .success
}
if await quoteUnquoteDownloadStickerFromInstalledPackIfPossible(record: record) {
// Done!
Logger.info("Sourced sticker attachment from installed sticker \(record.attachmentId)")
return .success
}
if record.priority == .localClone {
// Local clone happens in two ways:
// 1. Original's local stream for a quoted reply
// 2. Local installed sticker for a sticker message
// If we were trying for either of these and got this far,
// we failed to use the local data, so just fail the whole thing.
return .unretryableError(OWSAssertionError("Failed local clone"))
}
let downloadMetadata: DownloadMetadata
let downloadSizeSource: DownloadQueue.DownloadSizeSource
let maxDownloadSizeBytes: UInt64
switch record.sourceType {
case .transitTier:
// We only download from the latest transit tier info.
guard let transitTierInfo = attachment.latestTransitTierInfo else {
return .unretryableError(OWSAssertionError("Attempting to download an attachment without cdn info"))
}
downloadMetadata = .init(
mimeType: attachment.mimeType,
cdnNumber: transitTierInfo.cdnNumber,
encryptionKey: transitTierInfo.encryptionKey,
source: .transitTier(
cdnKey: transitTierInfo.cdnKey,
integrityCheck: transitTierInfo.integrityCheck,
plaintextLength: transitTierInfo.unencryptedByteCount,
),
)
downloadSizeSource = transitTierInfo.unencryptedByteCount.map({
.estimatedSizeBytes(UInt(Cryptography.estimatedTransitTierCDNSize(
unencryptedSize: UInt64(safeCast: $0),
) ?? UInt64(UInt32.max)))
}) ?? .useHeadRequest
let attachmentLimits = IncomingAttachmentLimits.currentLimits(remoteConfig: remoteConfigProvider.currentConfig())
switch Attachment.ContentTypeRaw(mimeType: attachment.mimeType) {
case .image, .animatedImage:
maxDownloadSizeBytes = attachmentLimits.maxEncryptedImageBytes
case .audio, .video, .file, .invalid:
maxDownloadSizeBytes = attachmentLimits.maxEncryptedBytes
}
case .mediaTierFullsize:
let cdnNumber = attachment.mediaTierInfo?.cdnNumber ?? remoteConfigProvider.currentConfig().mediaTierFallbackCdnNumber
guard
let mediaTierInfo = attachment.mediaTierInfo,
let mediaName = attachment.mediaName,
let backupKey = db.read(block: { accountKeyStore.getMediaRootBackupKey(tx: $0) }),
let encryptionMetadata = buildCdnEncryptionMetadata(mediaName: mediaName, backupKey: backupKey, type: .outerLayerFullsizeOrThumbnail),
let cdnCredential = await fetchBackupCdnReadCredential(
for: cdnNumber,
backupKey: backupKey,
logger: PrefixedLogger(prefix: "[Backups]"),
)
else {
return .unretryableError(OWSAssertionError("Attempting to download an attachment without cdn info"))
}
let integrityCheck = AttachmentIntegrityCheck.sha256ContentHash(mediaTierInfo.sha256ContentHash)
downloadMetadata = .init(
mimeType: attachment.mimeType,
cdnNumber: cdnNumber,
encryptionKey: attachment.encryptionKey,
source: .mediaTierFullsize(
cdnReadCredential: cdnCredential,
outerEncryptionMetadata: encryptionMetadata,
integrityCheck: integrityCheck,
plaintextLength: mediaTierInfo.unencryptedByteCount,
),
)
downloadSizeSource = .estimatedSizeBytes(UInt(Cryptography.estimatedMediaTierCDNSize(
unencryptedSize: UInt64(safeCast: mediaTierInfo.unencryptedByteCount),
) ?? UInt64(UInt32.max)))
maxDownloadSizeBytes = remoteConfigProvider.currentConfig().attachmentMaxEncryptedReceiveBytes
case .mediaTierThumbnail:
let cdnNumber = attachment.thumbnailMediaTierInfo?.cdnNumber ?? remoteConfigProvider.currentConfig().mediaTierFallbackCdnNumber
guard
attachment.thumbnailMediaTierInfo != nil || MimeTypeUtil.isSupportedVisualMediaMimeType(attachment.mimeType),
let mediaName = attachment.mediaName,
let backupKey = db.read(block: { accountKeyStore.getMediaRootBackupKey(tx: $0) }),
// This is the outer encryption
let outerEncryptionMetadata = buildCdnEncryptionMetadata(
mediaName: AttachmentBackupThumbnail.thumbnailMediaName(fullsizeMediaName: mediaName),
backupKey: backupKey,
type: .outerLayerFullsizeOrThumbnail,
),
// inner encryption
let innerEncryptionMetadata = buildCdnEncryptionMetadata(
mediaName: AttachmentBackupThumbnail.thumbnailMediaName(fullsizeMediaName: mediaName),
backupKey: backupKey,
type: .transitTierThumbnail,
),
let cdnReadCredential = await fetchBackupCdnReadCredential(
for: cdnNumber,
backupKey: backupKey,
logger: PrefixedLogger(prefix: "[Backups]"),
)
else {
return .unretryableError(OWSAssertionError("Attempting to download an attachment without cdn info"))
}
let thumbnailMimeType = MimeTypeUtil.thumbnailMimetype(
fullsizeMimeType: attachment.mimeType,
quality: .backupThumbnail,
)
downloadMetadata = .init(
mimeType: thumbnailMimeType,
cdnNumber: cdnNumber,
encryptionKey: attachment.encryptionKey,
source: .mediaTierThumbnail(
cdnReadCredential: cdnReadCredential,
outerEncyptionMetadata: outerEncryptionMetadata,
innerEncryptionMetadata: innerEncryptionMetadata,
),
)
// We don't know thumbnail sizes and don't want to issue a
// request for each one to check. Just estimate as the max size.
downloadSizeSource = .estimatedSizeBytes(UInt(Cryptography.estimatedMediaTierCDNSize(
unencryptedSize: UInt64(safeCast: AttachmentThumbnailQuality.backupThumbnailMaxSizeBytes),
) ?? UInt64(UInt32.max)))
maxDownloadSizeBytes = remoteConfigProvider.currentConfig().attachmentMaxEncryptedReceiveBytes
}
let downloadedFileUrl: URL
do {
downloadedFileUrl = try await downloadQueue.enqueueDownload(
downloadState: .init(type: .attachment(downloadMetadata, id: attachment.id)),
maxDownloadSizeBytes: maxDownloadSizeBytes,
expectedDownloadSize: downloadSizeSource,
progress: nil,
)
} catch let error {
Logger.error("Failed to download: \(error)")
return wrapDownloadError(
error: error,
record: record,
attachmentBeforeDownloadAttempt: attachment,
)
}
let pendingAttachment: PendingAttachment
do {
pendingAttachment = try await decrypter.validateAndPrepare(
encryptedFileUrl: downloadedFileUrl,
metadata: downloadMetadata,
)
} catch let error {
return .unretryableError(OWSAssertionError("Failed to validate: \(error)"))
}
let result: DownloadResult
do {
result = try await attachmentUpdater.updateAttachmentAsDownloaded(
attachmentId: attachment.id,
pendingAttachment: pendingAttachment,
source: record.sourceType,
priority: record.priority,
timestamp: dateProvider().ows_millisecondsSince1970,
)
} catch let error {
return .unretryableError(OWSAssertionError("Failed to update attachment: \(error)"))
}
if case .stream(let attachmentStream) = result {
do {
try await attachmentUpdater.copyThumbnailForQuotedReplyIfNeeded(
attachmentStream,
)
} catch let error {
// Log error but don't block finishing; the thumbnails
// can update themselves later.
Logger.error("Failed to update thumbnails: \(error)")
}
}
return .success
}
private nonisolated func quoteUnquoteDownloadQuotedReplyFromOriginalStream(
originalAttachmentIdForQuotedReply: Attachment.IDType,
record: QueuedAttachmentDownloadRecord,
) async -> Bool {
let originalAttachmentStream = db.read { tx in
attachmentStore.fetch(id: originalAttachmentIdForQuotedReply, tx: tx)?.asStream()
}
guard let originalAttachmentStream else {
return false
}
do {
try await attachmentUpdater.copyThumbnailForQuotedReplyIfNeeded(
originalAttachmentStream,
)
return true
} catch let error {
Logger.error("Failed to update thumbnails: \(error)")
return false
}
}
private nonisolated func quoteUnquoteDownloadStickerFromInstalledPackIfPossible(
record: QueuedAttachmentDownloadRecord,
) async -> Bool {
let installedSticker: InstalledStickerRecord? = db.read { tx in
var stickerMetadata: AttachmentReference.Owner.MessageSource.StickerMetadata?
attachmentStore.enumerateAllReferences(
toAttachmentId: record.attachmentId,
tx: tx,
block: { reference, stop in
switch reference.owner {
case .message(.sticker(let metadata)):
stop = true
stickerMetadata = metadata
default:
break
}
},
)
guard let stickerMetadata else {
return nil
}
return self.stickerManager.fetchInstalledSticker(
packId: stickerMetadata.stickerPackId,
stickerId: stickerMetadata.stickerId,
tx: tx,
)
}
guard let installedSticker else {
return false
}
// Pretend that is the file we've downloaded.
let pendingAttachment: PendingAttachment
do {
pendingAttachment = try await decrypter.validateAndPrepareInstalledSticker(installedSticker)
} catch let error {
Logger.error("Failed to validate sticker: \(error)")
return false
}
let attachmentStream: AttachmentStream
do {
attachmentStream = try await attachmentUpdater.updateAttachmentFromInstalledSticker(
attachmentId: record.attachmentId,
pendingAttachment: pendingAttachment,
)
} catch let error {
Logger.error("Failed to update attachment: \(error)")
return false
}
do {
try await attachmentUpdater.copyThumbnailForQuotedReplyIfNeeded(
attachmentStream,
)
} catch let error {
// Log error but don't block finishing; the thumbnails
// can update themselves later.
Logger.error("Failed to update thumbnails: \(error)")
}
return true
}
private func buildCdnEncryptionMetadata(
mediaName: String,
backupKey: MediaRootBackupKey,
type: MediaTierEncryptionType,
) -> MediaTierEncryptionMetadata? {
do {
return try backupKey.mediaEncryptionMetadata(
mediaName: mediaName,
type: type,
)
} catch {
owsFailDebug("Failed to build backup media metadata")
return nil
}
}
private func fetchBackupCdnReadCredential(
for cdn: UInt32,
backupKey: MediaRootBackupKey,
logger: PrefixedLogger,
) async -> MediaTierReadCredential? {
guard
let localAci = db.read(block: { tx in
self.tsAccountManager.localIdentifiers(tx: tx)?.aci
})
else {
owsFailDebug("Missing local identifier")
return nil
}
guard
let auth = try? await backupRequestManager.fetchBackupServiceAuth(
for: backupKey,
localAci: localAci,
auth: .implicit(),
logger: logger,
)
else {
owsFailDebug("Failed to fetch backup credential")
return nil
}
guard
let metadata = try? await backupRequestManager.fetchMediaTierCdnRequestMetadata(
cdn: Int32(cdn),
auth: auth,
logger: logger,
)
else {
owsFailDebug("Failed to fetch backup credential")
return nil
}
return metadata
}
}
// MARK: - Downloadability
private class DownloadabilityChecker {
private let attachmentStore: AttachmentStore
private let backupSettingsStore: BackupSettingsStore
private let currentCallProvider: CurrentCallProvider
private let db: any DB
private let mediaBandwidthPreferenceStore: MediaBandwidthPreferenceStore
private let profileManager: ProfileManager
private let reachabilityManager: SSKReachabilityManager
private let storyStore: any StoryStore
private let threadStore: ThreadStore
init(
attachmentStore: AttachmentStore,
backupSettingsStore: BackupSettingsStore,
currentCallProvider: CurrentCallProvider,
db: any DB,
mediaBandwidthPreferenceStore: MediaBandwidthPreferenceStore,
profileManager: ProfileManager,
reachabilityManager: SSKReachabilityManager,
storyStore: any StoryStore,
threadStore: ThreadStore,
) {
self.attachmentStore = attachmentStore
self.backupSettingsStore = backupSettingsStore
self.currentCallProvider = currentCallProvider
self.db = db
self.mediaBandwidthPreferenceStore = mediaBandwidthPreferenceStore
self.profileManager = profileManager
self.reachabilityManager = reachabilityManager
self.storyStore = storyStore
self.threadStore = threadStore
}
enum Downloadability {
case downloadable
case blockedByActiveCall
case blockedByPendingMessageRequest
case blockedByAutoDownloadSettings
case blockedByNetworkState
}
func downloadability(
_ record: QueuedAttachmentDownloadRecord,
attachment: Attachment,
) -> Downloadability? {
// Check priority before opening a read.
switch record.priority {
case .userInitiated, .localClone:
// Always download at these priorities.
return .downloadable
case .default, .backupRestore:
break
}
return db.read { tx in
var downloadability: Downloadability?
self.attachmentStore.enumerateAllReferences(
toAttachmentId: record.attachmentId,
tx: tx,
) { reference, stop in
downloadability = self.downloadability(
of: reference,
priority: record.priority,
source: record.sourceType,
mimeType: attachment.mimeType,
tx: tx,
)
// If one reference marks it downloadable, don't check further ones.
if downloadability == .downloadable {
stop = true
}
}
return downloadability
}
}
func downloadability(
of reference: AttachmentReference,
priority: AttachmentDownloadPriority,
source: QueuedAttachmentDownloadRecord.SourceType,
mimeType: String,
tx: DBReadTransaction,
) -> Downloadability {
let blockedByCall = self.isDownloadBlockedByActiveCall(
priority: priority,
owner: reference.owner,
tx: tx,
)
if blockedByCall {
return .blockedByActiveCall
}
switch source {
case .transitTier:
// Transit tier can download regardless of reachability
break
case .mediaTierFullsize, .mediaTierThumbnail:
guard
backupSettingsStore.shouldAllowBackupDownloadsOnCellular(tx: tx)
|| reachabilityManager.isReachable(via: .wifi)
else {
return .blockedByNetworkState
}
}
let blockedByAutoDownloadSettings = self.isDownloadBlockedByAutoDownloadSettings(
priority: priority,
owner: reference.owner,
renderingFlag: reference.renderingFlag,
mimeType: mimeType,
tx: tx,
)
if blockedByAutoDownloadSettings {
return .blockedByAutoDownloadSettings
}
let blockedByPendingMessageRequest = self.isDownloadBlockedByPendingMessageRequest(
priority: priority,
source: source,
owner: reference.owner,
tx: tx,
)
if blockedByPendingMessageRequest {
return .blockedByPendingMessageRequest
}
// If we made it this far, its downloadable.
return .downloadable
}
private func isDownloadBlockedByActiveCall(
priority: AttachmentDownloadPriority,
owner: AttachmentReference.Owner,
tx: DBReadTransaction,
) -> Bool {
switch priority {
case .userInitiated, .localClone:
// Always download at these priorities.
return false
case .backupRestore:
// Don't suspend during a call until we have mechanisms
// to resume once the call ends.
// TODO: [Backups] suspend downloads during calls and resume after
return false
case .default:
break
}
switch owner {
case .message(.bodyAttachment), .storyMessage(.media), .thread(.threadWallpaperImage), .thread(.globalThreadWallpaperImage):
break
case .message(.oversizeText):
return false
case .message(.sticker):
break
case .message(.quotedReply), .message(.linkPreview), .storyMessage(.textStoryLinkPreview), .message(.contactAvatar):
return false
}
return currentCallProvider.hasCurrentCall
}
private func isDownloadBlockedByPendingMessageRequest(
priority: AttachmentDownloadPriority,
source: QueuedAttachmentDownloadRecord.SourceType,
owner: AttachmentReference.Owner,
tx: DBReadTransaction,
) -> Bool {
switch priority {
case .userInitiated, .localClone:
// Always download at these priorities.
return false
case .default, .backupRestore:
break
}
switch source {
case .transitTier:
break
case .mediaTierFullsize, .mediaTierThumbnail:
// Even if we are in message request state, the fact
// that these downloads are on media tier means the
// client that backed them up had them downloaded
// before. Therefore, they should be good to redownload.
return false
}
let thread: TSThread?
switch owner {
case .message(let source):
let threadRowId: TSThread.RowId
switch source {
case .oversizeText:
// These are bodyAttachments that branch based on their MIME type and
// aren't processed as media files.
return false
case .sticker(let metadata):
threadRowId = metadata.threadRowId
case .bodyAttachment(let metadata):
threadRowId = metadata.threadRowId
case .quotedReply(let metadata):
threadRowId = metadata.threadRowId
case .linkPreview(let metadata):
threadRowId = metadata.threadRowId
case .contactAvatar(let metadata):
threadRowId = metadata.threadRowId
}
thread = threadStore.fetchThread(rowId: threadRowId, tx: tx)
case .storyMessage(let source):
let storyMessage = storyStore.fetchStoryMessage(rowId: source.storyMessageRowId, tx: tx)
guard let storyMessage else {
owsFailDebug("can't check downloadability for non-existent owner")
return true
}
switch storyMessage.direction {
case .outgoing:
// Ignore outgoing stories for purposes of pending message requests.
return false
case .incoming:
break
}
if let groupId = storyMessage.groupId {
thread = threadStore.fetchGroupThread(groupId: groupId, tx: tx)
} else {
thread = threadStore.fetchContactThreads(serviceId: storyMessage.authorAci, tx: tx).first
}
case .thread:
// Ignore non-message cases for purposes of pending message request.
return false
}
// If there's not a thread, err on the safe side and don't download it.
guard let thread else {
return true
}
return threadStore.hasPendingMessageRequest(thread: thread, tx: tx)
}
private func isDownloadBlockedByAutoDownloadSettings(
priority: AttachmentDownloadPriority,
owner: AttachmentReference.Owner,
renderingFlag: AttachmentReference.RenderingFlag,
mimeType: String,
tx: DBReadTransaction,
) -> Bool {
switch priority {
case .userInitiated, .localClone:
// Always download at these priorities.
return false
case .default:
break
case .backupRestore:
// Despite being lower priority than default,
// these actually should download despite the setting.
return false
}
let autoDownloadableMediaTypes = mediaBandwidthPreferenceStore.autoDownloadableMediaTypes(tx: tx)
switch owner {
case .message(.bodyAttachment), .storyMessage(.media):
if MimeTypeUtil.isSupportedImageMimeType(mimeType) {
return !autoDownloadableMediaTypes.contains(.photo)
}
if MimeTypeUtil.isSupportedVideoMimeType(mimeType) {
return !autoDownloadableMediaTypes.contains(.video)
}
if MimeTypeUtil.isSupportedAudioMimeType(mimeType) {
if renderingFlag == .voiceMessage {
return false
}
return !autoDownloadableMediaTypes.contains(.audio)
}
return !autoDownloadableMediaTypes.contains(.document)
case .message(.oversizeText):
return false
case .message(.sticker):
return !autoDownloadableMediaTypes.contains(.photo)
case .message(.quotedReply):
return false
case .message(.linkPreview):
return false
case .message(.contactAvatar):
return false
case .storyMessage(.textStoryLinkPreview):
return false
case .thread(.threadWallpaperImage), .thread(.globalThreadWallpaperImage):
return false
}
}
}
// MARK: - Downloads
typealias DownloadMetadata = AttachmentDownloads.DownloadMetadata
private enum DownloadError: Error {
case oversize
}
private enum DownloadType {
case backup(metadata: BackupReadCredential, uuid: UUID)
case transientAttachment(DownloadMetadata, uuid: UUID)
case attachment(DownloadMetadata, id: Attachment.IDType)
// MARK: - Helpers
func urlPath() throws -> String {
switch self {
case .backup(let info, _):
return info.backupLocationUrl()
case .attachment(let metadata, _), .transientAttachment(let metadata, _):
switch metadata.source {
case .transitTier(let cdnKey, _, _), .linkNSyncBackup(let cdnKey):
guard let encodedKey = cdnKey.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) else {
throw OWSAssertionError("Invalid cdnKey.")
}
return "attachments/\(encodedKey)"
case
.mediaTierFullsize(let cdnCredential, let outerEncryptionMetadata, _, _),
.mediaTierThumbnail(let cdnCredential, let outerEncryptionMetadata, _):
let prefix = cdnCredential.mediaTierUrlPrefix()
return "\(prefix)/\(outerEncryptionMetadata.mediaId.asBase64Url)"
}
}
}
func cdnNumber() -> UInt32 {
switch self {
case .backup(let info, _):
return UInt32(clamping: info.cdn)
case .attachment(let metadata, _), .transientAttachment(let metadata, _):
return metadata.cdnNumber
}
}
func additionalHeaders() -> HttpHeaders {
switch self {
case .backup(let metadata, _):
return metadata.cdnAuthHeaders
case .attachment(let metadata, _), .transientAttachment(let metadata, _):
switch metadata.source {
case .transitTier, .linkNSyncBackup:
return [:]
case .mediaTierFullsize(let cdnCredential, _, _, _), .mediaTierThumbnail(let cdnCredential, _, _):
return cdnCredential.cdnAuthHeaders
}
}
}
func isExpired() -> Bool {
switch self {
case .backup(let metadata, _):
return metadata.isExpired
case .attachment(let metadata, _), .transientAttachment(let metadata, _):
switch metadata.source {
case .transitTier, .linkNSyncBackup:
return false
case .mediaTierFullsize(let cdnCredential, _, _, _), .mediaTierThumbnail(let cdnCredential, _, _):
return cdnCredential.isExpired
}
}
}
}
private class DownloadState {
let startDate = Date()
let type: DownloadType
init(type: DownloadType) {
self.type = type
}
func urlPath() throws -> String {
return try type.urlPath()
}
func cdnNumber() -> UInt32 {
return type.cdnNumber()
}
func additionalHeaders() -> HttpHeaders {
return type.additionalHeaders()
}
func isExpired() -> Bool {
return type.isExpired()
}
}
private final class ProgressStates: Sendable {
private struct States {
var states: [Attachment.IDType: Double] = [:]
var cancelledAttachmentIds: Set<Attachment.IDType> = []
}
private let states = TSMutex(initialState: States())
func markDownloadCancelled(for attachmentId: Attachment.IDType) {
states.withLock {
$0.states[attachmentId] = nil
$0.cancelledAttachmentIds.insert(attachmentId)
}
}
func consumeCancellation(of attachmentId: Attachment.IDType) -> Bool {
states.withLock { $0.cancelledAttachmentIds.remove(attachmentId) != nil }
}
}
private actor DownloadQueue {
private let progressStates: ProgressStates
private nonisolated let signalService: OWSSignalServiceProtocol
init(
progressStates: ProgressStates,
signalService: OWSSignalServiceProtocol,
) {
self.progressStates = progressStates
self.signalService = signalService
}
private let queue = ConcurrentTaskQueue(concurrentLimit: 12)
/// Non-transient attachments have an in-memory disconnect from the downloadQueue and the actual job runner;
/// they are enqueued to disk and then read off disk to be downloaded. In order to get an in-memory object
/// like a Continuation or an OWSProgress across, we need to cache them in memory using this key.
/// This does not apply to transient downloads since they stay in memory the whole time.
struct DownloadKey: Hashable {
let id: Attachment.IDType
let source: QueuedAttachmentDownloadRecord.SourceType
}
private var downloadObservers = [DownloadKey: [CheckedContinuation<Void, Error>]]()
private var downloadProgresses = [DownloadKey: [OWSProgressSink]]()
private var finishedOrFailedDownloads = Set<DownloadKey>()
private var progressIDs = [DownloadKey: UInt64]()
func latestProgressID(downloadKey: DownloadKey?) -> UInt64 {
guard let downloadKey else {
return 0
}
return progressIDs[downloadKey] ?? 0
}
func clearOldDownloadsAndIncrementProgressID(key: DownloadKey) {
finishedOrFailedDownloads.remove(key)
let oldProgressID = progressIDs[key] ?? 0
progressIDs[key] = oldProgressID + 1
}
func clearDownloadProgressAndMarkFinished(key: DownloadKey) {
downloadProgresses.removeValue(forKey: key)
finishedOrFailedDownloads.insert(key)
}
func waitForDownloadOfAttachment(
id: Attachment.IDType,
source: QueuedAttachmentDownloadRecord.SourceType,
progress: OWSProgressSink?,
) async throws {
return try await withCheckedThrowingContinuation { continuation in
let key = DownloadKey(id: id, source: source)
var observers = self.downloadObservers[key] ?? []
observers.append(continuation)
self.downloadObservers[key] = observers
if let progress {
var progresses = downloadProgresses[key] ?? []
progresses.append(progress)
self.downloadProgresses[key] = progresses
}
}
}
func updateObservers(downloadKey: DownloadKey, error: Error?) {
let observers = self.downloadObservers.removeValue(forKey: downloadKey) ?? []
if let error {
observers.forEach { $0.resume(throwing: error) }
} else {
observers.forEach { $0.resume() }
}
}
nonisolated static func downloadKey(record: QueuedAttachmentDownloadRecord) -> DownloadKey {
return DownloadKey(id: record.attachmentId, source: record.sourceType)
}
private nonisolated static func downloadKey(state: DownloadState) -> DownloadKey? {
switch state.type {
case .backup, .transientAttachment:
return nil
case .attachment(let downloadMetadata, let id):
return DownloadKey(id: id, source: downloadMetadata.source.asQueuedDownloadSource)
}
}
fileprivate func performHeadRequest(
downloadState: DownloadState,
maxDownloadSizeBytes: UInt64,
) async throws -> AttachmentDownloads.CdnInfo {
// We don't need maxDownloadSizeBytes for this request, but we include it
// to increase the likelihood of a shared URLSession cache hit.
let urlSession = await self.signalService.sharedUrlSessionForCdn(
cdnNumber: downloadState.cdnNumber(),
maxResponseSize: maxDownloadSizeBytes,
)
let urlPath = try downloadState.urlPath()
var headers = downloadState.additionalHeaders()
headers["Content-Type"] = MimeType.applicationOctetStream.rawValue
// Perform a HEAD request to get the byte length & last modified date from cdn.
let request = try urlSession.endpoint.buildRequest(urlPath, method: .head, headers: headers)
let response = try await urlSession.performRequest(request: request, ignoreAppExpiry: true)
return try AttachmentDownloads.CdnInfo(response.headers)
}
/// Fetch the first `length` bytes of the object from the CDN, returning the fetched bytes,
/// (or nil if the response was empty) alongside headers from the response (which are the
/// same headers from a HEAD response).
///
/// Length is limited to UInt16, and really should be even smaller, because this is _not_
/// a download task, is not resumable, and should therefore only be used to fetch a very
/// limited number of bytes.
fileprivate func performPrefixRequest(
downloadState: DownloadState,
maxDownloadSizeBytes: UInt64,
length: UInt16,
) async throws -> (AttachmentDownloads.CdnInfo, Data?) {
let urlSession = await self.signalService.sharedUrlSessionForCdn(
cdnNumber: downloadState.cdnNumber(),
maxResponseSize: maxDownloadSizeBytes,
)
let urlPath = try downloadState.urlPath()
var headers = downloadState.additionalHeaders()
headers["Content-Type"] = MimeType.applicationOctetStream.rawValue
headers["range"] = "bytes=0-\(length - 1)"
let request = try urlSession.endpoint.buildRequest(urlPath, method: .get, headers: headers)
let response = try await urlSession.performRequest(request: request, ignoreAppExpiry: true)
let cdnInfo = try AttachmentDownloads.CdnInfo(response.headers)
return (cdnInfo, response.responseBodyData)
}
func enqueueDownload(
downloadState: DownloadState,
maxDownloadSizeBytes: UInt64,
expectedDownloadSize: DownloadSizeSource,
progress: OWSProgressSink?,
) async throws -> URL {
let progresses = (
[progress]
+ (
Self.downloadKey(state: downloadState)
.map({ self.downloadProgresses[$0] ?? [] })
?? []
),
).compacted()
return try await queue.run {
return try await performDownload(
downloadState: downloadState,
progresses: progresses,
maxDownloadSizeBytes: maxDownloadSizeBytes,
expectedDownloadSize: expectedDownloadSize,
)
}
}
enum DownloadSizeSource {
case useHeadRequest
case estimatedSizeBytes(UInt)
}
private nonisolated func performDownload(
downloadState: DownloadState,
progresses: [OWSProgressSink],
maxDownloadSizeBytes: UInt64,
expectedDownloadSize: DownloadSizeSource,
) async throws -> URL {
let urlPath = try downloadState.urlPath()
var headers = downloadState.additionalHeaders()
headers["Content-Type"] = MimeType.applicationOctetStream.rawValue
let attachmentId: Attachment.IDType?
switch downloadState.type {
case .backup, .transientAttachment:
attachmentId = nil
case .attachment(_, let id):
attachmentId = id
}
if downloadState.isExpired() {
throw AttachmentDownloads.Error.expiredCredentials
}
let expectedDownloadSizeBytes: UInt
switch expectedDownloadSize {
case .estimatedSizeBytes(let size):
expectedDownloadSizeBytes = UInt(size)
case .useHeadRequest:
// Perform a HEAD request just to get the byte length from cdn.
let downloadInfo = try await performHeadRequest(
downloadState: downloadState,
maxDownloadSizeBytes: maxDownloadSizeBytes,
)
expectedDownloadSizeBytes = downloadInfo.contentLength
}
var progressSources: [OWSProgressSource] = []
for progress in progresses {
progressSources.append(await progress.addSource(
withLabel: AttachmentDownloads.downloadProgressLabel,
unitCount: UInt64(expectedDownloadSizeBytes),
))
}
var resumeData: Data?
return try await Retry.performRepeatedly {
if downloadState.isExpired() {
throw AttachmentDownloads.Error.expiredCredentials
}
let urlSession = await self.signalService.sharedUrlSessionForCdn(
cdnNumber: downloadState.cdnNumber(),
maxResponseSize: maxDownloadSizeBytes,
)
var downloadTask: Task<OWSUrlDownloadResponse, Error>?
let wrappedProgressID = await latestProgressID(downloadKey: DownloadQueue.downloadKey(state: downloadState))
let wrappedProgress = OWSProgress.createSink { [weak self] progressValue in
if let k = DownloadQueue.downloadKey(state: downloadState) {
if await self?.latestProgressID(downloadKey: k) != wrappedProgressID {
// A new download has started, don't send progress updates or notifications.
return
}
if let self, await finishedOrFailedDownloads.contains(k) {
// If we've already finished the download, send the notification so
// handlers can get 100% updates but don't update the progress sources,
// which may be double counting.
handleDownloadProgress(
downloadState: downloadState,
task: downloadTask,
progress: progressValue,
expectedDownloadSizeBytes: UInt(expectedDownloadSizeBytes),
attachmentId: attachmentId,
)
return
}
}
for progressSource in progressSources {
if progressSource.completedUnitCount < progressValue.completedUnitCount {
progressSource.incrementCompletedUnitCount(by: progressValue.completedUnitCount - progressSource.completedUnitCount)
}
}
self?.handleDownloadProgress(
downloadState: downloadState,
task: downloadTask,
progress: progressValue,
expectedDownloadSizeBytes: UInt(expectedDownloadSizeBytes),
attachmentId: attachmentId,
)
}
let wrappedProgressSource = await wrappedProgress.addSource(
withLabel: "source",
unitCount: UInt64(expectedDownloadSizeBytes),
)
let downloadResponse: OWSUrlDownloadResponse
if let resumeData {
let request = try urlSession.endpoint.buildRequest(urlPath, method: .get, headers: headers)
guard let requestUrl = request.url else {
throw OWSAssertionError("Request missing url.")
}
downloadTask = Task {
return try await urlSession.performDownload(
requestUrl: requestUrl,
resumeData: resumeData,
progressBlock: wrappedProgressSource.asProgressBlock(),
)
}
downloadResponse = try await downloadTask!.value
} else {
downloadTask = Task {
return try await urlSession.performDownload(
urlPath,
method: .get,
headers: headers,
progressBlock: wrappedProgressSource.asProgressBlock(),
)
}
downloadResponse = try await downloadTask!.value
}
let downloadUrl = downloadResponse.downloadUrl
let tmpFile = OWSFileSystem.temporaryFileUrl(
fileExtension: nil,
isAvailableWhileDeviceLocked: false,
)
try OWSFileSystem.moveFile(from: downloadUrl, to: tmpFile)
return tmpFile
} onError: { error, attemptCount in
Logger.warn("Error: \(error)")
let maxAttemptCount = 16
guard attemptCount < maxAttemptCount, error.isNetworkFailureOrTimeout else {
throw error
}
// Wait briefly before retrying.
try await Task.sleep(nanoseconds: 250 * NSEC_PER_MSEC)
resumeData = ((error as NSError).userInfo[NSURLSessionDownloadTaskResumeData] as? Data)?.nilIfEmpty
}
}
private nonisolated func handleDownloadProgress(
downloadState: DownloadState,
task: Task<OWSUrlDownloadResponse, Error>?,
progress: OWSProgress,
expectedDownloadSizeBytes: UInt?,
attachmentId: Attachment.IDType?,
) {
if let attachmentId, progressStates.consumeCancellation(of: attachmentId) {
Logger.info("Cancelling download.")
// Cancelling will inform the URLSessionTask delegate.
task?.cancel()
return
}
// Use a slightly non-zero value to ensure that the progress
// indicator shows up as quickly as possible.
let progressTheta: Float = 0.001
let fractionCompleted: Float
if progress.completedUnitCount > 0 {
fractionCompleted = max(progressTheta, progress.percentComplete)
} else if expectedDownloadSizeBytes != nil {
fractionCompleted = progressTheta
} else {
// Don't do anything until we've received at least one byte of data,
// or estimated the download size.
return
}
switch downloadState.type {
case .backup, .transientAttachment:
break
case .attachment(_, let attachmentId):
NotificationCenter.default.postOnMainThread(
name: AttachmentDownloads.attachmentDownloadProgressNotification,
object: nil,
userInfo: [
AttachmentDownloads.attachmentDownloadProgressKey: fractionCompleted,
AttachmentDownloads.attachmentDownloadAttachmentIDKey: attachmentId,
],
)
}
}
}
private class Decrypter {
private let attachmentValidator: AttachmentContentValidator
private let stickerManager: Shims.StickerManager
init(
attachmentValidator: AttachmentContentValidator,
stickerManager: Shims.StickerManager,
) {
self.attachmentValidator = attachmentValidator
self.stickerManager = stickerManager
}
// Use concurrent=1 queue to ensure that we only load into memory
// & decrypt a single attachment at a time.
private let decryptionQueue = ConcurrentTaskQueue(concurrentLimit: 1)
func decryptTransientAttachment(
encryptedFileUrl: URL,
metadata: DownloadMetadata,
) async throws -> URL {
return try await decryptionQueue.run {
do {
// Transient attachments decrypt to a tmp file.
let outputUrl = OWSFileSystem.temporaryFileUrl(
fileExtension: nil,
isAvailableWhileDeviceLocked: false,
)
try Cryptography.decryptAttachment(
at: encryptedFileUrl,
metadata: DecryptionMetadata(
key: AttachmentKey(combinedKey: metadata.encryptionKey),
integrityCheck: metadata.integrityCheck,
plaintextLength: metadata.plaintextLength.map(UInt64.init(safeCast:)),
),
output: outputUrl,
)
return outputUrl
} catch let error {
do {
try OWSFileSystem.deleteFileIfExists(url: encryptedFileUrl)
} catch let deleteFileError {
owsFailDebug("Error: \(deleteFileError).")
}
throw error
}
}
}
func validateAndPrepareInstalledSticker(
_ sticker: InstalledStickerRecord,
) async throws -> PendingAttachment {
let attachmentValidator = self.attachmentValidator
let stickerManager = self.stickerManager
return try await decryptionQueue.run {
guard
let stickerDataUrl = stickerManager.stickerDataUrl(
forInstalledSticker: sticker,
verifyExists: true,
)
else {
throw OWSAssertionError("Missing sticker")
}
let mimeType: MimeType
let imageMetadata = try? DataImageSource.forPath(stickerDataUrl.path).imageMetadata()
if let imageMetadata {
mimeType = imageMetadata.imageFormat.mimeType
} else {
mimeType = MimeType.imageWebp
}
return try await attachmentValidator.validateDataSourceContents(
DataSourcePath(fileUrl: stickerDataUrl, ownership: .borrowed),
mimeType: mimeType.rawValue,
renderingFlag: .borderless,
sourceFilename: nil,
)
}
}
func validateAndPrepare(
encryptedFileUrl: URL,
metadata: DownloadMetadata,
) async throws -> PendingAttachment {
let attachmentValidator = self.attachmentValidator
return try await decryptionQueue.run {
switch metadata.source {
case .transitTier(_, let integrityCheck, let plaintextLength):
return try await attachmentValidator.validateDownloadedContents(
ofEncryptedFileAt: encryptedFileUrl,
attachmentKey: AttachmentKey(combinedKey: metadata.encryptionKey),
plaintextLength: plaintextLength,
integrityCheck: integrityCheck,
mimeType: metadata.mimeType,
renderingFlag: .default,
sourceFilename: nil,
)
case .mediaTierFullsize(_, let outerEncryptionMetadata, let integrityCheck, let plaintextLength):
let innerPlaintextLength: UInt64? = {
guard let plaintextLength else { return nil }
return UInt64(safeCast: plaintextLength)
}()
return try await attachmentValidator.validateBackupMediaFileContents(
fileUrl: encryptedFileUrl,
outerDecryptionData: DecryptionMetadata(key: outerEncryptionMetadata.attachmentKey()),
innerDecryptionData: DecryptionMetadata(
key: AttachmentKey(combinedKey: metadata.encryptionKey),
integrityCheck: integrityCheck,
plaintextLength: innerPlaintextLength,
),
finalAttachmentKey: AttachmentKey(combinedKey: metadata.encryptionKey),
mimeType: metadata.mimeType,
renderingFlag: .default,
sourceFilename: nil,
)
case .mediaTierThumbnail(_, let outerEncryptionMetadata, let innerEncryptionData):
return try await attachmentValidator.validateBackupMediaFileContents(
fileUrl: encryptedFileUrl,
outerDecryptionData: DecryptionMetadata(key: outerEncryptionMetadata.attachmentKey()),
innerDecryptionData: DecryptionMetadata(key: innerEncryptionData.attachmentKey()),
finalAttachmentKey: AttachmentKey(combinedKey: metadata.encryptionKey),
mimeType: metadata.mimeType,
renderingFlag: .default,
sourceFilename: nil,
)
case .linkNSyncBackup:
throw OWSAssertionError("Should not be validating link'n'sync backups")
}
}
}
func prepareQuotedReplyThumbnail(originalAttachmentStream: AttachmentStream) async throws -> PendingAttachment {
let attachmentValidator = self.attachmentValidator
return try await decryptionQueue.run {
return try await attachmentValidator.prepareQuotedReplyThumbnail(
fromOriginalAttachmentStream: originalAttachmentStream,
)
}
}
}
private class AttachmentUpdater {
private let attachmentStore: AttachmentStore
private let backupAttachmentUploadScheduler: BackupAttachmentUploadScheduler
private let db: any DB
private let decrypter: Decrypter
private let interactionStore: InteractionStore
private let orphanedAttachmentCleaner: OrphanedAttachmentCleaner
private let orphanedAttachmentStore: OrphanedAttachmentStore
private let orphanedBackupAttachmentScheduler: OrphanedBackupAttachmentScheduler
private let storyStore: StoryStore
private let threadStore: ThreadStore
init(
attachmentStore: AttachmentStore,
backupAttachmentUploadScheduler: BackupAttachmentUploadScheduler,
db: any DB,
decrypter: Decrypter,
interactionStore: InteractionStore,
orphanedAttachmentCleaner: OrphanedAttachmentCleaner,
orphanedAttachmentStore: OrphanedAttachmentStore,
orphanedBackupAttachmentScheduler: OrphanedBackupAttachmentScheduler,
storyStore: StoryStore,
threadStore: ThreadStore,
) {
self.attachmentStore = attachmentStore
self.backupAttachmentUploadScheduler = backupAttachmentUploadScheduler
self.db = db
self.decrypter = decrypter
self.interactionStore = interactionStore
self.orphanedAttachmentCleaner = orphanedAttachmentCleaner
self.orphanedAttachmentStore = orphanedAttachmentStore
self.orphanedBackupAttachmentScheduler = orphanedBackupAttachmentScheduler
self.storyStore = storyStore
self.threadStore = threadStore
}
func updateAttachmentAsDownloaded(
attachmentId: Attachment.IDType,
pendingAttachment: PendingAttachment,
source: QueuedAttachmentDownloadRecord.SourceType,
priority: AttachmentDownloadPriority,
timestamp: UInt64,
) async throws -> DownloadResult {
return try await db.awaitableWrite { tx in
guard orphanedAttachmentStore.orphanAttachmentExists(with: pendingAttachment.orphanRecordId, tx: tx) else {
throw OWSAssertionError("Attachment file deleted before creation")
}
// We may find, when we go to update the attachment, that we had
// already downloaded this attachment. If so, we'll thereafter
// want to refer to the existing attachment instead.
var attachmentId = attachmentId
guard var attachmentWeJustDownloaded = attachmentStore.fetch(id: attachmentId, tx: tx) else {
throw OWSGenericError("Missing attachment after download; could have been deleted while downloading.")
}
if let stream = attachmentWeJustDownloaded.asStream() {
// Its already a stream?
return .stream(stream)
}
let mediaName = Attachment.mediaName(
sha256ContentHash: pendingAttachment.sha256ContentHash,
encryptionKey: pendingAttachment.encryptionKey,
)
let streamInfo = Attachment.StreamInfo(
sha256ContentHash: pendingAttachment.sha256ContentHash,
mediaName: mediaName,
encryptedByteCount: pendingAttachment.encryptedByteCount,
unencryptedByteCount: pendingAttachment.unencryptedByteCount,
contentType: pendingAttachment.validatedContentType,
digestSHA256Ciphertext: pendingAttachment.digestSHA256Ciphertext,
localRelativeFilePath: pendingAttachment.localRelativeFilePath,
)
// Try and update the attachment.
do throws(AttachmentInsertError) {
try self.attachmentStore.updateAttachmentAsDownloaded(
attachment: attachmentWeJustDownloaded,
sourceType: source,
priority: priority,
validatedMimeType: pendingAttachment.mimeType,
streamInfo: streamInfo,
timestamp: timestamp,
tx: tx,
)
// Make sure to clear out the pending attachment from the orphan table so it isn't deleted!
self.orphanedAttachmentCleaner.releasePendingAttachment(withId: pendingAttachment.orphanRecordId, tx: tx)
self.orphanedBackupAttachmentScheduler.didCreateOrUpdateAttachment(
withMediaName: mediaName,
tx: tx,
)
} catch {
let existingAttachmentId: Attachment.IDType = try AttachmentManagerImpl.handleAttachmentInsertError(
error,
pendingAttachmentStreamInfo: streamInfo,
pendingAttachmentEncryptionKey: pendingAttachment.encryptionKey,
pendingAttachmentMimeType: pendingAttachment.mimeType,
pendingAttachmentOrphanRecordId: pendingAttachment.orphanRecordId,
pendingAttachmentLatestTransitTierInfo: attachmentWeJustDownloaded.latestTransitTierInfo,
pendingAttachmentOriginalTransitTierInfo: attachmentWeJustDownloaded.originalTransitTierInfo,
attachmentStore: attachmentStore,
orphanedAttachmentCleaner: orphanedAttachmentCleaner,
orphanedAttachmentStore: orphanedAttachmentStore,
backupAttachmentUploadScheduler: backupAttachmentUploadScheduler,
orphanedBackupAttachmentScheduler: orphanedBackupAttachmentScheduler,
tx: tx,
)
// Already have an attachment with the same plaintext hash or media name!
// Move all existing references to that copy, instead.
// Doing so should delete the original attachment pointer.
// Just hold all refs in memory; there shouldn't in practice be
// so many pointers to the same attachment.
var references = [AttachmentReference]()
self.attachmentStore.enumerateAllReferences(
toAttachmentId: attachmentId,
tx: tx,
) { reference, _ in
references.append(reference)
}
for reference in references {
self.attachmentStore.removeReference(
reference: reference,
tx: tx,
)
let newOwnerParams = AttachmentReference.ConstructionParams(
owner: reference.owner.forReassignmentWithContentType(pendingAttachment.validatedContentType.raw),
sourceFilename: reference.sourceFilename,
sourceUnencryptedByteCount: reference.sourceUnencryptedByteCount,
sourceMediaSizePixels: reference.sourceMediaSizePixels,
)
attachmentStore.addReference(
newOwnerParams,
attachmentRowId: existingAttachmentId,
tx: tx,
)
}
attachmentId = existingAttachmentId
}
// Refetch the attachment, to reflect `updateAttachmentAsDownloaded`.
attachmentWeJustDownloaded = attachmentStore.fetch(id: attachmentId, tx: tx)!
tx.addSyncCompletion { [self] in
db.asyncWrite { [self] tx in
touchAllOwners(attachmentId: attachmentId, tx: tx)
}
}
switch source {
case .transitTier:
guard let stream = attachmentWeJustDownloaded.asStream() else {
throw OWSAssertionError("Not a stream")
}
// After we download an attachment and verify its digest, we can
// schedule it for "upload" to the media (backup) tier. "Upload"
// really means "copy from transit tier" and since we just downloaded
// we shouldn't need to reupload to do that; we just needed to verify
// the digest before copying.
backupAttachmentUploadScheduler.enqueueUsingHighestPriorityOwnerIfNeeded(
attachmentWeJustDownloaded,
mode: .all,
tx: tx,
)
tx.addSyncCompletion {
NotificationCenter.default.post(name: .startBackupAttachmentUploadQueue, object: nil)
}
return .stream(stream)
case .mediaTierFullsize:
guard let stream = attachmentWeJustDownloaded.asStream() else {
throw OWSAssertionError("Not a stream")
}
return .stream(stream)
case .mediaTierThumbnail:
guard let thumbnail = attachmentWeJustDownloaded.asBackupThumbnail() else {
throw OWSAssertionError("Not a thumbnail")
}
return .thumbnail(thumbnail)
}
}
}
func updateAttachmentFromInstalledSticker(
attachmentId: Attachment.IDType,
pendingAttachment: PendingAttachment,
) async throws -> AttachmentStream {
return try await db.awaitableWriteWithRollbackIfThrows { tx -> AttachmentStream in
guard let existingAttachment = self.attachmentStore.fetch(id: attachmentId, tx: tx) else {
throw OWSGenericError("Missing attachment after download; could have been deleted while downloading.")
}
if let stream = existingAttachment.asStream() {
// Its already a stream?
return stream
}
var references = [AttachmentReference]()
self.attachmentStore.enumerateAllReferences(
toAttachmentId: attachmentId,
tx: tx,
) { reference, _ in
references.append(reference)
}
// Arbitrarily pick the first reference as the one we will use as the initial ref to
// the new stream. The others' references will be re-pointed to the new stream afterwards.
guard let firstReference = references.first else {
throw OWSAssertionError("Attachments should never have zero references")
}
self.attachmentStore.removeReference(
reference: firstReference,
tx: tx,
)
let mediaName = Attachment.mediaName(
sha256ContentHash: pendingAttachment.sha256ContentHash,
encryptionKey: pendingAttachment.encryptionKey,
)
let streamInfo = Attachment.StreamInfo(
sha256ContentHash: pendingAttachment.sha256ContentHash,
mediaName: mediaName,
encryptedByteCount: pendingAttachment.encryptedByteCount,
unencryptedByteCount: pendingAttachment.unencryptedByteCount,
contentType: pendingAttachment.validatedContentType,
digestSHA256Ciphertext: pendingAttachment.digestSHA256Ciphertext,
localRelativeFilePath: pendingAttachment.localRelativeFilePath,
)
let alreadyAssignedFirstReference: Bool
let newAttachmentStream: AttachmentStream
do {
guard self.orphanedAttachmentStore.orphanAttachmentExists(with: pendingAttachment.orphanRecordId, tx: tx) else {
throw OWSAssertionError("Attachment file deleted before creation")
}
let mediaSizePixels: CGSize?
switch pendingAttachment.validatedContentType {
case .invalid, .file, .audio:
mediaSizePixels = nil
case .image(let pixelSize), .video(_, let pixelSize, _), .animatedImage(let pixelSize):
mediaSizePixels = pixelSize
}
let referenceParams = AttachmentReference.ConstructionParams(
owner: firstReference.owner,
sourceFilename: firstReference.sourceFilename,
sourceUnencryptedByteCount: pendingAttachment.unencryptedByteCount,
sourceMediaSizePixels: mediaSizePixels,
)
var attachmentRecord = Attachment.Record.forInsertingStream(
blurHash: pendingAttachment.blurHash,
mimeType: pendingAttachment.mimeType,
encryptionKey: pendingAttachment.encryptionKey,
streamInfo: streamInfo,
sha256ContentHash: pendingAttachment.sha256ContentHash,
mediaName: mediaName,
)
let attachment = try self.attachmentStore.insert(
&attachmentRecord,
reference: referenceParams,
tx: tx,
)
self.orphanedBackupAttachmentScheduler.didCreateOrUpdateAttachment(
withMediaName: mediaName,
tx: tx,
)
backupAttachmentUploadScheduler.enqueueUsingHighestPriorityOwnerIfNeeded(
attachment,
mode: .all,
tx: tx,
)
tx.addSyncCompletion {
NotificationCenter.default.post(name: .startBackupAttachmentUploadQueue, object: nil)
}
// Make sure to clear out the pending attachment from the orphan table so it isn't deleted!
self.orphanedAttachmentCleaner.releasePendingAttachment(withId: pendingAttachment.orphanRecordId, tx: tx)
newAttachmentStream = AttachmentStream(
attachment: attachment,
info: streamInfo,
)
alreadyAssignedFirstReference = true
} catch let error {
let existingAttachmentId: Attachment.IDType
if let error = error as? AttachmentInsertError {
existingAttachmentId = try AttachmentManagerImpl.handleAttachmentInsertError(
error,
pendingAttachmentStreamInfo: streamInfo,
pendingAttachmentEncryptionKey: pendingAttachment.encryptionKey,
pendingAttachmentMimeType: pendingAttachment.mimeType,
pendingAttachmentOrphanRecordId: pendingAttachment.orphanRecordId,
pendingAttachmentLatestTransitTierInfo: nil,
pendingAttachmentOriginalTransitTierInfo: nil,
attachmentStore: attachmentStore,
orphanedAttachmentCleaner: orphanedAttachmentCleaner,
orphanedAttachmentStore: orphanedAttachmentStore,
backupAttachmentUploadScheduler: backupAttachmentUploadScheduler,
orphanedBackupAttachmentScheduler: orphanedBackupAttachmentScheduler,
tx: tx,
)
tx.addSyncCompletion {
NotificationCenter.default.post(name: .startBackupAttachmentUploadQueue, object: nil)
}
} else {
throw error
}
// Already have an attachment with the same plaintext hash or media name!
// We will instead re-point all references to this attachment.
guard
let attachment = attachmentStore.fetch(
id: existingAttachmentId,
tx: tx,
),
let attachmentStream = attachment.asStream()
else {
throw OWSAssertionError("Missing stream for attachment we just matched against!")
}
newAttachmentStream = attachmentStream
alreadyAssignedFirstReference = false
}
// Move all existing references to the new thumbnail stream.
let referencesToUpdate = alreadyAssignedFirstReference
? references.suffix(max(references.count - 1, 0))
: references
for reference in referencesToUpdate {
attachmentStore.removeReference(
reference: reference,
tx: tx,
)
let newOwnerParams = AttachmentReference.ConstructionParams(
owner: reference.owner.forReassignmentWithContentType(
newAttachmentStream.contentType.raw,
),
sourceFilename: reference.sourceFilename,
sourceUnencryptedByteCount: reference.sourceUnencryptedByteCount,
sourceMediaSizePixels: reference.sourceMediaSizePixels,
)
attachmentStore.addReference(
newOwnerParams,
attachmentRowId: newAttachmentStream.attachment.id,
tx: tx,
)
}
references.forEach { reference in
// Its ok to point at the old owner here; its the same message id
// or story message id etc, which is what we use for this.
self.touchOwner(reference.owner, tx: tx)
}
return newAttachmentStream
}
}
func copyThumbnailForQuotedReplyIfNeeded(_ downloadedAttachment: AttachmentStream) async throws {
let thumbnailAttachments = db.read { tx in
return self.attachmentStore.allQuotedReplyAttachments(
forOriginalAttachmentId: downloadedAttachment.attachment.id,
tx: tx,
)
}
guard thumbnailAttachments.contains(where: { $0.asStream() == nil }) else {
// all the referencing thumbnails already have their own streams, nothing to do.
return
}
let pendingThumbnailAttachment = try await self.decrypter.prepareQuotedReplyThumbnail(
originalAttachmentStream: downloadedAttachment,
)
try await db.awaitableWriteWithRollbackIfThrows { tx in
let alreadyAssignedFirstReference: Bool
let thumbnailAttachments = attachmentStore
.allQuotedReplyAttachments(
forOriginalAttachmentId: downloadedAttachment.attachment.id,
tx: tx,
)
.filter({ $0.asStream() == nil })
let references = thumbnailAttachments.flatMap { attachment in
var refs = [AttachmentReference]()
self.attachmentStore.enumerateAllReferences(
toAttachmentId: attachment.id,
tx: tx,
) { ref, _ in
refs.append(ref)
}
return refs
}
// Arbitrarily pick the first thumbnail as the one we will use as the initial ref to
// the new stream. The others' references will be re-pointed to the new stream afterwards.
guard let firstReference = references.first else {
// Nothing to update.
return
}
attachmentStore.removeReference(
reference: firstReference,
tx: tx,
)
let mediaName = Attachment.mediaName(
sha256ContentHash: pendingThumbnailAttachment.sha256ContentHash,
encryptionKey: pendingThumbnailAttachment.encryptionKey,
)
let streamInfo = Attachment.StreamInfo(
sha256ContentHash: pendingThumbnailAttachment.sha256ContentHash,
mediaName: mediaName,
encryptedByteCount: pendingThumbnailAttachment.encryptedByteCount,
unencryptedByteCount: pendingThumbnailAttachment.unencryptedByteCount,
contentType: pendingThumbnailAttachment.validatedContentType,
digestSHA256Ciphertext: pendingThumbnailAttachment.digestSHA256Ciphertext,
localRelativeFilePath: pendingThumbnailAttachment.localRelativeFilePath,
)
let thumbnailAttachmentId: Attachment.IDType
do {
guard self.orphanedAttachmentStore.orphanAttachmentExists(with: pendingThumbnailAttachment.orphanRecordId, tx: tx) else {
throw OWSAssertionError("Attachment file deleted before creation")
}
let mediaSizePixels: CGSize?
switch pendingThumbnailAttachment.validatedContentType {
case .invalid, .file, .audio:
mediaSizePixels = nil
case .image(let pixelSize), .video(_, let pixelSize, _), .animatedImage(let pixelSize):
mediaSizePixels = pixelSize
}
let referenceParams = AttachmentReference.ConstructionParams(
owner: firstReference.owner,
sourceFilename: firstReference.sourceFilename,
sourceUnencryptedByteCount: pendingThumbnailAttachment.unencryptedByteCount,
sourceMediaSizePixels: mediaSizePixels,
)
var attachmentRecord = Attachment.Record.forInsertingStream(
blurHash: pendingThumbnailAttachment.blurHash,
mimeType: pendingThumbnailAttachment.mimeType,
encryptionKey: pendingThumbnailAttachment.encryptionKey,
streamInfo: streamInfo,
sha256ContentHash: pendingThumbnailAttachment.sha256ContentHash,
mediaName: mediaName,
)
let newAttachment = try self.attachmentStore.insert(
&attachmentRecord,
reference: referenceParams,
tx: tx,
)
self.orphanedBackupAttachmentScheduler.didCreateOrUpdateAttachment(
withMediaName: mediaName,
tx: tx,
)
backupAttachmentUploadScheduler.enqueueUsingHighestPriorityOwnerIfNeeded(
newAttachment,
mode: .all,
tx: tx,
)
tx.addSyncCompletion {
NotificationCenter.default.post(name: .startBackupAttachmentUploadQueue, object: nil)
}
// Make sure to clear out the pending attachment from the orphan table so it isn't deleted!
self.orphanedAttachmentCleaner.releasePendingAttachment(withId: pendingThumbnailAttachment.orphanRecordId, tx: tx)
guard
let referencedAttachment = attachmentStore.fetchAnyReferencedAttachment(
for: referenceParams.owner.id,
tx: tx,
)
else {
throw OWSAssertionError("Missing attachment we just created")
}
thumbnailAttachmentId = referencedAttachment.attachment.id
alreadyAssignedFirstReference = true
} catch let error {
let existingAttachmentId: Attachment.IDType
if let error = error as? AttachmentInsertError {
existingAttachmentId = try AttachmentManagerImpl.handleAttachmentInsertError(
error,
pendingAttachmentStreamInfo: streamInfo,
pendingAttachmentEncryptionKey: pendingThumbnailAttachment.encryptionKey,
pendingAttachmentMimeType: pendingThumbnailAttachment.mimeType,
pendingAttachmentOrphanRecordId: pendingThumbnailAttachment.orphanRecordId,
pendingAttachmentLatestTransitTierInfo: nil,
pendingAttachmentOriginalTransitTierInfo: nil,
attachmentStore: attachmentStore,
orphanedAttachmentCleaner: orphanedAttachmentCleaner,
orphanedAttachmentStore: orphanedAttachmentStore,
backupAttachmentUploadScheduler: backupAttachmentUploadScheduler,
orphanedBackupAttachmentScheduler: orphanedBackupAttachmentScheduler,
tx: tx,
)
tx.addSyncCompletion {
NotificationCenter.default.post(name: .startBackupAttachmentUploadQueue, object: nil)
}
} else {
throw error
}
// Already have an attachment with the same plaintext hash or media name!
// We will instead re-point all references to this attachment.
thumbnailAttachmentId = existingAttachmentId
alreadyAssignedFirstReference = false
}
// Move all existing references to the new thumbnail stream.
let referencesToUpdate = alreadyAssignedFirstReference
? references.suffix(max(references.count - 1, 0))
: references
for reference in referencesToUpdate {
attachmentStore.removeReference(
reference: reference,
tx: tx,
)
let newOwnerParams = AttachmentReference.ConstructionParams(
owner: reference.owner.forReassignmentWithContentType(pendingThumbnailAttachment.validatedContentType.raw),
sourceFilename: reference.sourceFilename,
sourceUnencryptedByteCount: reference.sourceUnencryptedByteCount,
sourceMediaSizePixels: reference.sourceMediaSizePixels,
)
attachmentStore.addReference(
newOwnerParams,
attachmentRowId: thumbnailAttachmentId,
tx: tx,
)
}
references.forEach { reference in
// Its ok to point at the old owner here; its the same message id
// or story message id etc, which is what we use for this.
self.touchOwner(reference.owner, tx: tx)
}
if let thumbnailAttachment = attachmentStore.fetch(id: thumbnailAttachmentId, tx: tx)?.asStream() {
// Schedule upload, if needed.
backupAttachmentUploadScheduler.enqueueUsingHighestPriorityOwnerIfNeeded(
thumbnailAttachment.attachment,
mode: .all,
tx: tx,
)
tx.addSyncCompletion {
NotificationCenter.default.post(name: .startBackupAttachmentUploadQueue, object: nil)
}
}
}
}
func touchAllOwners(attachmentId: Attachment.IDType, tx: DBWriteTransaction) {
self.attachmentStore.enumerateAllReferences(
toAttachmentId: attachmentId,
tx: tx,
) { reference, _ in
touchOwner(reference.owner, tx: tx)
}
}
func touchOwner(_ owner: AttachmentReference.Owner, tx: DBWriteTransaction) {
switch owner {
case .thread:
// TODO: perhaps a mechanism to update a thread once wallpaper is loaded?
break
case .message(let messageSource):
guard
let interaction = interactionStore.fetchInteraction(
rowId: messageSource.messageRowId,
tx: tx,
)
else {
break
}
db.touch(interaction: interaction, shouldReindex: false, tx: tx)
case .storyMessage(let storyMessageSource):
guard
let storyMessage = storyStore.fetchStoryMessage(
rowId: storyMessageSource.storyMessageRowId,
tx: tx,
)
else {
break
}
db.touch(storyMessage: storyMessage, tx: tx)
}
}
}
}
extension AttachmentDownloadManagerImpl {
public enum Shims {
public typealias StickerManager = _AttachmentDownloadManagerImpl_StickerManagerShim
}
public enum Wrappers {
public typealias StickerManager = _AttachmentDownloadManagerImpl_StickerManagerWrapper
}
}
public protocol _AttachmentDownloadManagerImpl_StickerManagerShim {
func fetchInstalledSticker(packId: Data, stickerId: UInt32, tx: DBReadTransaction) -> InstalledStickerRecord?
func stickerDataUrl(forInstalledSticker: InstalledStickerRecord, verifyExists: Bool) -> URL?
}
public class _AttachmentDownloadManagerImpl_StickerManagerWrapper: _AttachmentDownloadManagerImpl_StickerManagerShim {
public init() {}
public func fetchInstalledSticker(packId: Data, stickerId: UInt32, tx: DBReadTransaction) -> InstalledStickerRecord? {
return StickerManager.fetchInstalledSticker(packId: packId, stickerId: stickerId, transaction: tx)
}
public func stickerDataUrl(forInstalledSticker: InstalledStickerRecord, verifyExists: Bool) -> URL? {
return StickerManager.stickerDataUrl(forInstalledSticker: forInstalledSticker, verifyExists: verifyExists)
}
}