Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Network/SSKWebSocket.swift
1 views
//
// Copyright 2019 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation

@objc
public enum SSKWebSocketState: UInt {
    case open
    case connecting
    case disconnected
}

// MARK: -

extension SSKWebSocketState: CustomStringConvertible {
    public var description: String {
        switch self {
        case .open:
            return "SSKWebSocketState.open"
        case .connecting:
            return "SSKWebSocketState.connecting"
        case .disconnected:
            return "SSKWebSocketState.disconnected"
        }
    }
}

// MARK: -

public protocol SSKWebSocket: AnyObject {

    var delegate: SSKWebSocketDelegate? { get set }

    var id: UInt { get }

    var state: SSKWebSocketState { get }

    func connect()
    /// Disconnect with a provided closure code.
    /// If no code is provided, no code will be sent (equivalent to 1005 "noStatusReceived").
    /// For a normal closure, use `URLSessionWebSocketTask.CloseCode.normalClosure`
    func disconnect(code: URLSessionWebSocketTask.CloseCode?)

    func write(data: Data)

    func writePing()
}

// MARK: -

public enum WebSocketError: Error {
    // From RFC 6455: https://www.rfc-editor.org/rfc/rfc6455#section-7.4.1
    public static let normalClosure: Int = 1000

    case httpError(statusCode: Int, retryAfter: Date?)
    case closeError(statusCode: Int, closeReason: Data?)
}

// MARK: -

public extension SSKWebSocket {
    func sendResponse(
        for request: WebSocketProtoWebSocketRequestMessage,
        status: UInt32,
        message: String,
    ) throws {
        let responseBuilder = WebSocketProtoWebSocketResponseMessage.builder(
            requestID: request.requestID,
            status: status,
        )
        responseBuilder.setMessage(message)
        let response = try responseBuilder.build()

        let messageBuilder = WebSocketProtoWebSocketMessage.builder()
        messageBuilder.setType(.response)
        messageBuilder.setResponse(response)

        let messageData = try messageBuilder.buildSerializedData()

        write(data: messageData)
    }
}

// MARK: -

public protocol SSKWebSocketDelegate: AnyObject {
    func websocketDidConnect(socket: SSKWebSocket)

    func websocketDidDisconnectOrFail(socket: SSKWebSocket, error: Error)

    func websocket(_ socket: SSKWebSocket, didReceiveData data: Data)
}

// MARK: -

public struct WebSocketRequest {
    /// The Signal service associated with this request.
    public let signalService: SignalServiceType

    public let urlPath: String
    public let urlQueryItems: [URLQueryItem]?

    /// Extra headers that should be sent along with the request.
    public let extraHeaders: HttpHeaders

    public func build(for endpoint: OWSURLSessionEndpoint) -> URLRequest? {
        var urlComponents = URLComponents()
        urlComponents.path = urlPath
        urlComponents.queryItems = urlQueryItems
        guard let urlString = urlComponents.string else {
            owsFailBeta("Couldn't build URL for web socket: \(urlPath)")
            return nil
        }
        do {
            return try endpoint.buildRequest(
                urlString,
                overrideUrlScheme: "wss",
                method: .get,
                headers: extraHeaders,
            )
        } catch {
            Logger.warn("Couldn't build web socket request: \(error)")
            return nil
        }
    }
}

public protocol WebSocketFactory {
    func buildSocket(request: WebSocketRequest, callbackScheduler: Scheduler) -> SSKWebSocket?
}

// MARK: -

#if TESTABLE_BUILD

@objc
public class WebSocketFactoryMock: NSObject, WebSocketFactory {
    public func buildSocket(request: WebSocketRequest, callbackScheduler: Scheduler) -> SSKWebSocket? {
        owsFailDebug("Cannot build websocket.")
        return nil
    }
}

#endif

// MARK: -

@objc
public class WebSocketFactoryNative: NSObject, WebSocketFactory {
    public func buildSocket(request: WebSocketRequest, callbackScheduler: Scheduler) -> SSKWebSocket? {
        return SSKWebSocketNative(request: request, signalService: SSKEnvironment.shared.signalServiceRef, callbackScheduler: callbackScheduler)
    }
}

// MARK: -

public class SSKWebSocketNative: SSKWebSocket {

    private static let idCounter = AtomicUInt(lock: .sharedGlobal)
    public let id = SSKWebSocketNative.idCounter.increment()

    private let requestUrl: URL
    private let callbackScheduler: Scheduler
    private let urlSession: OWSURLSessionProtocol

    public init?(
        request: WebSocketRequest,
        signalService: OWSSignalServiceProtocol,
        callbackScheduler: Scheduler,
    ) {
        let signalServiceInfo = request.signalService.signalServiceInfo()

        let endpoint = signalService.buildUrlEndpoint(for: signalServiceInfo)

        guard let urlRequest = request.build(for: endpoint) else {
            return nil
        }

        let configuration = OWSURLSession.defaultConfigurationWithoutCaching

        // For some reason, `URLSessionWebSocketTask` will only respect the proxy
        // configuration if started with a URL and not a URLRequest. As a temporary
        // workaround, port header information from the request to the session.
        configuration.httpAdditionalHeaders = urlRequest.allHTTPHeaderFields

        self.urlSession = signalService.buildUrlSession(
            for: signalServiceInfo,
            endpoint: endpoint,
            configuration: configuration,
            maxResponseSize: nil,
        )
        self.requestUrl = urlRequest.url!
        self.callbackScheduler = callbackScheduler
    }

    // MARK: - SSKWebSocket

    public weak var delegate: SSKWebSocketDelegate?

    private var lock = UnfairLock()

    private var webSocketTask: URLSessionWebSocketTask?
    private var hasEverConnected = false
    private var isConnected = false
    private var shouldReportError = true
    private var hasUnansweredPing = false

    // This method is thread-safe.
    public var state: SSKWebSocketState {
        lock.withLock {
            if isConnected {
                return .open
            }
            if hasEverConnected {
                return .disconnected
            }
            return .connecting
        }
    }

    public func connect() {
        let taskToResume = lock.withLock { () -> URLSessionWebSocketTask? in
            owsAssertDebug(webSocketTask == nil && !hasEverConnected, "Must connect only once.")
            guard webSocketTask == nil else {
                return nil
            }
            webSocketTask = urlSession.webSocketTask(
                requestUrl: requestUrl,
                didOpenBlock: { [weak self] _ in self?.didOpen() },
                didCloseBlock: { [weak self] error in self?.didClose(error: error) },
            )
            return webSocketTask
        }
        taskToResume?.resume()
    }

    private func didOpen() {
        lock.withLock {
            isConnected = true
            hasEverConnected = true

            callbackScheduler.async {
                self.delegate?.websocketDidConnect(socket: self)
            }
        }
        listenForNextMessage()
    }

    private func didClose(error: Error) {
        lock.withLock {
            isConnected = false
            webSocketTask = nil
            reportErrorWithLock(error, context: "close")
        }
    }

    private func listenForNextMessage() {
        DispatchQueue.global().async {
            self.lock.withLock { self.webSocketTask }?.receive { [weak self] result in
                self?.receivedMessage(result)
            }
        }
    }

    private func receivedMessage(_ result: Result<URLSessionWebSocketTask.Message, Error>) {
        switch result {
        case .success(let message):
            switch message {
            case .data(let data):
                callbackScheduler.async {
                    self.delegate?.websocket(self, didReceiveData: data)
                }
            case .string:
                owsFailDebug("We only expect binary frames.")
            @unknown default:
                owsFailDebug("We only expect binary frames.")
            }
            listenForNextMessage()

        case .failure(let error):
            // For some sockets, we read messages until the server closes the
            // connection (and we inspect the close code to determine whether or not
            // it's a graceful teardown). As a result, we expect to receive the final
            // message and close frame in quick succession.
            //
            // We receive messages by repeatedly calling `receive` until we get an
            // error. Unfortunately, this process might see that the stream has been
            // closed before we've had a chance to process the real close frame.
            //
            // The Good Case:
            //   - receivedMessage(<final message>)
            //   - didClose(<close reason>)
            //   - receivedMessage(<socket closed error>)
            //
            // The Bad Case:
            //   - receivedMessage(<final message>)
            //   - receivedMessage(<socket closed error>)
            //   - didClose(<close reason>)
            //
            // (Note that the underlying web socket processes the incoming frames in
            // order, so it's not possible to receive didClose before the final
            // message. The didClose frame waits until the callback for the final
            // message has finished executing.)
            //
            // In theory, we should be able to drop this `receive` error on the floor
            // -- we always expect to learn that the socket has been closed via one of
            // the other URLSession callbacks. However, to guard against the
            // possibility that those might not happen, report the error after a short
            // delay. The delay should be long enough that it never jumps in front of
            // the close callback -- it's a last resort.
            DispatchQueue.global().asyncAfter(deadline: .now() + 10) { [weak self] in
                self?.reportReceivedMessageError(error)
            }
            // Don't try to listen again.
        }
    }

    private func reportReceivedMessageError(_ error: Error) {
        lock.withLock {
            owsAssertDebug(!shouldReportError, "We shouldn't learn that the socket has closed from a receive error.")
            reportErrorWithLock(error, context: "read")
        }
    }

    public func disconnect(code: URLSessionWebSocketTask.CloseCode?) {
        let taskToCancel = lock.withLock { () -> URLSessionWebSocketTask? in
            // The user requested a cancellation, so don't report an error
            shouldReportError = false
            let result = webSocketTask
            webSocketTask = nil
            return result
        }
        if let code {
            taskToCancel?.cancel(with: code, reason: nil)
        } else {
            taskToCancel?.cancel()
        }
    }

    public func write(data: Data) {
        let taskToSendTo = lock.withLock { () -> URLSessionWebSocketTask? in
            owsAssertDebug(hasEverConnected, "Must connect before sending to web socket.")
            guard let webSocketTask else {
                reportErrorWithLock(OWSGenericError("Missing webSocketTask."), context: "write")
                return nil
            }
            return webSocketTask
        }
        taskToSendTo?.send(.data(data)) { [weak self] error in
            self?.reportError(error, context: "write")
        }
    }

    public func writePing() {
        let taskToPing = lock.withLock { () -> URLSessionWebSocketTask? in
            owsAssertDebug(hasEverConnected, "Must connect before sending a ping.")
            guard let webSocketTask else {
                reportErrorWithLock(OWSGenericError("Missing webSocketTask."), context: "ping")
                return nil
            }
            guard !hasUnansweredPing else {
                reportErrorWithLock(OWSGenericError("Ping didn't get a response."), context: "ping")
                return nil
            }
            hasUnansweredPing = true
            return webSocketTask
        }
        taskToPing?.sendPing(pongReceiveHandler: { [weak self] error in
            self?.receivedPong(error)
        })
    }

    private func receivedPong(_ error: Error?) {
        lock.withLock {
            hasUnansweredPing = false
            reportErrorWithLock(error, context: "pong")
        }
    }

    private func reportError(_ error: Error?, context: String) {
        lock.withLock {
            reportErrorWithLock(error, context: context)
        }
    }

    private func reportErrorWithLock(_ error: Error?, context: String) {
        lock.assertOwner()

        guard let error else {
            return
        }

        guard shouldReportError else {
            return
        }
        shouldReportError = false

        callbackScheduler.async {
            self.delegate?.websocketDidDisconnectOrFail(socket: self, error: error)
        }
    }
}