Path: blob/main/SignalServiceKit/Network/OWSUrlSession.swift
1 views
//
// Copyright 2020 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
public enum OWSURLSessionError: Error {
case responseTooLarge
}
public class OWSURLSession: OWSURLSessionProtocol {
public typealias ProgressBlock = (_ completedByteCount: Int64, _ totalByteCount: Int64) async -> Void
// MARK: - OWSURLSessionProtocol conformance
public let endpoint: OWSURLSessionEndpoint
public var require2xxOr3xx: Bool {
get {
_require2xxOr3xx.get()
}
set {
_require2xxOr3xx.set(newValue)
}
}
public var shouldHandleRemoteDeprecation: Bool {
get {
_shouldHandleRemoteDeprecation.get()
}
set {
_shouldHandleRemoteDeprecation.set(newValue)
}
}
public var allowRedirects: Bool {
get {
_allowRedirects.get()
}
set {
owsAssertDebug(customRedirectHandler == nil || newValue)
_allowRedirects.set(newValue)
}
}
public var customRedirectHandler: ((URLRequest) -> URLRequest?)? {
get {
_customRedirectHandler.get()
}
set {
owsAssertDebug(newValue == nil || allowRedirects)
_customRedirectHandler.set(newValue)
}
}
private let onFailureCallback: ((any Error) -> Void)?
// Note: not all protocol methods can be made visible to objc, but those
// that can be are declared so here. Objc callers must use this implementation
// directly and not touch the protocol.
public static let defaultSecurityPolicy = HttpSecurityPolicy.systemDefault
public static let signalServiceSecurityPolicy = HttpSecurityPolicy.signalCaPinned
public static var defaultConfigurationWithCaching: URLSessionConfiguration {
.ephemeral
}
public static var defaultConfigurationWithoutCaching: URLSessionConfiguration {
let configuration = URLSessionConfiguration.ephemeral
configuration.urlCache = nil
configuration.requestCachePolicy = .reloadIgnoringLocalCacheData
return configuration
}
// MARK: Default Headers
public static var userAgentHeaderKey: String { HttpHeaders.userAgentHeaderKey }
public static var userAgentHeaderValueSignalIos: String { HttpHeaders.userAgentHeaderValueSignalIos }
public static var acceptLanguageHeaderKey: String { HttpHeaders.acceptLanguageHeaderKey }
public static var acceptLanguageHeaderValue: String { HttpHeaders.acceptLanguageHeaderValue }
// MARK: Initializers
public required init(
endpoint: OWSURLSessionEndpoint,
configuration: URLSessionConfiguration,
maxResponseSize: UInt64?,
canUseSignalProxy: Bool,
onFailureCallback: ((any Error) -> Void)?,
) {
if canUseSignalProxy {
configuration.connectionProxyDictionary = SignalProxy.connectionProxyDictionary
}
self.endpoint = endpoint
self.configuration = configuration
self.maxResponseSize = maxResponseSize
self.canUseSignalProxy = canUseSignalProxy
self.onFailureCallback = onFailureCallback
// Ensure this is set so that we don't try to create it in deinit().
_ = self.delegateBox
}
public convenience init(
securityPolicy: HttpSecurityPolicy,
configuration: URLSessionConfiguration,
) {
self.init(
endpoint: OWSURLSessionEndpoint(
baseUrl: nil,
frontingInfo: nil,
securityPolicy: securityPolicy,
extraHeaders: [:],
),
configuration: configuration,
maxResponseSize: nil,
canUseSignalProxy: false,
)
}
public convenience init(
baseUrl: URL? = nil,
securityPolicy: HttpSecurityPolicy,
configuration: URLSessionConfiguration,
extraHeaders: HttpHeaders = HttpHeaders(),
maxResponseSize: UInt64? = nil,
canUseSignalProxy: Bool = false,
) {
self.init(
endpoint: OWSURLSessionEndpoint(
baseUrl: baseUrl,
frontingInfo: nil,
securityPolicy: securityPolicy,
extraHeaders: extraHeaders,
),
configuration: configuration,
maxResponseSize: maxResponseSize,
canUseSignalProxy: canUseSignalProxy,
)
}
// MARK: Tasks
public func performUpload(
request: URLRequest,
requestData: Data,
progressBlock: ProgressBlock,
) async throws -> HTTPResponse {
return try await performUpload(
request: request,
ignoreAppExpiry: false,
progressBlock: progressBlock,
taskBlock: { self.session.uploadTask(with: request, from: requestData) },
)
}
public func performUpload(
request: URLRequest,
fileUrl: URL,
ignoreAppExpiry: Bool,
progressBlock: ProgressBlock,
) async throws -> HTTPResponse {
return try await performUpload(
request: request,
ignoreAppExpiry: ignoreAppExpiry,
progressBlock: progressBlock,
taskBlock: { self.session.uploadTask(with: request, fromFile: fileUrl) },
)
}
public func performRequest(request: URLRequest, ignoreAppExpiry: Bool) async throws -> HTTPResponse {
if !ignoreAppExpiry, DependenciesBridge.shared.appExpiry.isExpired(now: Date()) {
throw AppExpiredError()
}
let request = prepareRequest(request: request)
let requestConfig = self.requestConfig(requestUrl: request.url!)
let task = session.dataTask(with: request)
let (urlResponse, responseData) = try await runTask(
task,
taskState: { DataTaskState(progress: $0, completion: $1) },
progressBlock: { _, _ in },
)
return try await handleDataResult(
urlResponse: urlResponse,
responseData: responseData,
originalRequest: task.originalRequest,
requestConfig: requestConfig,
)
}
public func performDownload(
request: URLRequest,
progressBlock: ProgressBlock,
) async throws -> OWSUrlDownloadResponse {
let request = prepareRequest(request: request)
guard let requestUrl = request.url else {
throw OWSAssertionError("Request missing url.")
}
return try await performDownload(requestUrl: requestUrl, progressBlock: progressBlock) {
// Don't use a completion block or the delegate will be ignored for download tasks.
return self.session.downloadTask(with: request)
}
}
public func performDownload(
requestUrl: URL,
resumeData: Data,
progressBlock: ProgressBlock,
) async throws -> OWSUrlDownloadResponse {
return try await performDownload(requestUrl: requestUrl, progressBlock: progressBlock) {
// Don't use a completion block or the delegate will be ignored for download tasks.
return self.session.downloadTask(withResumeData: resumeData)
}
}
public func webSocketTask(requestUrl: URL, didOpenBlock: @escaping (String?) -> Void, didCloseBlock: @escaping (Error) -> Void) -> URLSessionWebSocketTask {
// We can't pass a URLRequest here since it prevents the proxy from
// operating correctly. See `SSKWebSocketNative.init(...)` for more details
// and an example of passing URLRequest options via this web socket.
let task = session.webSocketTask(with: requestUrl)
addTask(task, taskState: WebSocketTaskState(openBlock: didOpenBlock, closeBlock: didCloseBlock))
return task
}
// MARK: - Internal Implementation
private static let operationQueue: OperationQueue = {
let queue = OperationQueue()
queue.underlyingQueue = .global()
return queue
}()
// MARK: Backing Vars
private let _require2xxOr3xx = AtomicBool(true, lock: .sharedGlobal)
private let _shouldHandleRemoteDeprecation = AtomicBool(false, lock: .sharedGlobal)
private let _allowRedirects = AtomicBool(true, lock: .sharedGlobal)
private let _customRedirectHandler = AtomicOptional<(URLRequest) -> URLRequest?>(nil, lock: .sharedGlobal)
// MARK: Internal vars
private let configuration: URLSessionConfiguration
private let _session = AtomicValue<URLSession?>(nil, lock: .init())
private var session: URLSession {
return _session.map {
return $0 ?? URLSession(configuration: configuration, delegate: delegateBox, delegateQueue: Self.operationQueue)
}!
}
private let maxResponseSize: UInt64?
private let canUseSignalProxy: Bool
// MARK: Deinit
deinit {
// From NSURLSession.h
// If you do not invalidate the session by calling the invalidateAndCancel() or
// finishTasksAndInvalidate() method, your app leaks memory until it exits
//
// Even though there will be no reference cycle, underlying NSURLSession metadata
// is malloced and kept around as a root leak.
session.invalidateAndCancel()
}
// MARK: Configuration
private struct RequestConfig {
let requestUrl: URL
let require2xxOr3xx: Bool
let shouldHandleRemoteDeprecation: Bool
}
private func requestConfig(requestUrl: URL) -> RequestConfig {
// Snapshot session state at time request is made.
return RequestConfig(
requestUrl: requestUrl,
require2xxOr3xx: require2xxOr3xx,
shouldHandleRemoteDeprecation: shouldHandleRemoteDeprecation,
)
}
private func handleDataResult(urlResponse: URLResponse?, responseData: Data, originalRequest: URLRequest?, requestConfig: RequestConfig) async throws -> HTTPResponse {
let httpUrlResponse = try await handleResult(urlResponse: urlResponse, responseData: responseData, originalRequest: originalRequest, requestConfig: requestConfig)
return HTTPResponse(requestUrl: requestConfig.requestUrl, httpUrlResponse: httpUrlResponse, bodyData: responseData)
}
private func handleDownloadResult(urlResponse: URLResponse?, downloadUrl: URL, originalRequest: URLRequest?, requestConfig: RequestConfig) async throws -> OWSUrlDownloadResponse {
let httpUrlResponse = try await handleResult(urlResponse: urlResponse, responseData: nil, originalRequest: originalRequest, requestConfig: requestConfig)
return OWSUrlDownloadResponse(httpUrlResponse: httpUrlResponse, downloadUrl: downloadUrl)
}
private func handleError(_ error: any Error, originalRequest: URLRequest?, requestConfig: RequestConfig) -> OWSHTTPError {
if error.isNetworkFailureOrTimeout {
return .networkFailure(.wrappedFailure(error))
}
#if TESTABLE_BUILD
if let originalRequest {
HTTPUtils.logCurl(for: originalRequest)
}
#endif
return .wrappedFailure(error)
}
private func handleResult(urlResponse: URLResponse?, responseData: Data?, originalRequest: URLRequest?, requestConfig: RequestConfig) async throws -> HTTPURLResponse {
if requestConfig.shouldHandleRemoteDeprecation {
await handleRemoteDeprecation(inResponse: urlResponse)
}
guard let httpUrlResponse = urlResponse as? HTTPURLResponse else {
throw OWSAssertionError("Invalid response: \(type(of: urlResponse)).")
}
if requestConfig.require2xxOr3xx {
let statusCode = httpUrlResponse.statusCode
guard statusCode >= 200, statusCode < 400 else {
#if TESTABLE_BUILD
if let originalRequest {
HTTPUtils.logCurl(for: originalRequest)
}
#endif
if statusCode > 0 {
let requestUrl = requestConfig.requestUrl
let responseHeaders = HttpHeaders(response: httpUrlResponse)
throw OWSHTTPError.serviceResponse(.init(
requestUrl: requestUrl,
responseStatus: statusCode,
responseHeaders: responseHeaders,
responseData: responseData,
))
} else {
owsFailDebug("Missing status code.")
throw OWSHTTPError.networkFailure(.invalidResponseStatus)
}
}
}
return httpUrlResponse
}
private func handleRemoteDeprecation(inResponse response: URLResponse?) async {
guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == AppExpiry.appExpiredStatusCode else {
return
}
let appExpiry = DependenciesBridge.shared.appExpiry
let db = DependenciesBridge.shared.db
await appExpiry.setHasAppExpiredAtCurrentVersion(db: db)
}
private func isResponseTooLarge(bytesReceived: Int64, bytesExpected: Int64) -> Bool {
if let maxResponseSize {
if bytesReceived > maxResponseSize {
return true
}
if bytesExpected != NSURLSessionTransferSizeUnknown, bytesExpected > maxResponseSize {
return true
}
}
return false
}
// MARK: Request building
// Ensure certain invariants for all requests.
private func prepareRequest(request: URLRequest) -> URLRequest {
var request = request
request.httpShouldHandleCookies = false
request = HttpHeaders.fillInMissingDefaultHeaders(request: request)
// Only requests to Signal services require CC.
// If frontingHost is nil, this instance of OWSURLSession does not perform CC.
if let frontingInfo = endpoint.frontingInfo, let urlString = request.url?.absoluteString.nilIfEmpty {
owsAssertDebug(frontingInfo.isFrontedUrl(urlString), "Unfronted URL: \(urlString)")
}
return request
}
// MARK: - Issuing Requests
public func performRequest(_ rawRequest: TSRequest) async throws -> HTTPResponse {
let appExpiry = DependenciesBridge.shared.appExpiry
guard !appExpiry.isExpired(now: Date()) else {
throw AppExpiredError()
}
var httpHeaders = rawRequest.headers
try rawRequest.applyAuth(to: &httpHeaders, socketAuth: nil)
let method: HTTPMethod
do {
method = try HTTPMethod.method(for: rawRequest.method)
} catch {
owsFailDebug("Invalid HTTP method: \(rawRequest.method)", logger: rawRequest.logger)
throw OWSHTTPError.invalidRequest
}
let requestBody: Data
switch rawRequest.body {
case .data(let bodyData):
requestBody = bodyData
case .parameters(let bodyParameters) where !bodyParameters.isEmpty:
do {
requestBody = try TSRequest.Body.encodedParameters(bodyParameters)
} catch {
owsFailDebug("Could not serialize JSON parameters: \(error).", logger: rawRequest.logger)
throw OWSHTTPError.invalidRequest
}
// If we're going to use the json serialized parameters as our body, we should overwrite
// the Content-Type on the request.
httpHeaders.addHeader("Content-Type", value: "application/json", overwriteOnConflict: true)
case .parameters:
requestBody = Data()
}
var request: URLRequest
do {
request = try self.endpoint.buildRequest(
rawRequest.url.absoluteString,
method: method,
headers: httpHeaders,
)
} catch {
owsFailDebug("Missing or invalid request: \(rawRequest.url).", logger: rawRequest.logger)
throw OWSHTTPError.invalidRequest
}
let backgroundTask = OWSBackgroundTask(label: "\(#function)")
defer {
backgroundTask.end()
}
request.timeoutInterval = rawRequest.timeoutInterval
do {
rawRequest.logger.info("Sending… -> \(rawRequest)")
let response = try await performUpload(request: request, requestData: requestBody, progressBlock: { _, _ in })
rawRequest.logger.info("HTTP \(response.responseStatusCode) <- \(rawRequest)")
return response
} catch where error.httpStatusCode != nil {
rawRequest.logger.warn("HTTP \(error.httpStatusCode!) <- \(rawRequest)")
throw error
} catch {
rawRequest.logger.warn("Failure. <- \(rawRequest): \(error)")
throw error
}
}
private func performUpload(
request: URLRequest,
ignoreAppExpiry: Bool,
progressBlock: ProgressBlock,
taskBlock: () -> URLSessionUploadTask,
) async throws -> HTTPResponse {
if !ignoreAppExpiry, DependenciesBridge.shared.appExpiry.isExpired(now: Date()) {
throw AppExpiredError()
}
let request = prepareRequest(request: request)
let requestConfig = requestConfig(requestUrl: request.url!)
let task = taskBlock()
let urlResponse: URLResponse?
let responseData: Data
do {
(urlResponse, responseData) = try await runTask(
task,
taskState: { DataTaskState(progress: $0, completion: $1) },
progressBlock: progressBlock,
)
} catch {
throw handleError(error, originalRequest: task.originalRequest, requestConfig: requestConfig)
}
return try await handleDataResult(
urlResponse: urlResponse,
responseData: responseData,
originalRequest: task.originalRequest,
requestConfig: requestConfig,
)
}
private func performDownload(
requestUrl: URL,
progressBlock: ProgressBlock,
taskBlock: () -> URLSessionDownloadTask,
) async throws -> OWSUrlDownloadResponse {
let appExpiry = DependenciesBridge.shared.appExpiry
if appExpiry.isExpired(now: Date()) {
throw AppExpiredError()
}
let requestConfig = self.requestConfig(requestUrl: requestUrl)
let task = taskBlock()
let (urlResponse, downloadUrl) = try await runTask(
task,
taskState: { DownloadTaskState(progress: $0, completion: $1) },
progressBlock: progressBlock,
)
return try await handleDownloadResult(
urlResponse: urlResponse,
downloadUrl: downloadUrl,
originalRequest: task.originalRequest,
requestConfig: requestConfig,
)
}
private func runTask<T>(
_ task: URLSessionTask,
taskState: (TaskState.ProgressContinuation, DeferredContinuation<T>) -> some TaskState,
progressBlock: ProgressBlock,
) async throws -> T {
// It's possible for operation and onCancel to race one another, so we use
// a counter to ensure that cancellation happens after addTask is invoked.
// (You can trigger this by sending a request from a canceled Task.)
let cancelState = AtomicUInt(lock: .init())
return try await withTaskCancellationHandler(
operation: {
let completion = DeferredContinuation<T>()
let progressStream = AsyncStream(bufferingPolicy: .bufferingNewest(1)) { continuation in
self.addTask(task, taskState: taskState(continuation, completion))
// If cancel was already called, cancel it now.
if cancelState.increment() == 2 {
task.cancel()
} else {
task.resume()
}
}
for await progressUpdate in progressStream {
await progressBlock(progressUpdate.completedByteCount, progressUpdate.totalByteCount)
}
return try await completion.wait()
},
onCancel: {
// If the task was already added, cancel it now.
if cancelState.increment() == 2 {
task.cancel()
}
},
)
}
// MARK: - TaskState
private let taskStates = AtomicValue([TaskIdentifier: TaskState](), lock: .init())
private lazy var delegateBox = URLSessionDelegateBox(delegate: self)
typealias TaskIdentifier = Int
private func updateTaskStates<T>(block: (inout [TaskIdentifier: TaskState]) throws -> T) rethrows -> T {
return try self.taskStates.update {
let result = try block(&$0)
delegateBox.isRetaining = !$0.isEmpty
return result
}
}
private func addTask(_ task: URLSessionTask, taskState: TaskState) {
updateTaskStates {
owsAssertDebug($0[task.taskIdentifier] == nil)
$0[task.taskIdentifier] = taskState
}
}
private func progress(forTask task: URLSessionTask) -> TaskState.ProgressContinuation? {
return updateTaskStates {
return $0[task.taskIdentifier]?.progress
}
}
private func dataTaskState(forTask task: URLSessionTask) -> DataTaskState? {
return updateTaskStates {
return $0[task.taskIdentifier] as? DataTaskState
}
}
private func webSocketState(forTask task: URLSessionTask) -> WebSocketTaskState? {
return updateTaskStates {
return $0[task.taskIdentifier] as? WebSocketTaskState
}
}
private func removeCompletedTaskState(_ task: URLSessionTask) -> TaskState? {
return updateTaskStates {
guard let taskState = $0[task.taskIdentifier] else {
// This isn't necessarily an error or bug.
// A task might "succeed" after it "fails" in certain edge cases,
// although we make a best effort to avoid them.
Logger.warn("Missing TaskState.")
return nil
}
$0[task.taskIdentifier] = nil
return taskState
}
}
private func downloadTaskDidSucceed(_ task: URLSessionTask, downloadUrl: URL) {
guard let taskState = removeCompletedTaskState(task) as? DownloadTaskState else {
owsFailDebug("Missing TaskState.")
return
}
taskState.progress?.finish()
taskState.completion.resume(with: .success((task.response, downloadUrl)))
}
private func dataTaskDidSucceed(_ task: URLSessionTask) {
guard let taskState = removeCompletedTaskState(task) as? DataTaskState else {
owsFailDebug("Missing TaskState.")
return
}
let responseData = taskState.pendingData.get()
taskState.progress?.finish()
taskState.completion.resume(with: .success((task.response, responseData)))
}
private func taskDidFail(_ task: URLSessionTask, error: Error) {
guard let taskState = removeCompletedTaskState(task) else {
Logger.warn("Missing TaskState.")
return
}
taskState.reject(error: error, task: task)
task.cancel()
onFailureCallback?(error)
}
// MARK: -
public typealias URLAuthenticationChallengeCompletion = (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
fileprivate func urlSession(
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping URLAuthenticationChallengeCompletion,
) {
var disposition: URLSession.AuthChallengeDisposition = .performDefaultHandling
var credential: URLCredential?
if
challenge.protectionSpace.authenticationMethod == NSURLAuthenticationMethodServerTrust,
let serverTrust = challenge.protectionSpace.serverTrust
{
if endpoint.securityPolicy.evaluate(serverTrust: serverTrust, domain: challenge.protectionSpace.host) {
credential = URLCredential(trust: serverTrust)
disposition = .useCredential
} else {
disposition = .cancelAuthenticationChallenge
}
} else {
disposition = .performDefaultHandling
}
completionHandler(disposition, credential)
}
}
// MARK: - Forwarded Delegate Methods
extension OWSURLSession {
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error {
taskDidFail(task, error: error)
} else if let dataTask = task as? URLSessionDataTask {
dataTaskDidSucceed(dataTask)
}
}
func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping URLAuthenticationChallengeCompletion,
) {
urlSession(didReceive: challenge, completionHandler: completionHandler)
}
func urlSession(
_ session: URLSession,
task: URLSessionTask,
willPerformHTTPRedirection response: HTTPURLResponse,
newRequest: URLRequest,
completionHandler: @escaping (URLRequest?) -> Void,
) {
guard allowRedirects else { return completionHandler(nil) }
if let customRedirectHandler {
completionHandler(customRedirectHandler(newRequest))
} else {
completionHandler(newRequest)
}
}
func urlSession(
_ session: URLSession,
task: URLSessionTask,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping URLAuthenticationChallengeCompletion,
) {
urlSession(didReceive: challenge, completionHandler: completionHandler)
}
func urlSession(_ session: URLSession, task: URLSessionTask, didSendBodyData bytesSent: Int64, totalBytesSent: Int64, totalBytesExpectedToSend: Int64) {
// TODO: We could check for NSURLSessionTransferSizeUnknown here.
self.progress(forTask: task)?.yield((totalBytesSent, totalBytesExpectedToSend))
}
func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
if let maxResponseSize {
let fileSize: UInt64
do {
fileSize = try OWSFileSystem.fileSize(of: location)
} catch {
taskDidFail(downloadTask, error: error)
return
}
guard fileSize <= maxResponseSize else {
taskDidFail(downloadTask, error: OWSURLSessionError.responseTooLarge)
return
}
}
do {
// Download locations are cleaned up quickly, so we
// need to move the file synchronously.
let temporaryUrl = OWSFileSystem.temporaryFileUrl(fileExtension: nil, isAvailableWhileDeviceLocked: true)
try OWSFileSystem.moveFile(from: location, to: temporaryUrl)
downloadTaskDidSucceed(downloadTask, downloadUrl: temporaryUrl)
} catch {
owsFailDebugUnlessNetworkFailure(error)
taskDidFail(downloadTask, error: error)
}
}
func urlSession(
_ session: URLSession,
downloadTask: URLSessionDownloadTask,
didWriteData bytesWritten: Int64,
totalBytesWritten: Int64,
totalBytesExpectedToWrite: Int64,
) {
if isResponseTooLarge(bytesReceived: totalBytesWritten, bytesExpected: totalBytesExpectedToWrite) {
taskDidFail(downloadTask, error: OWSURLSessionError.responseTooLarge)
return
}
self.progress(forTask: downloadTask)?.yield((totalBytesWritten, totalBytesExpectedToWrite))
}
func urlSession(
_ session: URLSession,
downloadTask: URLSessionDownloadTask,
didResumeAtOffset fileOffset: Int64,
expectedTotalBytes: Int64,
) {
if isResponseTooLarge(bytesReceived: fileOffset, bytesExpected: expectedTotalBytes) {
taskDidFail(downloadTask, error: OWSURLSessionError.responseTooLarge)
return
}
self.progress(forTask: downloadTask)?.yield((fileOffset, expectedTotalBytes))
}
func urlSession(
_ session: URLSession,
dataTask: URLSessionDataTask,
didReceive response: URLResponse,
completionHandler: @escaping (URLSession.ResponseDisposition) -> Void,
) {
if isResponseTooLarge(bytesReceived: 0, bytesExpected: response.expectedContentLength) {
taskDidFail(dataTask, error: OWSURLSessionError.responseTooLarge)
completionHandler(.cancel)
return
}
completionHandler(.allow)
}
func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
if isResponseTooLarge(bytesReceived: dataTask.countOfBytesReceived, bytesExpected: dataTask.countOfBytesExpectedToReceive) {
taskDidFail(dataTask, error: OWSURLSessionError.responseTooLarge)
return
}
dataTaskState(forTask: dataTask)?.pendingData.update { $0 += data }
}
func urlSession(_ session: URLSession, webSocketTask: URLSessionWebSocketTask, didOpenWithProtocol: String?) {
webSocketState(forTask: webSocketTask)?.openBlock(didOpenWithProtocol)
}
func urlSession(_ session: URLSession, webSocketTask: URLSessionWebSocketTask, didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?) {
guard let webSocketState = removeCompletedTaskState(webSocketTask) as? WebSocketTaskState else { return }
webSocketState.closeBlock(WebSocketError.closeError(statusCode: closeCode.rawValue, closeReason: reason))
}
}
// MARK: - TaskState
private protocol TaskState {
typealias ProgressContinuation = AsyncStream<(completedByteCount: Int64, totalByteCount: Int64)>.Continuation
var progress: ProgressContinuation? { get }
func reject(error: any Error, task: URLSessionTask)
}
// MARK: - DownloadTaskState
private class DownloadTaskState: TaskState {
let progress: ProgressContinuation?
let completion: DeferredContinuation<(URLResponse?, URL)>
init(progress: ProgressContinuation, completion: DeferredContinuation<(URLResponse?, URL)>) {
self.progress = progress
self.completion = completion
}
func reject(error: any Error, task: URLSessionTask) {
self.progress?.finish()
self.completion.resume(with: .failure(error))
}
}
// MARK: - DataTaskState (& UploadTaskState)
/// Also used for upload tasks, which are a subclass data tasks.
private class DataTaskState: TaskState {
let pendingData = AtomicValue<Data>(Data(), lock: .init())
let progress: ProgressContinuation?
let completion: DeferredContinuation<(URLResponse?, Data)>
init(progress: ProgressContinuation?, completion: DeferredContinuation<(URLResponse?, Data)>) {
self.progress = progress
self.completion = completion
}
func reject(error: any Error, task: URLSessionTask) {
self.progress?.finish()
self.completion.resume(with: .failure(error))
}
}
// MARK: - WebSocketTaskState
private class WebSocketTaskState: TaskState {
typealias OpenBlock = (String?) -> Void
typealias CloseBlock = (Error) -> Void
var progress: ProgressContinuation? { nil }
let openBlock: OpenBlock
let closeBlock: CloseBlock
init(openBlock: @escaping OpenBlock, closeBlock: @escaping CloseBlock) {
self.openBlock = openBlock
self.closeBlock = closeBlock
}
func reject(error: any Error, task: URLSessionTask) {
// We only want to return HTTP errors during the initial web socket
// upgrade. Once we've switched protocols, the HTTP response is no longer
// relevant but the property remains defined on the task. We use
// `badServerResponse` to distinguish errors during the initial handshake
// from other unexpected errors that occur later (eg losing internet).
if case URLError.badServerResponse = error, let httpResponse = task.response as? HTTPURLResponse {
let retryAfter = HttpHeaders(response: httpResponse).retryAfterDate
closeBlock(WebSocketError.httpError(statusCode: httpResponse.statusCode, retryAfter: retryAfter))
return
}
closeBlock(error)
}
}
// NSURLSession maintains a strong reference to its delegate until explicitly invalidated
// OWSURLSession acts as its own delegate, and may be retained by any number of owners
// We don't really know when to invalidate our session, because a caller may decide to reuse a session
// at any time.
//
// So here's the plan:
// - While we have any outstanding tasks, a strong reference cycle is maintained. Promise holders
// don't need to hold on to the session while waiting for a promise to resolve.
// i.e. OWSURLSession --(session)--> URLSession --(delegate)--> URLSessionDelegateBox
// ^-----------------------(strongReference)-------------------|
//
// - Once all outstanding tasks have been resolved, the box breaks its reference. If there are no
// external references to the OWSURLSession, then everything cleans itself up.
// i.e. OWSURLSession --(session)--> URLSession --(delegate)--> URLSessionDelegateBox
// x-----(weakDelegate)-----|
//
private class URLSessionDelegateBox: NSObject {
private weak var weakDelegate: OWSURLSession?
private var strongReference: OWSURLSession?
init(delegate: OWSURLSession) {
self.weakDelegate = delegate
}
var isRetaining: Bool {
get {
strongReference != nil
}
set {
strongReference = newValue ? weakDelegate : nil
}
}
}
// MARK: -
extension URLSessionDelegateBox: URLSessionDelegate, URLSessionTaskDelegate, URLSessionDownloadDelegate, URLSessionDataDelegate {
// Any of the optional methods will be forwarded using objc selector forwarding
// If all goes according to plan, weakDelegate will only go nil once everything is being dealloced
// But just in case, let's make sure we provide a fallback implementation to the only non-optional method we've conformed to
func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
weakDelegate?.urlSession(session, downloadTask: downloadTask, didFinishDownloadingTo: location)
}
func urlSession(
_ session: URLSession,
downloadTask: URLSessionDownloadTask,
didWriteData bytesWritten: Int64,
totalBytesWritten: Int64,
totalBytesExpectedToWrite: Int64,
) {
weakDelegate?.urlSession(
session,
downloadTask: downloadTask,
didWriteData: bytesWritten,
totalBytesWritten: totalBytesWritten,
totalBytesExpectedToWrite: totalBytesExpectedToWrite,
)
}
func urlSession(
_ session: URLSession,
downloadTask: URLSessionDownloadTask,
didResumeAtOffset fileOffset: Int64,
expectedTotalBytes: Int64,
) {
weakDelegate?.urlSession(
session,
downloadTask: downloadTask,
didResumeAtOffset: fileOffset,
expectedTotalBytes: expectedTotalBytes,
)
}
public typealias URLAuthenticationChallengeCompletion = OWSURLSession.URLAuthenticationChallengeCompletion
func urlSession(
_ session: URLSession,
task: URLSessionTask,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping URLAuthenticationChallengeCompletion,
) {
weakDelegate?.urlSession(
session,
task: task,
didReceive: challenge,
completionHandler: completionHandler,
)
}
func urlSession(
_ session: URLSession,
task: URLSessionTask,
didSendBodyData bytesSent: Int64,
totalBytesSent: Int64,
totalBytesExpectedToSend: Int64,
) {
weakDelegate?.urlSession(
session,
task: task,
didSendBodyData: bytesSent,
totalBytesSent: totalBytesSent,
totalBytesExpectedToSend: totalBytesExpectedToSend,
)
}
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
weakDelegate?.urlSession(
session,
task: task,
didCompleteWithError: error,
)
}
func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping URLAuthenticationChallengeCompletion,
) {
weakDelegate?.urlSession(
session,
didReceive: challenge,
completionHandler: completionHandler,
)
}
func urlSession(
_ session: URLSession,
task: URLSessionTask,
willPerformHTTPRedirection response: HTTPURLResponse,
newRequest: URLRequest,
completionHandler: @escaping (URLRequest?) -> Void,
) {
weakDelegate?.urlSession(
session,
task: task,
willPerformHTTPRedirection: response,
newRequest: newRequest,
completionHandler: completionHandler,
)
}
func urlSession(
_ session: URLSession,
dataTask: URLSessionDataTask,
didReceive response: URLResponse,
completionHandler: @escaping (URLSession.ResponseDisposition) -> Void,
) {
guard let delegate = weakDelegate else {
completionHandler(.cancel)
return
}
delegate.urlSession(
session,
dataTask: dataTask,
didReceive: response,
completionHandler: completionHandler,
)
}
func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
weakDelegate?.urlSession(session, dataTask: dataTask, didReceive: data)
}
}
extension URLSessionDelegateBox: URLSessionWebSocketDelegate {
public func urlSession(_ session: URLSession, webSocketTask: URLSessionWebSocketTask, didOpenWithProtocol: String?) {
weakDelegate?.urlSession(session, webSocketTask: webSocketTask, didOpenWithProtocol: didOpenWithProtocol)
}
public func urlSession(_ session: URLSession, webSocketTask: URLSessionWebSocketTask, didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?) {
weakDelegate?.urlSession(session, webSocketTask: webSocketTask, didCloseWith: closeCode, reason: reason)
}
}