Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
SeleniumHQ
GitHub Repository: SeleniumHQ/Selenium
Path: blob/trunk/dotnet/src/webdriver/BiDi/Broker.cs
4000 views
// <copyright file="Broker.cs" company="Selenium Committers">
// Licensed to the Software Freedom Conservancy (SFC) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The SFC licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.
// </copyright>

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using OpenQA.Selenium.Internal.Logging;

namespace OpenQA.Selenium.BiDi;

internal sealed class Broker : IAsyncDisposable
{
    private readonly ILogger _logger = Internal.Logging.Log.GetLogger<Broker>();

    private readonly BiDi _bidi;
    private readonly ITransport _transport;

    private readonly ConcurrentDictionary<long, CommandInfo> _pendingCommands = new();
    private readonly Channel<(string Method, EventArgs Params)> _pendingEvents = Channel.CreateUnbounded<(string Method, EventArgs Params)>(new()
    {
        SingleReader = true,
        SingleWriter = true
    });
    private readonly Dictionary<string, JsonTypeInfo> _eventTypesMap = [];

    private readonly ConcurrentDictionary<string, List<EventHandler>> _eventHandlers = new();

    private long _currentCommandId;

    private static readonly TaskFactory _myTaskFactory = new(CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskContinuationOptions.None, TaskScheduler.Default);

    private Task? _receivingMessageTask;
    private Task? _eventEmitterTask;
    private CancellationTokenSource? _receiveMessagesCancellationTokenSource;

    internal Broker(BiDi bidi, Uri url)
    {
        _bidi = bidi;
        _transport = new WebSocketTransport(url);
    }

    public async Task ConnectAsync(CancellationToken cancellationToken)
    {
        await _transport.ConnectAsync(cancellationToken).ConfigureAwait(false);

        _receiveMessagesCancellationTokenSource = new CancellationTokenSource();
        _receivingMessageTask = _myTaskFactory.StartNew(async () => await ReceiveMessagesAsync(_receiveMessagesCancellationTokenSource.Token), TaskCreationOptions.LongRunning).Unwrap();
        _eventEmitterTask = _myTaskFactory.StartNew(ProcessEventsAwaiterAsync).Unwrap();
    }

    private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
    {
        try
        {
            while (!cancellationToken.IsCancellationRequested)
            {
                var data = await _transport.ReceiveAsync(cancellationToken).ConfigureAwait(false);

                try
                {
                    ProcessReceivedMessage(data);
                }
                catch (Exception ex)
                {
                    if (_logger.IsEnabled(LogEventLevel.Error))
                    {
                        _logger.Error($"Unhandled error occurred while processing remote message: {ex}");
                    }
                }
            }
        }
        catch (Exception ex) when (ex is not OperationCanceledException)
        {
            if (_logger.IsEnabled(LogEventLevel.Error))
            {
                _logger.Error($"Unhandled error occurred while receiving remote messages: {ex}");
            }

            throw;
        }
    }

    private async Task ProcessEventsAwaiterAsync()
    {
        var reader = _pendingEvents.Reader;
        while (await reader.WaitToReadAsync().ConfigureAwait(false))
        {
            while (reader.TryRead(out var result))
            {
                try
                {
                    if (_eventHandlers.TryGetValue(result.Method, out var eventHandlers))
                    {
                        if (eventHandlers is not null)
                        {
                            foreach (var handler in eventHandlers.ToArray()) // copy handlers avoiding modified collection while iterating
                            {
                                var args = result.Params;

                                args.BiDi = _bidi;

                                await handler.InvokeAsync(args).ConfigureAwait(false);
                            }
                        }
                    }
                }
                catch (Exception ex)
                {
                    if (_logger.IsEnabled(LogEventLevel.Error))
                    {
                        _logger.Error($"Unhandled error processing BiDi event handler: {ex}");
                    }
                }
            }
        }
    }

    public async Task<TResult> ExecuteCommandAsync<TCommand, TResult>(TCommand command, CommandOptions? options, JsonTypeInfo<TCommand> jsonCommandTypeInfo, JsonTypeInfo<TResult> jsonResultTypeInfo, CancellationToken cancellationToken)
        where TCommand : Command
        where TResult : EmptyResult
    {
        command.Id = Interlocked.Increment(ref _currentCommandId);

        var tcs = new TaskCompletionSource<EmptyResult>(TaskCreationOptions.RunContinuationsAsynchronously);

        using var cts = cancellationToken.CanBeCanceled
            ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken)
            : new CancellationTokenSource();

        var timeout = options?.Timeout ?? TimeSpan.FromSeconds(30);
        cts.CancelAfter(timeout);

        cts.Token.Register(() => tcs.TrySetCanceled(cts.Token));
        var commandInfo = new CommandInfo(command.Id, tcs, jsonResultTypeInfo);
        _pendingCommands[command.Id] = commandInfo;
        var data = JsonSerializer.SerializeToUtf8Bytes(command, jsonCommandTypeInfo);

        await _transport.SendAsync(data, cts.Token).ConfigureAwait(false);

        return (TResult)await tcs.Task.ConfigureAwait(false);
    }

    public async Task<Subscription> SubscribeAsync<TEventArgs>(string eventName, EventHandler eventHandler, SubscriptionOptions? options, JsonTypeInfo<TEventArgs> jsonTypeInfo, CancellationToken cancellationToken)
        where TEventArgs : EventArgs
    {
        _eventTypesMap[eventName] = jsonTypeInfo;

        var handlers = _eventHandlers.GetOrAdd(eventName, (a) => []);

        var subscribeResult = await _bidi.SessionModule.SubscribeAsync([eventName], new() { Contexts = options?.Contexts, UserContexts = options?.UserContexts }, cancellationToken).ConfigureAwait(false);

        handlers.Add(eventHandler);

        return new Subscription(subscribeResult.Subscription, this, eventHandler);
    }

    public async Task UnsubscribeAsync(Subscription subscription, CancellationToken cancellationToken)
    {
        var eventHandlers = _eventHandlers[subscription.EventHandler.EventName];

        eventHandlers.Remove(subscription.EventHandler);

        await _bidi.SessionModule.UnsubscribeAsync([subscription.SubscriptionId], null, cancellationToken).ConfigureAwait(false);
    }

    public async ValueTask DisposeAsync()
    {
        _pendingEvents.Writer.Complete();

        _receiveMessagesCancellationTokenSource?.Cancel();

        if (_eventEmitterTask is not null)
        {
            await _eventEmitterTask.ConfigureAwait(false);
        }

        _transport.Dispose();

        GC.SuppressFinalize(this);
    }

    private void ProcessReceivedMessage(byte[]? data)
    {
        long? id = default;
        string? type = default;
        string? method = default;
        string? error = default;
        string? message = default;
        Utf8JsonReader resultReader = default;
        Utf8JsonReader paramsReader = default;

        Utf8JsonReader reader = new(new ReadOnlySpan<byte>(data));
        reader.Read();

        reader.Read(); // "{"

        while (reader.TokenType == JsonTokenType.PropertyName)
        {
            string? propertyName = reader.GetString();
            reader.Read();

            switch (propertyName)
            {
                case "id":
                    id = reader.GetInt64();
                    break;

                case "type":
                    type = reader.GetString();
                    break;

                case "method":
                    method = reader.GetString();
                    break;

                case "result":
                    resultReader = reader; // snapshot
                    break;

                case "params":
                    paramsReader = reader; // snapshot
                    break;

                case "error":
                    error = reader.GetString();
                    break;

                case "message":
                    message = reader.GetString();
                    break;
            }

            reader.Skip();
            reader.Read();
        }

        switch (type)
        {
            case "success":
                if (id is null) throw new JsonException("The remote end responded with 'success' message type, but missed required 'id' property.");

                if (_pendingCommands.TryGetValue(id.Value, out var command))
                {
                    try
                    {
                        var commandResult = JsonSerializer.Deserialize(ref resultReader, command.JsonResultTypeInfo)
                            ?? throw new JsonException("Remote end returned null command result in the 'result' property.");

                        command.TaskCompletionSource.SetResult((EmptyResult)commandResult);
                    }
                    catch (Exception ex)
                    {
                        command.TaskCompletionSource.SetException(ex);
                    }
                    finally
                    {
                        _pendingCommands.TryRemove(id.Value, out _);
                    }
                }
                else
                {
                    throw new BiDiException($"The remote end responded with 'success' message type, but no pending command with id {id} was found.");
                }

                break;

            case "event":
                if (method is null) throw new JsonException("The remote end responded with 'event' message type, but missed required 'method' property.");

                if (_eventTypesMap.TryGetValue(method, out var eventInfo))
                {
                    var eventArgs = (EventArgs)JsonSerializer.Deserialize(ref paramsReader, eventInfo)!;

                    eventArgs.BiDi = _bidi;

                    var messageEvent = (method, eventArgs);
                    _pendingEvents.Writer.TryWrite(messageEvent);
                }
                else
                {
                    throw new BiDiException($"The remote end responded with 'event' message type, but no event type mapping for method '{method}' was found.");
                }

                break;

            case "error":
                if (id is null) throw new JsonException("The remote end responded with 'error' message type, but missed required 'id' property.");

                if (_pendingCommands.TryGetValue(id.Value, out var errorCommand))
                {
                    errorCommand.TaskCompletionSource.SetException(new BiDiException($"{error}: {message}"));
                    _pendingCommands.TryRemove(id.Value, out _);
                }
                else
                {
                    throw new BiDiException($"The remote end responded with 'error' message type, but no pending command with id {id} was found.");
                }

                break;
        }
    }

    class CommandInfo(long id, TaskCompletionSource<EmptyResult> taskCompletionSource, JsonTypeInfo jsonResultTypeInfo)
    {
        public long Id { get; } = id;

        public TaskCompletionSource<EmptyResult> TaskCompletionSource { get; } = taskCompletionSource;

        public JsonTypeInfo JsonResultTypeInfo { get; } = jsonResultTypeInfo;
    };
}