Advertisement
Hulkstance

Untitled

Oct 13th, 2022
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 14.21 KB | None | 0 0
  1. using System.Buffers;
  2. using System.Net.WebSockets;
  3. using System.Runtime.CompilerServices;
  4. using System.Threading.Channels;
  5. using Microsoft.Extensions.Logging;
  6. using Microsoft.Extensions.Logging.Abstractions;
  7.  
  8. namespace Sts.Exchange.WebSockets;
  9.  
  10. public class WebSocketClient : IDisposable
  11. {
  12.     private const int MaxMessageSize = 1_500_000;
  13.  
  14.     private readonly int _receiveChunkSize;
  15.     private readonly string _url;
  16.     private readonly ILogger<WebSocketClient> _logger;
  17.     private readonly Channel<Message> _receiveChannel;
  18.     private readonly Channel<Message> _sendChannel;
  19.     private readonly uint _numberOfConsumers;
  20.     private readonly SemaphoreSlim _semaphore = new(1, 1);
  21.  
  22.     private ClientWebSocket? _clientWebSocket;
  23.     private CancellationTokenSource? _cts;
  24.     private Task _sendTask = Task.CompletedTask;
  25.     private Task _dataTask = Task.CompletedTask;
  26.     private bool _isStopped;
  27.  
  28.     public WebSocketClient(string url, ILoggerFactory? loggerFactory = default, int receiveChunkSize = 4096, bool singleReceiver = false, bool singleSender = false, uint numberOfConsumers = 4)
  29.     {
  30.         if (string.IsNullOrWhiteSpace(url))
  31.         {
  32.             throw new ArgumentNullException(nameof(url));
  33.         }
  34.  
  35.         _url = url;
  36.         _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<WebSocketClient>();
  37.         _receiveChunkSize = receiveChunkSize;
  38.  
  39.         _receiveChannel = Channel.CreateBounded<Message>(new BoundedChannelOptions(10)
  40.         {
  41.             SingleWriter = true,
  42.             SingleReader = singleReceiver,
  43.             FullMode = BoundedChannelFullMode.DropOldest
  44.         });
  45.  
  46.         _sendChannel = Channel.CreateBounded<Message>(new BoundedChannelOptions(10)
  47.         {
  48.             SingleReader = true,
  49.             SingleWriter = singleSender,
  50.             FullMode = BoundedChannelFullMode.Wait
  51.         });
  52.  
  53.         _numberOfConsumers = numberOfConsumers;
  54.     }
  55.  
  56.     internal Task Running { get; private set; } = Task.CompletedTask;
  57.     public bool IsDisposed { get; protected set; }
  58.     public bool IsOpen => _clientWebSocket is { State: WebSocketState.Open };
  59.  
  60.     public Func<ClientWebSocket> ClientFactory { get; set; } = () => new ClientWebSocket();
  61.  
  62.     public event EventHandler? Connected;
  63.     public event EventHandler? Disconnected;
  64.     public event EventHandler<ErrorEventArgs>? Error;
  65.     public event EventHandler<MessageReceivedEventArgs>? MessageReceived;
  66.  
  67.     public void Dispose()
  68.     {
  69.         Dispose(true);
  70.         GC.SuppressFinalize(this);
  71.     }
  72.  
  73.     public async Task StartAsync()
  74.     {
  75.         DoDisposeChecks();
  76.  
  77.         // Prevent a race condition
  78.         await _semaphore.WaitAsync().ConfigureAwait(false);
  79.  
  80.         try
  81.         {
  82.             if (_cts == null)
  83.             {
  84.                 _cts = new CancellationTokenSource();
  85.  
  86.                 _clientWebSocket = null;
  87.  
  88.                 Running = Task.Run(async () =>
  89.                 {
  90.                     _logger.LogTrace("Connection task started: {Url}", _url);
  91.  
  92.                     try
  93.                     {
  94.                         await HandleConnection().ConfigureAwait(false);
  95.                     }
  96.                     catch (Exception e)
  97.                     {
  98.                         _logger.LogError(e, "Error in connection task: {Url}: ", _url);
  99.                     }
  100.  
  101.                     _logger.LogTrace("Connection task ended: {Url}", _url);
  102.                 }, _cts.Token);
  103.  
  104.                 var count = 0;
  105.                 do
  106.                 {
  107.                     // wait for _client to be not null
  108.                     if (_clientWebSocket != null || _cts.Token.WaitHandle.WaitOne(50))
  109.                     {
  110.                         break;
  111.                     }
  112.                 } while (++count < 100);
  113.             }
  114.         }
  115.         finally
  116.         {
  117.             _semaphore.Release();
  118.         }
  119.     }
  120.  
  121.     public async Task StopAsync()
  122.     {
  123.         DoDisposeChecks();
  124.  
  125.         await _semaphore.WaitAsync().ConfigureAwait(false);
  126.  
  127.         try
  128.         {
  129.             if (Running.IsCompleted)
  130.             {
  131.                 // We never started
  132.                 return;
  133.             }
  134.  
  135.             _logger.LogDebug("Stopping");
  136.  
  137.             _isStopped = true;
  138.  
  139.             try
  140.             {
  141.                 // Close the socket first, because ReceiveAsync leaves an invalid socket (state = aborted) when the token is cancelled
  142.                 if (_clientWebSocket is { State: not (WebSocketState.Aborted or WebSocketState.Closed or WebSocketState.CloseSent) })
  143.                 {
  144.                     _logger.LogInformation("CloseOutputAsync called");
  145.                     // After this call, the socket state which change to CloseSent
  146.                     await _clientWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false);
  147.                 }
  148.             }
  149.             catch
  150.             {
  151.                 // Any exception thrown here will be caused by the socket already being closed,
  152.                 // which is the state we want to put it in by calling this method, which
  153.                 // means we don't care if it was already closed and threw an exception
  154.                 // when we tried to close it again.
  155.             }
  156.  
  157.             _logger.LogInformation("cts called");
  158.  
  159.             _cts?.Cancel();
  160.  
  161.             await Running.ConfigureAwait(false);
  162.  
  163.             _cts?.Dispose();
  164.             _cts = null;
  165.  
  166.             OnDisconnected();
  167.  
  168.             _logger.LogDebug("Stopped");
  169.         }
  170.         finally
  171.         {
  172.             _semaphore.Release();
  173.         }
  174.     }
  175.  
  176.     [MethodImpl(MethodImplOptions.AggressiveInlining)]
  177.     protected void DoDisposeChecks()
  178.     {
  179.         if (IsDisposed)
  180.         {
  181.             throw new ObjectDisposedException(nameof(WebSocketClient));
  182.         }
  183.     }
  184.  
  185.     protected virtual void Dispose(bool disposing)
  186.     {
  187.         if (IsDisposed)
  188.         {
  189.             return;
  190.         }
  191.  
  192.         if (disposing)
  193.         {
  194.             _semaphore.Dispose();
  195.             _receiveChannel.Writer.TryComplete();
  196.             _sendChannel.Writer.TryComplete();
  197.         }
  198.  
  199.         IsDisposed = true;
  200.     }
  201.  
  202.     private async Task HandleConnection()
  203.     {
  204.         while (_cts is { IsCancellationRequested: false })
  205.         {
  206.             DoDisposeChecks();
  207.  
  208.             _logger.LogTrace("Connecting...");
  209.  
  210.             using var connectionCts = CancellationTokenSource.CreateLinkedTokenSource(_cts.Token);
  211.  
  212.             try
  213.             {
  214.                 await _semaphore.WaitAsync(connectionCts.Token).ConfigureAwait(false);
  215.  
  216.                 try
  217.                 {
  218.                     var clientWebSocket = ClientFactory();
  219.                     await clientWebSocket.ConnectAsync(new Uri(_url), connectionCts.Token).ConfigureAwait(false);
  220.  
  221.                     _clientWebSocket = clientWebSocket;
  222.                 }
  223.                 finally
  224.                 {
  225.                     _semaphore.Release();
  226.                 }
  227.  
  228.                 OnConnected();
  229.                 _isStopped = false;
  230.  
  231.                 using (_clientWebSocket)
  232.                 {
  233.                     var token = connectionCts.Token;
  234.                     _sendTask = ProcessSendAsync(_clientWebSocket, token);
  235.  
  236.                     if (_numberOfConsumers == 0)
  237.                     {
  238.                         _dataTask = ProcessDataAsync(token);
  239.                     }
  240.                     else
  241.                     {
  242.                         var processingTasks = Enumerable.Range(1, (int)_numberOfConsumers).Select(_ => ProcessDataAsync(token));
  243.                         _dataTask = Task.WhenAll(processingTasks);
  244.                     }
  245.  
  246.                     // Listen for messages
  247.                     byte[]? buffer = null;
  248.                     var count = 0;
  249.  
  250.                     while (true)
  251.                     {
  252.                         // Rent a buffer only we run out of space, not for every read
  253.                         buffer ??= ArrayPool<byte>.Shared.Rent(_receiveChunkSize);
  254.  
  255.                         // We need Memory<byte> overload instead of the ArraySegment one (this one allocates a Task per read).
  256.                         var result = await _clientWebSocket.ReceiveAsync(buffer.AsMemory()[count..], connectionCts.Token).ConfigureAwait(false);
  257.  
  258.                         if (result.MessageType == WebSocketMessageType.Close)
  259.                         {
  260.                             _logger.LogDebug("The remote server has closed the connection");
  261.  
  262.                             ArrayPool<byte>.Shared.Return(buffer);
  263.  
  264.                             // Some exchanges e.g. Binance are closing the remote connection when an error occurs, so we simply reconnect when that happens
  265.                             if (!_isStopped)
  266.                             {
  267.                                 break;
  268.                             }
  269.  
  270.                             // Jumps to finally block
  271.                             return;
  272.                         }
  273.  
  274.                         count += result.Count;
  275.  
  276.                         if (count > MaxMessageSize)
  277.                         {
  278.                             throw new InvalidOperationException("Maximum size of the message was exceeded.");
  279.                         }
  280.  
  281.                         if (result.EndOfMessage)
  282.                         {
  283.                             // Avoid working with strings to reduce allocation
  284.                             await _receiveChannel.Writer.WriteAsync(new Message(buffer.AsMemory(0, count)), connectionCts.Token).ConfigureAwait(false);
  285.  
  286.                             count = 0;
  287.                             buffer = null;
  288.                         }
  289.                         else if (count >= buffer.Length)
  290.                         {
  291.                             // Create the new array
  292.                             var newArray = ArrayPool<byte>.Shared.Rent(buffer.Length * 2);
  293.  
  294.                             // Copy the old array to the new array
  295.                             buffer.AsSpan().CopyTo(newArray);
  296.  
  297.                             // Return the old array
  298.                             ArrayPool<byte>.Shared.Return(buffer);
  299.  
  300.                             buffer = newArray;
  301.                         }
  302.                     }
  303.                 }
  304.             }
  305.             catch (OperationCanceledException)
  306.             {
  307.                 // operation was canceled, ignore
  308.             }
  309.             catch (WebSocketException ex)
  310.             {
  311.                 if (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
  312.                 {
  313.                     _logger.LogDebug("Prematurely closed");
  314.                 }
  315.  
  316.                 OnError(new ErrorEventArgs(ex.Message, ex));
  317.                 connectionCts.Token.WaitHandle.WaitOne(2000);
  318.             }
  319.             catch (Exception ex)
  320.             {
  321.                 OnError(new ErrorEventArgs(ex.Message, ex));
  322.             }
  323.             finally
  324.             {
  325.                 connectionCts.Cancel();
  326.  
  327.                 await Task.WhenAll(_sendTask, _dataTask).ConfigureAwait(false);
  328.             }
  329.         }
  330.     }
  331.  
  332.     protected virtual void OnConnected()
  333.     {
  334.         _logger.LogTrace("OnConnected: Connection opened (URL: {Url})", _url);
  335.         Connected?.Invoke(this, EventArgs.Empty);
  336.     }
  337.  
  338.     protected virtual void OnDisconnected()
  339.     {
  340.         _logger.LogTrace("OnDisconnected: Connection closed");
  341.         Disconnected?.Invoke(this, EventArgs.Empty);
  342.     }
  343.  
  344.     protected virtual void OnError(ErrorEventArgs e)
  345.     {
  346.         _logger.LogError(e.Exception, "OnError: {Message}", e.Message);
  347.         Error?.Invoke(this, e);
  348.     }
  349.  
  350.     protected virtual void OnMessageReceived(MessageReceivedEventArgs e)
  351.     {
  352.         MessageReceived?.Invoke(this, e);
  353.     }
  354.  
  355.     #region Send
  356.  
  357.     // Producer
  358.     public ValueTask SendAsync(Message message)
  359.     {
  360.         DoDisposeChecks();
  361.  
  362.         return _sendChannel.Writer.WriteAsync(message);
  363.     }
  364.  
  365.     // Producer
  366.     public bool Send(Message message)
  367.     {
  368.         DoDisposeChecks();
  369.  
  370.         return _sendChannel.Writer.TryWrite(message);
  371.     }
  372.  
  373.     // Consumer
  374.     private async Task ProcessSendAsync(WebSocket webSocket, CancellationToken cancellationToken)
  375.     {
  376.         try
  377.         {
  378.             while (await _sendChannel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
  379.             {
  380.                 while (_sendChannel.Reader.TryRead(out var message))
  381.                 {
  382.                     try
  383.                     {
  384.                         await webSocket.SendAsync(message.Buffer, WebSocketMessageType.Text, true, CancellationToken.None).ConfigureAwait(false);
  385.                     }
  386.                     catch (Exception ex)
  387.                     {
  388.                         _logger.LogError(ex, "SendAsync: {ExceptionMessage}", ex.Message);
  389.                     }
  390.                 }
  391.             }
  392.         }
  393.         catch (OperationCanceledException)
  394.         {
  395.             // operation was canceled, ignore
  396.         }
  397.         finally
  398.         {
  399.             _logger.LogTrace("Send loop finished");
  400.         }
  401.     }
  402.  
  403.     #endregion
  404.  
  405.     #region Data
  406.  
  407.     private async Task ProcessDataAsync(CancellationToken cancellationToken)
  408.     {
  409.         try
  410.         {
  411.             while (await _receiveChannel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
  412.             {
  413.                 while (_receiveChannel.Reader.TryRead(out var message))
  414.                 {
  415.                     try
  416.                     {
  417.                         await ProcessMessageAsync(message).ConfigureAwait(false);
  418.                     }
  419.                     catch (Exception ex)
  420.                     {
  421.                         _logger.LogError("Data loop: {Message}", ex.Message);
  422.                     }
  423.  
  424.                     message.Dispose();
  425.                 }
  426.             }
  427.         }
  428.         catch (OperationCanceledException)
  429.         {
  430.             // operation was canceled, ignore
  431.         }
  432.         finally
  433.         {
  434.             _logger.LogTrace("Data loop finished");
  435.         }
  436.     }
  437.  
  438.     private Task ProcessMessageAsync(Message message)
  439.     {
  440.         OnMessageReceived(new MessageReceivedEventArgs(message));
  441.         return Task.CompletedTask;
  442.     }
  443.  
  444.     #endregion
  445. }
  446.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement