Path: blob/main/SignalServiceKit/Registration/RegistrationSessionManagerImpl.swift
1 views
//
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import CoreTelephony
import Foundation
public class RegistrationSessionManagerImpl: RegistrationSessionManager {
private let dateProvider: DateProvider
private let db: any DB
private let kvStore: KeyValueStore
private let signalService: OWSSignalServiceProtocol
public init(
dateProvider: @escaping DateProvider = Date.provider,
db: any DB,
signalService: OWSSignalServiceProtocol,
) {
self.dateProvider = dateProvider
self.db = db
self.kvStore = KeyValueStore(collection: KvStore.collectionName)
self.signalService = signalService
}
// TODO: make this and other methods resilient to transient network failures by adding
// basic retrying logic.
public func restoreSession(logger: PrefixedLogger) async -> RegistrationSession? {
// Just get the most recent one, don't validate against any e164.
return await restoreSession(forE164: nil, logger: logger)
}
public func beginOrRestoreSession(e164: E164, apnsToken: String?, logger: PrefixedLogger) async -> Registration.BeginSessionResponse {
// Verify the session is still valid.
let restoredSession = await restoreSession(forE164: e164, logger: logger)
guard let restoredSession, restoredSession.e164 == e164 else {
// We only keep one session at a time, wipe any if we change the e164.
await db.awaitableWrite { self.clearPersistedSession($0) }
let (mcc, mnc) = Self.getMccMnc()
let response = await makeBeginSessionRequest(
e164: e164,
apnsToken: apnsToken,
mcc: mcc,
mnc: mnc,
logger: logger,
)
return await persistSessionFromResponse(response)
}
return .success(restoredSession)
}
public func fulfillChallenge(
for session: RegistrationSession,
fulfillment: Registration.ChallengeFulfillment,
logger: PrefixedLogger,
) async -> Registration.UpdateSessionResponse {
let response = await makeFulfillChallengeRequest(session, fulfillment, logger: logger)
return await persistSessionFromResponse(response)
}
public func requestVerificationCode(for session: RegistrationSession, transport: Registration.CodeTransport, logger: PrefixedLogger) async -> Registration.UpdateSessionResponse {
let response = await makeRequestVerificationCodeRequest(session, transport, logger: logger)
return await persistSessionFromResponse(response)
}
public func submitVerificationCode(for session: RegistrationSession, code: String, logger: PrefixedLogger) async -> Registration.UpdateSessionResponse {
let response = await makeSubmitVerificationCodeRequest(session, code: code, logger: logger)
return await persistSessionFromResponse(response)
}
public func clearPersistedSession(_ transaction: DBWriteTransaction) {
kvStore.removeValue(forKey: KvStore.sessionKey, transaction: transaction)
}
// MARK: - Session persistence
enum KvStore {
static let collectionName = "RegistrationSession"
static let sessionKey = "session"
}
private func persist(session: RegistrationSession, _ transaction: DBWriteTransaction) {
do {
try kvStore.setCodable(session, key: KvStore.sessionKey, transaction: transaction)
} catch {
owsFailDebug("Unable to encode session; will not be recoverable after app relaunch.")
}
}
private func getPersistedSession(_ transaction: DBReadTransaction) -> RegistrationSession? {
do {
return try kvStore.getCodableValue(forKey: KvStore.sessionKey, transaction: transaction)
} catch {
owsFailDebug("Unable to decode session; will not be recoverable after app relaunch.")
return nil
}
}
private func persistSessionFromResponse(_ response: Registration.BeginSessionResponse) async -> Registration.BeginSessionResponse {
switch response {
case .success(let session):
await db.awaitableWrite { self.persist(session: session, $0) }
case .invalidArgument, .retryAfter, .networkFailure, .genericError:
break
}
return response
}
private func persistSessionFromResponse(_ response: Registration.UpdateSessionResponse) async -> Registration.UpdateSessionResponse {
switch response {
case
let .success(session),
let .disallowed(session),
let .rejectedArgument(session),
let .retryAfterTimeout(session, retryAfterHeader: _),
let .transportError(session):
await db.awaitableWrite { self.persist(session: session, $0) }
case .invalidSession:
// Clear the session we've stored as it's invalid.
await db.awaitableWrite { self.clearPersistedSession($0) }
case .serverFailure, .networkFailure, .genericError:
break
}
return response
}
private func persistSessionFromResponse(_ response: FetchSessionResponse) async -> FetchSessionResponse {
switch response {
case .success(let session):
await db.awaitableWrite { self.persist(session: session, $0) }
case .sessionInvalid:
await db.awaitableWrite { self.clearPersistedSession($0) }
case .genericError:
break
}
return response
}
// MARK: - MCC/MNC
private static func getMccMnc() -> (mcc: String?, mnc: String?) {
guard
let providers = CTTelephonyNetworkInfo().serviceSubscriberCellularProviders,
let provider = providers.values.first
else {
Logger.info("Unable to get telephony info for mcc/mnc.")
return (nil, nil)
}
if providers.values.count > 1 {
Logger.info("Multiple telephony providers found; using the first for mcc/mnc.")
}
return (provider.mobileCountryCode, provider.mobileNetworkCode)
}
// MARK: - Requests
// TODO: make this and other methods resilient to transient network failures by adding
// basic retrying logic.
private func restoreSession(forE164 e164: E164?, logger: PrefixedLogger) async -> RegistrationSession? {
guard let existingSession = db.read(block: { self.getPersistedSession($0) }) else {
return nil
}
if let e164, existingSession.e164 != e164 {
// We only keep one session at a time, wipe any if we change the e164.
await db.awaitableWrite { self.clearPersistedSession($0) }
return nil
}
// Verify the session is still valid.
let fetchSessionResponse = await makeFetchSessionRequest(existingSession, logger: logger)
_ = await self.persistSessionFromResponse(fetchSessionResponse)
switch fetchSessionResponse {
case .success(let session):
return session
case .sessionInvalid, .genericError:
return nil
}
}
// MARK: Begin Session
private func makeBeginSessionRequest(
e164: E164,
apnsToken: String?,
mcc: String?,
mnc: String?,
logger: PrefixedLogger,
) async -> Registration.BeginSessionResponse {
let request = RegistrationRequestFactory.beginSessionRequest(
e164: e164,
pushToken: apnsToken,
mcc: mcc,
mnc: mnc,
logger: logger,
)
return await makeRequest(
request,
e164: e164,
handler: self.handleBeginSessionResponse(forE164:statusCode:retryAfterHeader:bodyData:),
fallbackError: .genericError,
networkFailureError: .networkFailure,
)
}
private func handleBeginSessionResponse(
forE164 e164: E164,
statusCode: Int,
retryAfterHeader: TimeInterval?,
bodyData: Data?,
) -> Registration.BeginSessionResponse {
let statusCode = RegistrationServiceResponses.BeginSessionResponseCodes(rawValue: statusCode)
switch statusCode {
case .success:
return registrationSession(
fromResponseBody: bodyData,
e164: e164,
).map { .success($0) } ?? .genericError
case .invalidArgument, .missingArgument:
return .invalidArgument
case .retry:
return .retryAfter(retryAfterHeader)
case .unexpectedError, .none:
return .genericError
}
}
// MARK: Fulfill Challenge
private func makeFulfillChallengeRequest(
_ session: RegistrationSession,
_ fulfillment: Registration.ChallengeFulfillment,
logger: PrefixedLogger,
) async -> Registration.UpdateSessionResponse {
let captchaToken: String?
let pushChallengeToken: String?
switch fulfillment {
case .captcha(let token):
captchaToken = token
pushChallengeToken = nil
case .pushChallenge(let token):
captchaToken = nil
pushChallengeToken = token
}
let request = RegistrationRequestFactory.fulfillChallengeRequest(
sessionId: session.id,
captchaToken: captchaToken,
pushChallengeToken: pushChallengeToken,
logger: logger,
)
return await makeUpdateRequest(
request,
session: session,
handler: self.handleFulfillChallengeResponse(sessionAtSendTime:statusCode:retryAfterHeader:bodyData:),
)
}
private func handleFulfillChallengeResponse(
sessionAtSendTime: RegistrationSession,
statusCode: Int,
retryAfterHeader: TimeInterval?,
bodyData: Data?,
) -> Registration.UpdateSessionResponse {
let e164 = sessionAtSendTime.e164
let statusCode = RegistrationServiceResponses.FulfillChallengeResponseCodes(rawValue: statusCode)
switch statusCode {
case .success:
return registrationSession(
fromResponseBody: bodyData,
e164: e164,
).map { .success($0) } ?? .genericError
case .notAccepted:
return registrationSession(
fromResponseBody: bodyData,
e164: e164,
).map { .rejectedArgument($0) } ?? .genericError
case .missingSession:
return .invalidSession
case .malformedRequest:
Logger.error("Malformed fulfill challenge request")
return .genericError
case .unexpectedError, .none:
return .genericError
}
}
// MARK: Request Verification Code
private func makeRequestVerificationCodeRequest(
_ session: RegistrationSession,
_ transport: Registration.CodeTransport,
logger: PrefixedLogger,
) async -> Registration.UpdateSessionResponse {
let wireTransport: RegistrationRequestFactory.VerificationCodeTransport
switch transport {
case .sms:
wireTransport = .sms
case .voice:
wireTransport = .voice
}
// In an abstract sense we should mock these out for testing, but
// for any concrete test we'd want to write the language code doesn't matter
// at all. Its serialization is already tested at a lower level.
let locale = Locale.current
let languageCode: String?
let countryCode: String?
if #available(iOS 16, *) {
languageCode = locale.language.languageCode?.identifier
countryCode = locale.region?.identifier
} else {
languageCode = locale.languageCode
countryCode = locale.regionCode
}
let request = RegistrationRequestFactory.requestVerificationCodeRequest(
sessionId: session.id,
languageCode: languageCode,
countryCode: countryCode,
transport: wireTransport,
logger: logger,
)
return await makeUpdateRequest(
request,
session: session,
handler: self.handleRequestVerificationCodeResponse(sessionAtSendTime:statusCode:retryAfterHeader:bodyData:),
)
}
private func handleRequestVerificationCodeResponse(
sessionAtSendTime: RegistrationSession,
statusCode: Int,
retryAfterHeader: TimeInterval?,
bodyData: Data?,
) -> Registration.UpdateSessionResponse {
let e164 = sessionAtSendTime.e164
let statusCode = RegistrationServiceResponses.RequestVerificationCodeResponseCodes(rawValue: statusCode)
switch statusCode {
case .success:
return registrationSession(
fromResponseBody: bodyData,
e164: e164,
).map { .success($0) } ?? .genericError
case .disallowed:
return registrationSession(
fromResponseBody: bodyData,
e164: e164,
).map { .disallowed($0) } ?? .genericError
case .retry:
return registrationSession(
fromResponseBody: bodyData,
e164: e164,
).map { .retryAfterTimeout($0, retryAfterHeader: retryAfterHeader) } ?? .genericError
case .providerFailure:
return serverFailureResponse(fromResponseBody: bodyData, sessionAtSendTime: sessionAtSendTime).map { .serverFailure($0) } ?? .genericError
case .missingSession:
return .invalidSession
case .transportError:
return registrationSession(
fromResponseBody: bodyData,
e164: e164,
).map { .transportError($0) } ?? .genericError
case .malformedRequest, .unexpectedError, .none:
return .genericError
}
}
// MARK: Submit Verification Code
private func makeSubmitVerificationCodeRequest(
_ session: RegistrationSession,
code: String,
logger: PrefixedLogger,
) async -> Registration.UpdateSessionResponse {
let request = RegistrationRequestFactory.submitVerificationCodeRequest(
sessionId: session.id,
code: code,
logger: logger,
)
return await makeUpdateRequest(
request,
session: session,
handler: self.handleSubmitVerificationCodeResponse(sessionAtSendTime:statusCode:retryAfterHeader:bodyData:),
)
}
private func handleSubmitVerificationCodeResponse(
sessionAtSendTime: RegistrationSession,
statusCode: Int,
retryAfterHeader: TimeInterval?,
bodyData: Data?,
) -> Registration.UpdateSessionResponse {
let e164 = sessionAtSendTime.e164
let statusCode = RegistrationServiceResponses.SubmitVerificationCodeResponseCodes(rawValue: statusCode)
switch statusCode {
case .success:
guard
let session = registrationSession(
fromResponseBody: bodyData,
e164: e164,
)
else {
return .genericError
}
if session.verified {
return .success(session)
} else {
return .rejectedArgument(session)
}
case .malformedRequest:
Logger.error("Verification code was invalidly formatted (not just incorrect).")
return .genericError
case .retry:
return registrationSession(
fromResponseBody: bodyData,
e164: e164,
).map { .retryAfterTimeout($0, retryAfterHeader: retryAfterHeader) } ?? .genericError
case .missingSession:
return .invalidSession
case .newCodeRequired:
guard
let session = registrationSession(
fromResponseBody: bodyData,
e164: e164,
)
else {
return .genericError
}
if session.verified {
// Unclear how this could happen but hey,
// the session is verified. Pretend that worked
// and keep going
return .success(session)
} else if session.nextVerificationAttempt != nil {
// We can submit a code, but not yet.
return .retryAfterTimeout(session, retryAfterHeader: retryAfterHeader)
} else {
// There is no code to submit.
return .disallowed(session)
}
case .unexpectedError, .none:
return .genericError
}
}
// MARK: Fetch Session
private enum FetchSessionResponse: Equatable {
case success(RegistrationSession)
/// This session is known to be invalid or timed out.
/// It should be thrown away and another session started.
case sessionInvalid
/// Some other error occurred; an error might be shown to the user
/// but the session shouldn't be discarded.
case genericError
}
private func makeFetchSessionRequest(
_ session: RegistrationSession,
logger: PrefixedLogger,
) async -> FetchSessionResponse {
let request = RegistrationRequestFactory.fetchSessionRequest(sessionId: session.id, logger: logger)
do {
let response = try await signalService.urlSessionForMainSignalService().performRequest(request)
return handleFetchSessionResponse(
sessionAtSendTime: session,
statusCode: response.responseStatusCode,
bodyData: response.responseBodyData,
)
} catch {
guard let error = error as? OWSHTTPError else {
return .genericError
}
let response = handleFetchSessionResponse(
sessionAtSendTime: session,
statusCode: error.responseStatusCode,
bodyData: error.httpResponseData,
)
return response
}
}
private func handleFetchSessionResponse(
sessionAtSendTime: RegistrationSession,
statusCode: Int,
bodyData: Data?,
) -> FetchSessionResponse {
let e164 = sessionAtSendTime.e164
let statusCode = RegistrationServiceResponses.FetchSessionResponseCodes(rawValue: statusCode)
switch statusCode {
case .success:
return registrationSession(
fromResponseBody: bodyData,
e164: e164,
).map { .success($0) } ?? .genericError
case .missingSession:
return .sessionInvalid
case .unexpectedError, .none:
return .genericError
}
}
// MARK: - Generic Request Helpers
private func registrationSession(
fromResponseBody bodyData: Data?,
e164: E164,
) -> RegistrationSession? {
guard let bodyData else {
Logger.warn("Got empty registration session response")
return nil
}
guard let session = try? JSONDecoder().decode(RegistrationServiceResponses.RegistrationSession.self, from: bodyData) else {
Logger.warn("Unable to parse registration session from response")
return nil
}
return session.toLocalSession(forE164: e164, receivedAt: dateProvider())
}
private func serverFailureResponse(
fromResponseBody bodyData: Data?,
sessionAtSendTime: RegistrationSession,
) -> Registration.ServerFailureResponse? {
guard let bodyData else {
Logger.warn("Got empty provider failure response")
return nil
}
guard let failure = try? JSONDecoder().decode(RegistrationServiceResponses.SendVerificationCodeFailedResponse.self, from: bodyData) else {
Logger.warn("Unable to parse registration session from response")
return nil
}
let reasonString: String = {
switch failure.reason {
case .none, .unknown:
return "unknown"
case .providerRejected:
return "provider rejected"
case .providerUnavailable:
return "provider unavailable"
case .illegalArgument:
return "illegal argument (rejected number)"
}
}()
Logger.error("Sending verification code failure from service provider. Permanent:\(failure.permanentFailure) Reason:\(reasonString)")
let localReason: Registration.ServerFailureResponse.Reason?
switch failure.reason {
case .unknown, .none:
localReason = nil
case .providerRejected:
localReason = .providerRejected
case .providerUnavailable:
localReason = .providerUnavailable
case .illegalArgument:
localReason = .illegalArgument
}
return Registration.ServerFailureResponse(
session: sessionAtSendTime,
isPermanent: failure.permanentFailure,
reason: localReason,
)
}
private func makeRequest<ResponseType>(
_ request: TSRequest,
e164: E164,
handler: @escaping (_ e164: E164, _ statusCode: Int, _ retryAfterHeader: TimeInterval?, _ bodyData: Data?) -> ResponseType,
fallbackError: ResponseType,
networkFailureError: ResponseType,
) async -> ResponseType {
do {
let response = try await signalService.urlSessionForMainSignalService().performRequest(request)
return handler(
e164,
response.responseStatusCode,
response.headers.retryAfterTimeInterval,
response.responseBodyData,
)
} catch {
if error.isNetworkFailureOrTimeout {
return networkFailureError
}
guard let error = error as? OWSHTTPError else {
return fallbackError
}
return handler(
e164,
error.responseStatusCode,
error.responseHeaders?.retryAfterTimeInterval,
error.httpResponseData,
)
}
}
private func makeUpdateRequest(
_ request: TSRequest,
session: RegistrationSession,
handler: @escaping (_ priorSession: RegistrationSession, _ statusCode: Int, _ retryAfterHeader: TimeInterval?, _ bodyData: Data?) -> Registration.UpdateSessionResponse,
) async -> Registration.UpdateSessionResponse {
return await makeRequest(
request,
e164: session.e164,
handler: { _, statusCode, retryAfterHeader, bodyData in
return handler(session, statusCode, retryAfterHeader, bodyData)
},
fallbackError: .genericError,
networkFailureError: .networkFailure,
)
}
}
private extension RegistrationServiceResponses.RegistrationSession {
func toLocalSession(
forE164 e164: E164,
receivedAt: Date,
) -> RegistrationSession {
let mappedChallenges = requestedInformation.compactMap(\.asLocalChallenge)
let hasUnknownChallengeRequiringAppUpdate = mappedChallenges.count != requestedInformation.count
return RegistrationSession(
id: id,
e164: e164,
receivedDate: receivedAt,
nextSMS: nextSms.map { TimeInterval($0) },
nextCall: nextCall.map { TimeInterval($0) },
nextVerificationAttempt: nextVerificationAttempt.map { TimeInterval($0) },
allowedToRequestCode: allowedToRequestCode,
requestedInformation: mappedChallenges,
hasUnknownChallengeRequiringAppUpdate: hasUnknownChallengeRequiringAppUpdate,
verified: verified,
)
}
}
private extension RegistrationServiceResponses.RegistrationSession.Challenge {
var asLocalChallenge: RegistrationSession.Challenge? {
switch self {
case .captcha: return .captcha
case .pushChallenge: return .pushChallenge
case .unknown: return nil
}
}
}