diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index a1cda2ec..8ab1b24c 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -88,7 +88,7 @@ export const sites = pgTable("sites", { publicKey: varchar("publicKey"), lastHolePunch: bigint("lastHolePunch", { mode: "number" }), listenPort: integer("listenPort"), - dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true), + dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true) }); export const resources = pgTable("resources", { @@ -206,7 +206,7 @@ export const siteResources = pgTable("siteResources", { mode: varchar("mode").notNull(), // "host" | "cidr" | "port" protocol: varchar("protocol"), // only for port mode proxyPort: integer("proxyPort"), // only for port mode - destinationPort: integer("destinationPort"), // only for port mode + destinationPort: integer("destinationPort"), // only for port mode destination: varchar("destination").notNull(), // ip, cidr, hostname; validate against the mode enabled: boolean("enabled").notNull().default(true), alias: varchar("alias") @@ -654,25 +654,25 @@ export const clients = pgTable("clients", { maxConnections: integer("maxConnections") }); -export const clientSitesAssociationsCache = pgTable("clientSitesAssociationsCache", { - clientId: integer("clientId") - .notNull() - .references(() => clients.clientId, { onDelete: "cascade" }), - siteId: integer("siteId") - .notNull() - .references(() => sites.siteId, { onDelete: "cascade" }), - isRelayed: boolean("isRelayed").notNull().default(false), - endpoint: varchar("endpoint") -}); +export const clientSitesAssociationsCache = pgTable( + "clientSitesAssociationsCache", + { + clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message + .notNull(), + siteId: integer("siteId").notNull(), + isRelayed: boolean("isRelayed").notNull().default(false), + endpoint: varchar("endpoint") + } +); -export const clientSiteResourcesAssociationsCache = pgTable("clientSiteResourcesAssociationsCache", { - clientId: integer("clientId") - .notNull() - .references(() => clients.clientId, { onDelete: "cascade" }), - siteResourceId: integer("siteResourceId") - .notNull() - .references(() => siteResources.siteResourceId, { onDelete: "cascade" }) -}); +export const clientSiteResourcesAssociationsCache = pgTable( + "clientSiteResourcesAssociationsCache", + { + clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message + .notNull(), + siteResourceId: integer("siteResourceId").notNull() + } +); export const olms = pgTable("olms", { olmId: varchar("id").primaryKey(), diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index cae15a04..cfffdba7 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -361,27 +361,27 @@ export const clients = sqliteTable("clients", { lastHolePunch: integer("lastHolePunch") }); -export const clientSitesAssociationsCache = sqliteTable("clientSitesAssociationsCache", { - clientId: integer("clientId") - .notNull() - .references(() => clients.clientId, { onDelete: "cascade" }), - siteId: integer("siteId") - .notNull() - .references(() => sites.siteId, { onDelete: "cascade" }), - isRelayed: integer("isRelayed", { mode: "boolean" }) - .notNull() - .default(false), - endpoint: text("endpoint") -}); +export const clientSitesAssociationsCache = sqliteTable( + "clientSitesAssociationsCache", + { + clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message + .notNull(), + siteId: integer("siteId").notNull(), + isRelayed: integer("isRelayed", { mode: "boolean" }) + .notNull() + .default(false), + endpoint: text("endpoint") + } +); -export const clientSiteResourcesAssociationsCache = sqliteTable("clientSiteResourcesAssociationsCache", { - clientId: integer("clientId") - .notNull() - .references(() => clients.clientId, { onDelete: "cascade" }), - siteResourceId: integer("siteResourceId") - .notNull() - .references(() => siteResources.siteResourceId, { onDelete: "cascade" }) -}); +export const clientSiteResourcesAssociationsCache = sqliteTable( + "clientSiteResourcesAssociationsCache", + { + clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message + .notNull(), + siteResourceId: integer("siteResourceId").notNull() + } +); export const olms = sqliteTable("olms", { olmId: text("id").primaryKey(), diff --git a/server/lib/ip.ts b/server/lib/ip.ts index 8acf7c05..055820dc 100644 --- a/server/lib/ip.ts +++ b/server/lib/ip.ts @@ -304,7 +304,7 @@ export async function getNextAvailableOrgSubnet(): Promise { return subnet; } -export function generateRemoteSubnetsStr(allSiteResources: SiteResource[]) { +export function generateRemoteSubnets(allSiteResources: SiteResource[]): string[] { let remoteSubnets = allSiteResources .filter((sr) => { if (sr.mode === "cidr") return true; @@ -321,12 +321,11 @@ export function generateRemoteSubnetsStr(allSiteResources: SiteResource[]) { if (sr.mode === "host") { return `${sr.destination}/32`; } - }); + return ""; // This should never be reached due to filtering, but satisfies TypeScript + }) + .filter((subnet) => subnet !== ""); // Remove empty strings just to be safe // remove duplicates - remoteSubnets = Array.from(new Set(remoteSubnets)); - const remoteSubnetsStr = - remoteSubnets.length > 0 ? remoteSubnets.join(",") : null; - return remoteSubnetsStr; + return Array.from(new Set(remoteSubnets)); } export type SubnetProxyTarget = { diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index eae9529d..8a9e2de3 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -2,6 +2,7 @@ import { Client, clients, clientSiteResources, + clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, exitNodes, @@ -30,7 +31,7 @@ import { import { sendToExitNode } from "#dynamic/lib/exitNodes"; import logger from "@server/logger"; import { - generateRemoteSubnetsStr, + generateRemoteSubnets, generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip"; @@ -204,17 +205,35 @@ export async function rebuildClientAssociations( const existingClientSiteResources = await trx .select({ - clientId: clientSiteResources.clientId + clientId: clientSiteResourcesAssociationsCache.clientId }) - .from(clientSiteResources) + .from(clientSiteResourcesAssociationsCache) .where( - eq(clientSiteResources.siteResourceId, siteResource.siteResourceId) + eq( + clientSiteResourcesAssociationsCache.siteResourceId, + siteResource.siteResourceId + ) ); const existingClientSiteResourceIds = existingClientSiteResources.map( (row) => row.clientId ); + // Get full client details for existing resource clients (needed for sending delete messages) + const existingResourceClients = + existingClientSiteResourceIds.length > 0 + ? await trx + .select({ + clientId: clients.clientId, + pubKey: clients.pubKey, + subnet: clients.subnet + }) + .from(clients) + .where( + inArray(clients.clientId, existingClientSiteResourceIds) + ) + : []; + const clientSiteResourcesToAdd = mergedAllClientIds.filter( (clientId) => !existingClientSiteResourceIds.includes(clientId) ); @@ -228,7 +247,7 @@ export async function rebuildClientAssociations( if (clientSiteResourcesToInsert.length > 0) { await trx - .insert(clientSiteResources) + .insert(clientSiteResourcesAssociationsCache) .values(clientSiteResourcesToInsert) .returning(); } @@ -239,15 +258,15 @@ export async function rebuildClientAssociations( if (clientSiteResourcesToRemove.length > 0) { await trx - .delete(clientSiteResources) + .delete(clientSiteResourcesAssociationsCache) .where( and( eq( - clientSiteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId, siteResource.siteResourceId ), inArray( - clientSiteResources.clientId, + clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesToRemove ) ) @@ -269,7 +288,7 @@ export async function rebuildClientAssociations( await handleSubnetProxyTargetUpdates( siteResource, mergedAllClients, - existingClients, + existingResourceClients, clientSiteResourcesToAdd, clientSiteResourcesToRemove, trx @@ -277,7 +296,7 @@ export async function rebuildClientAssociations( return { mergedAllClients - } + }; } async function handleMessagesForSiteClients( @@ -429,10 +448,25 @@ async function handleMessagesForSiteClients( ); // TODO: should we have this here? - const allSiteResources = await trx + const allSiteResources = await db // only get the site resources that this client has access to .select() .from(siteResources) - .where(eq(siteResources.siteId, site.siteId)); + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + siteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId + ) + ) + .where( + and( + eq(siteResources.siteId, site.siteId), + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ) + ); olmJobs.push( olmAddPeer( @@ -446,8 +480,11 @@ async function handleMessagesForSiteClients( publicKey: site.publicKey, serverIP: site.address, serverPort: site.listenPort, - remoteSubnets: - generateRemoteSubnetsStr(allSiteResources) + remoteSubnets: generateRemoteSubnets( + allSiteResources.map( + ({ siteResources }) => siteResources + ) + ) }, olm.olmId ) @@ -518,9 +555,13 @@ export async function updateClientSiteDestinations( exitNodeId: site.exitNodes?.exitNodeId || 0, type: site.exitNodes?.type || "", name: site.exitNodes?.name || "", - sourceIp: site.clientSitesAssociationsCache.endpoint.split(":")[0] || "", + sourceIp: + site.clientSitesAssociationsCache.endpoint.split(":")[0] || + "", sourcePort: - parseInt(site.clientSitesAssociationsCache.endpoint.split(":")[1]) || 0, + parseInt( + site.clientSitesAssociationsCache.endpoint.split(":")[1] + ) || 0, destinations: [ { destinationIP: site.sites.subnet.split("/")[0], diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index 52a159b0..68116686 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -7,16 +7,16 @@ import { ExitNode, exitNodes, siteResources, - clientSiteResourcesAssociationsCache, + clientSiteResourcesAssociationsCache } from "@server/db"; import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db"; import { eq, and, inArray } from "drizzle-orm"; import { updatePeer } from "../olm/peers"; import { sendToExitNode } from "#dynamic/lib/exitNodes"; import { - generateRemoteSubnetsStr, + generateRemoteSubnets, generateSubnetProxyTargets, - SubnetProxyTarget, + SubnetProxyTarget } from "@server/lib/ip"; const inputSchema = z.object({ @@ -137,7 +137,10 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { const clientsRes = await db .select() .from(clients) - .innerJoin(clientSitesAssociationsCache, eq(clients.clientId, clientSitesAssociationsCache.clientId)) + .innerJoin( + clientSitesAssociationsCache, + eq(clients.clientId, clientSitesAssociationsCache.clientId) + ) .where(eq(clientSitesAssociationsCache.siteId, siteId)); // Prepare peers data for the response @@ -186,10 +189,25 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { return null; } - const allSiteResources = await db + const allSiteResources = await db // only get the site resources that this client has access to .select() .from(siteResources) - .where(eq(siteResources.siteId, site.siteId)); + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + siteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId + ) + ) + .where( + and( + eq(siteResources.siteId, site.siteId), + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clients.clientId + ) + ) + ); await updatePeer(client.clients.clientId, { siteId: site.siteId, @@ -197,8 +215,11 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { publicKey: site.publicKey, serverIP: site.address, serverPort: site.listenPort, - remoteSubnets: - generateRemoteSubnetsStr(allSiteResources) + remoteSubnets: generateRemoteSubnets( + allSiteResources.map( + ({ siteResources }) => siteResources + ) + ) }); } catch (error) { logger.error( @@ -238,7 +259,10 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { .from(clients) .innerJoin( clientSiteResourcesAssociationsCache, - eq(clients.clientId, clientSiteResourcesAssociationsCache.clientId) + eq( + clients.clientId, + clientSiteResourcesAssociationsCache.clientId + ) ) .where( eq( @@ -247,7 +271,10 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { ) ); - const resourceTargets = generateSubnetProxyTargets(resource, resourceClients); + const resourceTargets = generateSubnetProxyTargets( + resource, + resourceClients + ); targetsToSend.push(...resourceTargets); } diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 734c29f3..5c438e4f 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -1,5 +1,6 @@ import { Client, + clientSiteResourcesAssociationsCache, db, ExitNode, orgs, @@ -12,13 +13,20 @@ import { users } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; -import { clients, clientSitesAssociationsCache, exitNodes, Olm, olms, sites } from "@server/db"; +import { + clients, + clientSitesAssociationsCache, + exitNodes, + Olm, + olms, + sites +} from "@server/db"; import { and, eq, inArray, isNull } from "drizzle-orm"; import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { getNextAvailableClientSubnet } from "@server/lib/ip"; -import { generateRemoteSubnetsStr } from "@server/lib/ip"; +import { generateRemoteSubnets } from "@server/lib/ip"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.info("Handling register olm message!"); @@ -170,7 +178,10 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { const sitesData = await db .select() .from(sites) - .innerJoin(clientSitesAssociationsCache, eq(sites.siteId, clientSitesAssociationsCache.siteId)) + .innerJoin( + clientSitesAssociationsCache, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); // Prepare an array to store site configurations @@ -234,11 +245,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { ) .limit(1); - const allSiteResources = await db - .select() - .from(siteResources) - .where(eq(siteResources.siteId, site.siteId)); - // Add the peer to the exit node for this site if (clientSite.endpoint) { logger.info( @@ -269,6 +275,26 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { endpoint = `${exitNode.endpoint}:21820`; } + const allSiteResources = await db // only get the site resources that this client has access to + .select() + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + siteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId + ) + ) + .where( + and( + eq(siteResources.siteId, site.siteId), + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ) + ); + // Add site configuration to the array siteConfigurations.push({ siteId: site.siteId, @@ -276,7 +302,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { publicKey: site.publicKey, serverIP: site.address, serverPort: site.listenPort, - remoteSubnets: generateRemoteSubnetsStr(allSiteResources) + remoteSubnets: generateRemoteSubnets(allSiteResources.map(({ siteResources }) => siteResources)) }); } diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index c712ea65..1daed53a 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -12,7 +12,7 @@ export async function addPeer( endpoint: string; serverIP: string | null; serverPort: number | null; - remoteSubnets: string | null; // optional, comma-separated list of subnets that this site can access + remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access }, olmId?: string ) { @@ -80,7 +80,7 @@ export async function updatePeer( endpoint: string; serverIP: string | null; serverPort: number | null; - remoteSubnets?: string | null; // optional, comma-separated list of subnets that + remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that }, olmId?: string ) { diff --git a/server/routers/siteResource/createSiteResource.ts b/server/routers/siteResource/createSiteResource.ts index 618256be..2c7bf0fe 100644 --- a/server/routers/siteResource/createSiteResource.ts +++ b/server/routers/siteResource/createSiteResource.ts @@ -272,9 +272,6 @@ export async function createSiteResource( ); } - // const targets = await generateSubnetProxyTargets([newSiteResource], trx); - // await addTargets(newt.newtId, targets); - await rebuildClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role }); diff --git a/server/routers/siteResource/deleteSiteResource.ts b/server/routers/siteResource/deleteSiteResource.ts index d2838a5a..75f2c3f2 100644 --- a/server/routers/siteResource/deleteSiteResource.ts +++ b/server/routers/siteResource/deleteSiteResource.ts @@ -106,10 +106,7 @@ export async function deleteSiteResource( ); } - // const targets = await generateSubnetProxyTargets([removedSiteResource], trx); - // await removeTargets(newt.newtId, targets); - - await rebuildClientAssociations(existingSiteResource, trx); + await rebuildClientAssociations(removedSiteResource, trx); }); logger.info( diff --git a/statement-breakpoint b/statement-breakpoint new file mode 100644 index 00000000..e69de29b