Files
allstarr/allstarr/Middleware/WebSocketProxyMiddleware.cs
2026-02-02 19:53:39 -05:00

230 lines
9.0 KiB
C#

using System.Net.WebSockets;
using Microsoft.Extensions.Options;
using allstarr.Models.Settings;
namespace allstarr.Middleware;
/// <summary>
/// Middleware that proxies WebSocket connections to Jellyfin server.
/// This enables real-time features like session tracking, remote control, and live updates.
/// </summary>
public class WebSocketProxyMiddleware
{
private readonly RequestDelegate _next;
private readonly JellyfinSettings _settings;
private readonly ILogger<WebSocketProxyMiddleware> _logger;
public WebSocketProxyMiddleware(
RequestDelegate next,
IOptions<JellyfinSettings> settings,
ILogger<WebSocketProxyMiddleware> logger)
{
_next = next;
_settings = settings.Value;
_logger = logger;
_logger.LogInformation("🔧 WebSocketProxyMiddleware initialized - Jellyfin URL: {Url}", _settings.Url);
}
public async Task InvokeAsync(HttpContext context)
{
// Log ALL requests for debugging
var path = context.Request.Path.Value ?? "";
var isWebSocket = context.WebSockets.IsWebSocketRequest;
// Log any request that might be WebSocket-related
if (path.Contains("socket", StringComparison.OrdinalIgnoreCase) ||
path.Contains("ws", StringComparison.OrdinalIgnoreCase) ||
isWebSocket ||
context.Request.Headers.ContainsKey("Upgrade"))
{
_logger.LogInformation("🔍 Potential WebSocket request: Path={Path}, IsWs={IsWs}, Method={Method}, Upgrade={Upgrade}, Connection={Connection}",
path,
isWebSocket,
context.Request.Method,
context.Request.Headers["Upgrade"].ToString(),
context.Request.Headers["Connection"].ToString());
}
// Check if this is a WebSocket request to /socket
if (context.Request.Path.StartsWithSegments("/socket", StringComparison.OrdinalIgnoreCase) &&
context.WebSockets.IsWebSocketRequest)
{
_logger.LogInformation("🔌 WebSocket connection request received from {RemoteIp}",
context.Connection.RemoteIpAddress);
await HandleWebSocketProxyAsync(context);
return;
}
// Not a WebSocket request, pass to next middleware
await _next(context);
}
private async Task HandleWebSocketProxyAsync(HttpContext context)
{
ClientWebSocket? serverWebSocket = null;
WebSocket? clientWebSocket = null;
try
{
// Accept the WebSocket connection from the client
clientWebSocket = await context.WebSockets.AcceptWebSocketAsync();
_logger.LogInformation("✓ Client WebSocket accepted");
// Build Jellyfin WebSocket URL
var jellyfinUrl = _settings.Url?.TrimEnd('/') ?? "";
var wsScheme = jellyfinUrl.StartsWith("https://", StringComparison.OrdinalIgnoreCase) ? "wss://" : "ws://";
var jellyfinHost = jellyfinUrl.Replace("https://", "").Replace("http://", "");
var jellyfinWsUrl = $"{wsScheme}{jellyfinHost}/socket";
// Add query parameters if present (e.g., ?api_key=xxx or ?deviceId=xxx)
if (context.Request.QueryString.HasValue)
{
jellyfinWsUrl += context.Request.QueryString.Value;
}
_logger.LogInformation("Connecting to Jellyfin WebSocket: {Url}", jellyfinWsUrl);
// Connect to Jellyfin WebSocket
serverWebSocket = new ClientWebSocket();
// Forward authentication headers
if (context.Request.Headers.TryGetValue("Authorization", out var authHeader))
{
serverWebSocket.Options.SetRequestHeader("Authorization", authHeader.ToString());
_logger.LogDebug("Forwarded Authorization header");
}
else if (context.Request.Headers.TryGetValue("X-Emby-Authorization", out var embyAuthHeader))
{
serverWebSocket.Options.SetRequestHeader("X-Emby-Authorization", embyAuthHeader.ToString());
_logger.LogDebug("Forwarded X-Emby-Authorization header");
}
// Set user agent
serverWebSocket.Options.SetRequestHeader("User-Agent", "Allstarr/1.0");
await serverWebSocket.ConnectAsync(new Uri(jellyfinWsUrl), context.RequestAborted);
_logger.LogInformation("✓ Connected to Jellyfin WebSocket");
// Start bidirectional proxying
var clientToServer = ProxyMessagesAsync(clientWebSocket, serverWebSocket, "Client→Server", context.RequestAborted);
var serverToClient = ProxyMessagesAsync(serverWebSocket, clientWebSocket, "Server→Client", context.RequestAborted);
// Wait for either direction to complete
await Task.WhenAny(clientToServer, serverToClient);
_logger.LogInformation("WebSocket proxy connection closed");
}
catch (WebSocketException wsEx)
{
_logger.LogWarning(wsEx, "WebSocket error: {Message}", wsEx.Message);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error in WebSocket proxy");
}
finally
{
// Clean up connections
if (clientWebSocket?.State == WebSocketState.Open)
{
try
{
await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Proxy closing", CancellationToken.None);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Error closing client WebSocket");
}
}
if (serverWebSocket?.State == WebSocketState.Open)
{
try
{
await serverWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Proxy closing", CancellationToken.None);
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Error closing server WebSocket");
}
}
clientWebSocket?.Dispose();
serverWebSocket?.Dispose();
_logger.LogInformation("WebSocket connections cleaned up");
}
}
private async Task ProxyMessagesAsync(
WebSocket source,
WebSocket destination,
string direction,
CancellationToken cancellationToken)
{
var buffer = new byte[1024 * 4]; // 4KB buffer
var messageBuffer = new List<byte>();
try
{
while (source.State == WebSocketState.Open && destination.State == WebSocketState.Open)
{
var result = await source.ReceiveAsync(new ArraySegment<byte>(buffer), cancellationToken);
if (result.MessageType == WebSocketMessageType.Close)
{
_logger.LogInformation("{Direction}: Close message received", direction);
await destination.CloseAsync(
result.CloseStatus ?? WebSocketCloseStatus.NormalClosure,
result.CloseStatusDescription,
cancellationToken);
break;
}
// Accumulate message fragments
messageBuffer.AddRange(buffer.Take(result.Count));
// If this is the end of the message, forward it
if (result.EndOfMessage)
{
var messageBytes = messageBuffer.ToArray();
// Log message for debugging (only in debug mode to avoid spam)
if (_logger.IsEnabled(LogLevel.Debug))
{
var messageText = System.Text.Encoding.UTF8.GetString(messageBytes);
_logger.LogDebug("{Direction}: {MessageType} message ({Size} bytes): {Preview}",
direction,
result.MessageType,
messageBytes.Length,
messageText.Length > 200 ? messageText[..200] + "..." : messageText);
}
// Forward the complete message
await destination.SendAsync(
new ArraySegment<byte>(messageBytes),
result.MessageType,
true,
cancellationToken);
messageBuffer.Clear();
}
}
}
catch (OperationCanceledException)
{
_logger.LogInformation("{Direction}: Operation cancelled", direction);
}
catch (WebSocketException wsEx) when (wsEx.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
{
_logger.LogInformation("{Direction}: Connection closed prematurely", direction);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "{Direction}: Error proxying messages", direction);
}
}
}