Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/SignalServiceKit/Network/SgxWebsocketConnection/SgxWebsocketConnection.swift
1 views
//
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//

import Foundation
import LibSignalClient
import SwiftProtobuf

/// Exposes a SgxClient-conformant server communication channel.
///
/// This handles the initial handshake & subsequent encryption/decryption
/// of the exchanged messages using an `SgxClient` instance provided
/// by a `SgxWebsocketConfigurator`.
///
/// While this is a class, there should never be an instance of this base class; all instances
/// should be of a concrete subclass. It is only a class and not a protocol so users can refer
/// to an instance by config type without specifying the implementation,
/// e.g. `SgxWebsocketConnection<FooServerConfigurator>`.
/// That is not possible for a protocol with an associated type.
public class SgxWebsocketConnection<Configurator: SgxWebsocketConfigurator> {

    // Never add an initializer to this class; instances should be impossible.
    fileprivate init() {}

    public var mrEnclave: MrEnclave { fatalError("Concrete subclass must implement") }

    public var client: Configurator.Client { fatalError("Concrete subclass must implement") }

    public var auth: RemoteAttestation.Auth { fatalError("Concrete subclass must implement") }

    // Subclasses must implement.
    func sendRequestAndReadResponse(_ request: Configurator.Request) -> Promise<Configurator.Response> {
        fatalError("Concrete subclass must implement")
    }

    // Subclasses must implement.
    func sendRequestAndReadAllResponses(_ request: Configurator.Request) -> Promise<[Configurator.Response]> {
        fatalError("Concrete subclass must implement")
    }

    // Subclasses must implement.
    func disconnect(code: URLSessionWebSocketTask.CloseCode?) {
        fatalError("Concrete subclass must implement")
    }
}

public class SgxWebsocketConnectionImpl<Configurator: SgxWebsocketConfigurator>: SgxWebsocketConnection<Configurator> {

    private let webSocket: WebSocketPromise
    private let configurator: Configurator
    private let _client: Configurator.Client
    private let _auth: RemoteAttestation.Auth
    private let scheduler: Scheduler

    private init(
        webSocket: WebSocketPromise,
        configurator: Configurator,
        client: Configurator.Client,
        auth: RemoteAttestation.Auth,
        scheduler: Scheduler,
    ) {
        self.webSocket = webSocket
        self.configurator = configurator
        self._client = client
        self._auth = auth
        self.scheduler = scheduler
        super.init()
    }

    static func connectAndPerformHandshake(
        configurator: Configurator,
        auth: RemoteAttestation.Auth,
        websocketFactory: WebSocketFactory,
        scheduler: Scheduler,
    ) throws -> Promise<SgxWebsocketConnection<Configurator>> {
        let webSocket = try buildSocket(
            configurator: configurator,
            auth: auth,
            websocketFactory: websocketFactory,
            scheduler: scheduler,
        )
        return firstly(on: scheduler) {
            webSocket.waitForResponse()
        }.then(on: scheduler) { attestationMessage -> Promise<Configurator.Client> in
            let client = try Configurator.client(
                mrenclave: configurator.mrenclave,
                attestationMessage: attestationMessage,
                currentDate: Date(),
            )
            return firstly {
                webSocket.send(data: client.initialRequest())
                return webSocket.waitForResponse()
            }.map(on: scheduler) { handshakeResponse -> Configurator.Client in
                try client.completeHandshake(handshakeResponse)
                return client
            }
        }.map(on: scheduler) { client -> SgxWebsocketConnection<Configurator> in
            return SgxWebsocketConnectionImpl<Configurator>(
                webSocket: webSocket,
                configurator: configurator,
                client: client,
                auth: auth,
                scheduler: scheduler,
            )
        }.recover(on: scheduler) { error -> Promise<SgxWebsocketConnection<Configurator>> in
            Logger.warn("\(type(of: configurator).loggingName): Disconnecting socket after failed handshake: \(error)")
            webSocket.disconnect(code: .invalidFramePayloadData)
            throw error
        }
    }

    private static func buildSocket(
        configurator: Configurator,
        auth: RemoteAttestation.Auth,
        websocketFactory: WebSocketFactory,
        scheduler: Scheduler,
    ) throws -> WebSocketPromise {
        let authHeaderValue = HttpHeaders.authHeaderValue(username: auth.username, password: auth.password)
        let request = WebSocketRequest(
            signalService: Configurator.signalServiceType,
            urlPath: Configurator.websocketUrlPath(mrenclaveString: configurator.mrenclave.dataValue.hexadecimalString),
            urlQueryItems: nil,
            extraHeaders: [HttpHeaders.authHeaderKey: authHeaderValue],
        )
        guard let webSocketPromise = websocketFactory.webSocketPromise(request: request, callbackScheduler: scheduler) else {
            throw OWSAssertionError("We should always be able to get a web socket from this API.")
        }
        return webSocketPromise
    }

    override public var mrEnclave: MrEnclave { return configurator.mrenclave }

    override public var client: Configurator.Client { return _client }

    override public var auth: RemoteAttestation.Auth { return _auth }

    override public func sendRequestAndReadResponse(
        _ request: Configurator.Request,
    ) -> Promise<Configurator.Response> {
        firstly(on: scheduler) { () -> Promise<Data> in
            try self.encryptAndSendRequest(request.serializedData())
            return self.webSocket.waitForResponse()
        }.map(on: scheduler) { encryptedResponse in
            let data = try self.decryptResponse(encryptedResponse)
            return try Configurator.Response(serializedBytes: data)
        }
    }

    override public func sendRequestAndReadAllResponses(
        _ request: Configurator.Request,
    ) -> Promise<[Configurator.Response]> {
        firstly(on: scheduler) { () -> Promise<[Data]> in
            try self.encryptAndSendRequest(request.serializedData())
            return self.webSocket.waitForAllResponses()
        }.map(on: scheduler) { encryptedResponses in
            try encryptedResponses.map {
                let data = try self.decryptResponse($0)
                return try Configurator.Response(serializedBytes: data)
            }
        }
    }

    private func encryptAndSendRequest(_ request: Data) throws {
        let encryptedRequest = try client.establishedSend(request)
        webSocket.send(data: encryptedRequest)
    }

    private func decryptResponse(_ encryptedResponse: Data) throws -> Data {
        return try client.establishedRecv(encryptedResponse)
    }

    override public func disconnect(code: URLSessionWebSocketTask.CloseCode?) {
        webSocket.disconnect(code: code)
    }
}

#if TESTABLE_BUILD

public class MockSgxWebsocketConnection<Configurator: SgxWebsocketConfigurator>: SgxWebsocketConnection<Configurator> {

    override init() {
        super.init()
    }

    public var mockEnclave: MrEnclave!

    override public var mrEnclave: MrEnclave { return mockEnclave }

    public var mockClient: Configurator.Client!

    override public var client: Configurator.Client { return mockClient }

    public var mockAuth: RemoteAttestation.Auth!

    override public var auth: RemoteAttestation.Auth { return mockAuth }

    public var onSendRequestAndReadResponse: ((Configurator.Request) -> Promise<Configurator.Response>)?

    override public func sendRequestAndReadResponse(
        _ request: Configurator.Request,
    ) -> Promise<Configurator.Response> {
        onSendRequestAndReadResponse!(request)
    }

    public var onSendRequestAndReadAllResponses: ((Configurator.Request) -> Promise<[Configurator.Response]>)?

    override public func sendRequestAndReadAllResponses(
        _ request: Configurator.Request,
    ) -> Promise<[Configurator.Response]> {
        onSendRequestAndReadAllResponses!(request)
    }

    public var onDisconnect: (() -> Void)?

    override public func disconnect(code: URLSessionWebSocketTask.CloseCode?) {
        onDisconnect?()
    }
}

#endif