diff --git a/Server/Security/AuthHelper.cs b/Server/Security/AuthHelper.cs index 7c2a1a2..c0f8387 100644 --- a/Server/Security/AuthHelper.cs +++ b/Server/Security/AuthHelper.cs @@ -3,7 +3,10 @@ using System.Security.Claims; using System.IdentityModel.Tokens.Jwt; using Microsoft.AspNetCore.Http; using Microsoft.IdentityModel.Tokens; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; using SNote.Server.Security; +using SNote.Server.Endpoints; namespace SNote.Server.Security; @@ -47,7 +50,7 @@ public static class AuthHelper public static bool IsServerTokenValid(HttpContext context, PeerCache peerCache) { - var serverUrl = context.Request.Headers["X-Server-Url"].ToString(); + var serverUrl = context.Request.Headers["X-Server-Url"].ToString().Trim().TrimEnd('/'); var serverToken = context.Request.Headers["X-Server-Token"].ToString(); if (string.IsNullOrEmpty(serverUrl) || string.IsNullOrEmpty(serverToken)) @@ -55,7 +58,27 @@ public static class AuthHelper return false; } - return peerCache.VerifySessionToken(serverUrl, serverToken); + // 1. Verify if it's a handshaked downstream peer calling us (upstream verification) + if (peerCache.VerifySessionToken(serverUrl, serverToken)) + { + return true; + } + + // 2. Verify if it's our configured upstream calling us (downstream verification) + var configuration = context.RequestServices.GetRequiredService(); + var destUrl = (configuration["Sync:DestinationServerUrl"] ?? "").Trim().TrimEnd('/'); + + if (!string.IsNullOrEmpty(destUrl) && string.Equals(serverUrl, destUrl, StringComparison.OrdinalIgnoreCase)) + { + // The request is coming from our upstream. Verify the token matches the one we received during handshake! + if (!string.IsNullOrEmpty(SyncEndpoints.UpstreamSessionToken) && + string.Equals(SyncEndpoints.UpstreamSessionToken, serverToken, StringComparison.Ordinal)) + { + return true; + } + } + + return false; } // Helper to generate a server token for outgoing sync requests