diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index cf4ad8b7..f504ecd7 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -1,4 +1,4 @@ -import { db } from "@server/db"; +import { db, ExitNode } from "@server/db"; import { MessageHandler } from "../ws"; import { clients, @@ -28,7 +28,10 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } const clientId = olm.clientId; - const { publicKey } = message.data; + const { publicKey, relay } = message.data; + + logger.debug(`Olm client ID: ${clientId}, Public Key: ${publicKey}, Relay: ${relay}`); + if (!publicKey) { logger.warn("Public key not provided"); return; @@ -62,6 +65,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { endpoint: exitNode.endpoint, } }); + } if (now - (client.lastHolePunch || 0) > 6) { @@ -85,7 +89,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { await db .update(clientSites) .set({ - isRelayed: false + isRelayed: relay == true }) .where(eq(clientSites.clientId, olm.clientId)); } @@ -98,8 +102,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .where(eq(clientSites.clientId, client.clientId)); // Prepare an array to store site configurations - const siteConfigurations = []; - + let siteConfigurations = []; + logger.debug(`Found ${sitesData.length} sites for client ${client.clientId}`); // Process each site for (const { sites: site } of sitesData) { if (!site.exitNodeId) { @@ -115,7 +119,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { continue; } - if (site.lastHolePunch && now - site.lastHolePunch > 6) { + if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) { logger.warn( `Site ${site.siteId} last hole punch is too old, skipping` ); @@ -143,7 +147,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { await addPeer(site.siteId, { publicKey: publicKey, allowedIps: [`${client.subnet.split('/')[0]}/32`], // we want to only allow from that client - endpoint: client.endpoint + endpoint: relay ? "" : client.endpoint }); } else { logger.warn( @@ -151,10 +155,24 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { ); } + let endpoint = site.endpoint; + if (relay) { + const [exitNode] = await db + .select() + .from(exitNodes) + .where(eq(exitNodes.exitNodeId, site.exitNodeId)) + .limit(1); + if (!exitNode) { + logger.warn(`Exit node not found for site ${site.siteId}`); + continue; + } + endpoint = `${exitNode.endpoint}:21820`; + } + // Add site configuration to the array siteConfigurations.push({ siteId: site.siteId, - endpoint: site.endpoint, + endpoint: endpoint, publicKey: site.publicKey, serverIP: site.address, serverPort: site.listenPort