Advertisement
Hulkstance

Untitled

Oct 12th, 2022
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 14.48 KB | None | 0 0
  1. using System.Buffers;
  2. using System.Net.WebSockets;
  3. using System.Runtime.CompilerServices;
  4. using System.Runtime.InteropServices;
  5. using System.Threading.Channels;
  6. using Microsoft.Extensions.Logging;
  7. using Microsoft.Extensions.Logging.Abstractions;
  8.  
  9. namespace WebSocketClient;
  10.  
  11. public class WebSocketClient : IDisposable
  12. {
  13.     private const int MaxMessageSize = 1_500_000;
  14.  
  15.     private readonly int _receiveChunkSize;
  16.     private readonly string _url;
  17.     private readonly ILogger<WebSocketClient> _logger;
  18.     private readonly Channel<Message> _receiveChannel;
  19.     private readonly Channel<Message> _sendChannel;
  20.     private readonly uint _numberOfConsumers;
  21.     private readonly SemaphoreSlim _semaphore = new(1, 1);
  22.  
  23.     private ClientWebSocket? _clientWebSocket;
  24.     private CancellationTokenSource? _cts;
  25.     private Task _sendTask = Task.CompletedTask;
  26.     private Task _dataTask = Task.CompletedTask;
  27.     private bool _isStopped;
  28.  
  29.     public WebSocketClient(string url, ILoggerFactory? loggerFactory = default, int receiveChunkSize = 4096, bool singleReceiver = false, bool singleSender = false, uint numberOfConsumers = 4)
  30.     {
  31.         if (string.IsNullOrWhiteSpace(url))
  32.         {
  33.             throw new ArgumentNullException(nameof(url));
  34.         }
  35.  
  36.         _url = url;
  37.         _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<WebSocketClient>();
  38.         _receiveChunkSize = receiveChunkSize;
  39.  
  40.         _receiveChannel = Channel.CreateBounded<Message>(new BoundedChannelOptions(10)
  41.         {
  42.             SingleWriter = true,
  43.             SingleReader = singleReceiver,
  44.             FullMode = BoundedChannelFullMode.DropOldest
  45.         });
  46.  
  47.         _sendChannel = Channel.CreateBounded<Message>(new BoundedChannelOptions(10)
  48.         {
  49.             SingleReader = true,
  50.             SingleWriter = singleSender,
  51.             FullMode = BoundedChannelFullMode.Wait
  52.         });
  53.  
  54.         _numberOfConsumers = numberOfConsumers;
  55.     }
  56.  
  57.     internal Task Running { get; private set; } = Task.CompletedTask;
  58.     public bool IsDisposed { get; protected set; }
  59.     public bool IsOpen => _clientWebSocket is { State: WebSocketState.Open };
  60.  
  61.     public Func<ClientWebSocket> ClientFactory { get; } = () => new ClientWebSocket();
  62.  
  63.     public event EventHandler? Connected;
  64.     public event EventHandler? Disconnected;
  65.     public event EventHandler<ErrorEventArgs>? Error;
  66.     public event EventHandler<MessageReceivedEventArgs>? MessageReceived;
  67.  
  68.     public void Dispose()
  69.     {
  70.         Dispose(true);
  71.         GC.SuppressFinalize(this);
  72.     }
  73.  
  74.     public async Task StartAsync()
  75.     {
  76.         DoDisposeChecks();
  77.  
  78.         // Prevent a race condition
  79.         await _semaphore.WaitAsync().ConfigureAwait(false);
  80.  
  81.         try
  82.         {
  83.             if (_cts == null)
  84.             {
  85.                 _cts = new CancellationTokenSource();
  86.  
  87.                 _clientWebSocket = null;
  88.  
  89.                 Running = Task.Run(async () =>
  90.                 {
  91.                     _logger.LogTrace("Connection task started: {Url}", _url);
  92.  
  93.                     try
  94.                     {
  95.                         await HandleConnection().ConfigureAwait(false);
  96.                     }
  97.                     catch (Exception e)
  98.                     {
  99.                         _logger.LogError(e, "Error in connection task: {Url}: ", _url);
  100.                     }
  101.  
  102.                     _logger.LogTrace("Connection task ended: {Url}", _url);
  103.                 }, _cts.Token);
  104.  
  105.                 var count = 0;
  106.                 do
  107.                 {
  108.                     // wait for _client to be not null
  109.                     if (_clientWebSocket != null || _cts.Token.WaitHandle.WaitOne(50))
  110.                     {
  111.                         break;
  112.                     }
  113.                 } while (++count < 100);
  114.             }
  115.         }
  116.         finally
  117.         {
  118.             _semaphore.Release();
  119.         }
  120.     }
  121.  
  122.     public async Task StopAsync()
  123.     {
  124.         DoDisposeChecks();
  125.  
  126.         await _semaphore.WaitAsync().ConfigureAwait(false);
  127.  
  128.         try
  129.         {
  130.             if (Running.IsCompleted)
  131.             {
  132.                 // We never started
  133.                 return;
  134.             }
  135.  
  136.             _logger.LogDebug("Stopping");
  137.  
  138.             _isStopped = true;
  139.  
  140.             try
  141.             {
  142.                 // Close the socket first, because ReceiveAsync leaves an invalid socket (state = aborted) when the token is cancelled
  143.                 if (_clientWebSocket is { State: not (WebSocketState.Aborted or WebSocketState.Closed or WebSocketState.CloseSent) })
  144.                 {
  145.                     _logger.LogInformation("CloseOutputAsync called");
  146.                     // After this call, the socket state which change to CloseSent
  147.                     await _clientWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).ConfigureAwait(false);
  148.                 }
  149.             }
  150.             catch
  151.             {
  152.                 // Any exception thrown here will be caused by the socket already being closed,
  153.                 // which is the state we want to put it in by calling this method, which
  154.                 // means we don't care if it was already closed and threw an exception
  155.                 // when we tried to close it again.
  156.             }
  157.  
  158.             _logger.LogInformation("cts called");
  159.  
  160.             _cts?.Cancel();
  161.  
  162.             await Running.ConfigureAwait(false);
  163.  
  164.             _cts?.Dispose();
  165.             _cts = null;
  166.  
  167.             OnDisconnected();
  168.  
  169.             _logger.LogDebug("Stopped");
  170.         }
  171.         finally
  172.         {
  173.             _semaphore.Release();
  174.         }
  175.     }
  176.  
  177.     [MethodImpl(MethodImplOptions.AggressiveInlining)]
  178.     protected void DoDisposeChecks()
  179.     {
  180.         if (IsDisposed)
  181.         {
  182.             throw new ObjectDisposedException(nameof(WebSocketClient));
  183.         }
  184.     }
  185.  
  186.     protected virtual void Dispose(bool disposing)
  187.     {
  188.         if (IsDisposed)
  189.         {
  190.             return;
  191.         }
  192.  
  193.         if (disposing)
  194.         {
  195.             _semaphore.Dispose();
  196.             _receiveChannel.Writer.TryComplete();
  197.             _sendChannel.Writer.TryComplete();
  198.         }
  199.  
  200.         IsDisposed = true;
  201.     }
  202.  
  203.     private async Task HandleConnection()
  204.     {
  205.         while (_cts is { IsCancellationRequested: false })
  206.         {
  207.             DoDisposeChecks();
  208.  
  209.             _logger.LogTrace("Connecting...");
  210.  
  211.             using var connectionCts = CancellationTokenSource.CreateLinkedTokenSource(_cts.Token);
  212.  
  213.             try
  214.             {
  215.                 await _semaphore.WaitAsync(connectionCts.Token).ConfigureAwait(false);
  216.  
  217.                 try
  218.                 {
  219.                     var clientWebSocket = ClientFactory();
  220.                     await clientWebSocket.ConnectAsync(new Uri(_url), connectionCts.Token).ConfigureAwait(false);
  221.  
  222.                     _clientWebSocket = clientWebSocket;
  223.                 }
  224.                 finally
  225.                 {
  226.                     _semaphore.Release();
  227.                 }
  228.  
  229.                 OnConnected();
  230.                 _isStopped = false;
  231.  
  232.                 using (_clientWebSocket)
  233.                 {
  234.                     var token = connectionCts.Token;
  235.                     _sendTask = ProcessSendAsync(_clientWebSocket, token);
  236.  
  237.                     if (_numberOfConsumers == 0)
  238.                     {
  239.                         _dataTask = ProcessDataAsync(token);
  240.                     }
  241.                     else
  242.                     {
  243.                         var processingTasks = Enumerable.Range(1, (int)_numberOfConsumers).Select(_ => ProcessDataAsync(token));
  244.                         _dataTask = Task.WhenAll(processingTasks);
  245.                     }
  246.  
  247.                     // Listen for messages
  248.                     byte[]? buffer = null;
  249.                     var count = 0;
  250.  
  251.                     while (true)
  252.                     {
  253.                         // Rent a buffer only we run out of space, not for every read
  254.                         buffer ??= ArrayPool<byte>.Shared.Rent(_receiveChunkSize);
  255.  
  256.                         // We need Memory<byte> overload instead of the ArraySegment one (this one allocates a Task per read).
  257.                         var result = await _clientWebSocket.ReceiveAsync(buffer.AsMemory()[count..], connectionCts.Token).ConfigureAwait(false);
  258.  
  259.                         if (result.MessageType == WebSocketMessageType.Close)
  260.                         {
  261.                             _logger.LogDebug("The remote server has closed the connection");
  262.  
  263.                             // Prevent leak
  264.                             if (MemoryMarshal.TryGetArray((ReadOnlyMemory<byte>)buffer.AsMemory(0, count), out var arraySegment))
  265.                             {
  266.                                 ArrayPool<byte>.Shared.Return(arraySegment.Array!);
  267.                             }
  268.  
  269.                             // Some exchanges e.g. Binance are closing the remote connection when an error occurs, so we simply reconnect when that happens
  270.                             if (!_isStopped)
  271.                             {
  272.                                 break;
  273.                             }
  274.  
  275.                             // Jumps to finally block
  276.                             return;
  277.                         }
  278.  
  279.                         count += result.Count;
  280.  
  281.                         if (count > MaxMessageSize)
  282.                         {
  283.                             throw new InvalidOperationException("Maximum size of the message was exceeded.");
  284.                         }
  285.  
  286.                         if (result.EndOfMessage)
  287.                         {
  288.                             // Avoid working with strings to reduce allocation
  289.                             await _receiveChannel.Writer.WriteAsync(new Message(buffer.AsMemory(0, count)), connectionCts.Token).ConfigureAwait(false);
  290.  
  291.                             count = 0;
  292.                             buffer = null;
  293.                         }
  294.                         else if (count >= buffer.Length)
  295.                         {
  296.                             // Create the new array
  297.                             var newArray = ArrayPool<byte>.Shared.Rent(buffer.Length * 2);
  298.  
  299.                             // Copy the old array to the new array
  300.                             buffer.AsSpan().CopyTo(newArray);
  301.  
  302.                             // Return the old array
  303.                             ArrayPool<byte>.Shared.Return(buffer);
  304.  
  305.                             buffer = newArray;
  306.                         }
  307.                     }
  308.                 }
  309.             }
  310.             catch (OperationCanceledException)
  311.             {
  312.                 // operation was canceled, ignore
  313.             }
  314.             catch (WebSocketException ex)
  315.             {
  316.                 if (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
  317.                 {
  318.                     _logger.LogDebug("Prematurely closed");
  319.                 }
  320.  
  321.                 OnError(new ErrorEventArgs(ex.Message, ex));
  322.                 connectionCts.Token.WaitHandle.WaitOne(2000);
  323.             }
  324.             catch (Exception ex)
  325.             {
  326.                 OnError(new ErrorEventArgs(ex.Message, ex));
  327.             }
  328.             finally
  329.             {
  330.                 connectionCts.Cancel();
  331.  
  332.                 await Task.WhenAll(_sendTask, _dataTask).ConfigureAwait(false);
  333.             }
  334.         }
  335.     }
  336.  
  337.     protected virtual void OnConnected()
  338.     {
  339.         _logger.LogTrace("OnConnected: Connection opened (URL: {Url})", _url);
  340.         Connected?.Invoke(this, EventArgs.Empty);
  341.     }
  342.  
  343.     protected virtual void OnDisconnected()
  344.     {
  345.         _logger.LogTrace("OnDisconnected: Connection closed");
  346.         Disconnected?.Invoke(this, EventArgs.Empty);
  347.     }
  348.  
  349.     protected virtual void OnError(ErrorEventArgs e)
  350.     {
  351.         _logger.LogError(e.Exception, "OnError: {Message}", e.Message);
  352.         Error?.Invoke(this, e);
  353.     }
  354.  
  355.     protected virtual void OnMessageReceived(MessageReceivedEventArgs e)
  356.     {
  357.         MessageReceived?.Invoke(this, e);
  358.     }
  359.  
  360.     #region Send
  361.  
  362.     // Producer
  363.     public ValueTask SendAsync(Message message)
  364.     {
  365.         DoDisposeChecks();
  366.  
  367.         return _sendChannel.Writer.WriteAsync(message);
  368.     }
  369.  
  370.     // Producer
  371.     public bool Send(Message message)
  372.     {
  373.         DoDisposeChecks();
  374.  
  375.         return _sendChannel.Writer.TryWrite(message);
  376.     }
  377.  
  378.     // Consumer
  379.     private async Task ProcessSendAsync(WebSocket webSocket, CancellationToken cancellationToken)
  380.     {
  381.         try
  382.         {
  383.             while (await _sendChannel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
  384.             {
  385.                 while (_sendChannel.Reader.TryRead(out var message))
  386.                 {
  387.                     try
  388.                     {
  389.                         await webSocket.SendAsync(message.Buffer, WebSocketMessageType.Text, true, CancellationToken.None).ConfigureAwait(false);
  390.                     }
  391.                     catch (Exception ex)
  392.                     {
  393.                         _logger.LogError(ex, "SendAsync: {ExceptionMessage}", ex.Message);
  394.                     }
  395.                 }
  396.             }
  397.         }
  398.         catch (OperationCanceledException)
  399.         {
  400.             // operation was canceled, ignore
  401.         }
  402.         finally
  403.         {
  404.             _logger.LogTrace("Send loop finished");
  405.         }
  406.     }
  407.  
  408.     #endregion
  409.  
  410.     #region Data
  411.  
  412.     private async Task ProcessDataAsync(CancellationToken cancellationToken)
  413.     {
  414.         try
  415.         {
  416.             while (await _receiveChannel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
  417.             {
  418.                 while (_receiveChannel.Reader.TryRead(out var message))
  419.                 {
  420.                     try
  421.                     {
  422.                         await ProcessMessageAsync(message).ConfigureAwait(false);
  423.                     }
  424.                     catch (Exception ex)
  425.                     {
  426.                         _logger.LogError("Data loop: {Message}", ex.Message);
  427.                     }
  428.  
  429.                     message.Dispose();
  430.                 }
  431.             }
  432.         }
  433.         catch (OperationCanceledException)
  434.         {
  435.             // operation was canceled, ignore
  436.         }
  437.         finally
  438.         {
  439.             _logger.LogTrace("Data loop finished");
  440.         }
  441.     }
  442.  
  443.     private Task ProcessMessageAsync(Message message)
  444.     {
  445.         OnMessageReceived(new MessageReceivedEventArgs(message));
  446.         return Task.CompletedTask;
  447.     }
  448.  
  449.     #endregion
  450. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement