diff --git a/server/lib/calculateUserClientsForOrgs.ts b/server/lib/calculateUserClientsForOrgs.ts index f66e3888..4cde8657 100644 --- a/server/lib/calculateUserClientsForOrgs.ts +++ b/server/lib/calculateUserClientsForOrgs.ts @@ -1,8 +1,20 @@ -import { clients, clientSitesAssociationsCache, db, olms, orgs, roleClients, roles, userClients, userOrgs, Transaction } from "@server/db"; +import { + clients, + db, + olms, + orgs, + roleClients, + roles, + userClients, + userOrgs, + Transaction +} from "@server/db"; import { eq, and, notInArray } from "drizzle-orm"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { getNextAvailableClientSubnet } from "@server/lib/ip"; import logger from "@server/logger"; +import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations"; +import { sendTerminateClient } from "@server/routers/client/terminate"; export async function calculateUserClientsForOrgs( userId: string, @@ -88,7 +100,10 @@ export async function calculateUserClientsForOrgs( .where( and( eq(roleClients.roleId, adminRole.roleId), - eq(roleClients.clientId, existingClient.clientId) + eq( + roleClients.clientId, + existingClient.clientId + ) ) ) .limit(1); @@ -110,7 +125,10 @@ export async function calculateUserClientsForOrgs( .where( and( eq(userClients.userId, userId), - eq(userClients.clientId, existingClient.clientId) + eq( + userClients.clientId, + existingClient.clientId + ) ) ) .limit(1); @@ -172,6 +190,11 @@ export async function calculateUserClientsForOrgs( }) .returning(); + await rebuildClientAssociationsFromClient( + newClient, + transaction + ); + // Grant admin role access to the client await transaction.insert(roleClients).values({ roleId: adminRole.roleId, @@ -225,15 +248,8 @@ async function cleanupOrphanedClients( : and(eq(clients.userId, userId)) ); - // Delete client-site associations first, then delete the clients - for (const client of clientsToDelete) { - await trx - .delete(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); - } - if (clientsToDelete.length > 0) { - await trx + const deletedClients = await trx .delete(clients) .where( userOrgIds.length > 0 @@ -242,7 +258,20 @@ async function cleanupOrphanedClients( notInArray(clients.orgId, userOrgIds) ) : and(eq(clients.userId, userId)) - ); + ) + .returning(); + + // Rebuild associations for each deleted client to clean up related data + for (const deletedClient of deletedClients) { + await rebuildClientAssociationsFromClient(deletedClient, trx); + + if (deletedClient.olmId) { + await sendTerminateClient( + deletedClient.clientId, + deletedClient.olmId + ); + } + } if (userOrgIds.length === 0) { logger.debug( @@ -255,4 +284,3 @@ async function cleanupOrphanedClients( } } } - diff --git a/server/lib/ip.ts b/server/lib/ip.ts index 7835ad84..4a02694a 100644 --- a/server/lib/ip.ts +++ b/server/lib/ip.ts @@ -398,8 +398,9 @@ export function generateAliasConfig(allSiteResources: SiteResource[]): Alias[] { } export type SubnetProxyTarget = { - sourcePrefix: string; - destPrefix: string; + sourcePrefix: string; // must be a cidr + destPrefix: string; // must be a cidr + rewriteTo?: string; // must be a cidr portRange?: { min: number; max: number; @@ -447,7 +448,8 @@ export function generateSubnetProxyTargets( // also push a match for the alias address targets.push({ sourcePrefix: clientPrefix, - destPrefix: `${siteResource.aliasAddress}/32` + destPrefix: `${siteResource.aliasAddress}/32`, + rewriteTo: `${siteResource.destination}/32` }); } } else if (siteResource.mode == "cidr") { diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index a1072196..f1cbea0c 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -129,7 +129,7 @@ export async function getClientSiteResourceAccess( }; } -export async function rebuildClientAssociations( +export async function rebuildClientAssociationsFromSiteResource( siteResource: SiteResource, trx: Transaction | typeof db = db ): Promise<{ @@ -753,3 +753,552 @@ async function handleSubnetProxyTargetUpdates( await Promise.all(proxyJobs); } + +export async function rebuildClientAssociationsFromClient( + client: Client, + trx: Transaction | typeof db = db +): Promise { + let newSiteResourceIds: number[] = []; + + // 1. Direct client associations + const directSiteResources = await trx + .select({ siteResourceId: clientSiteResources.siteResourceId }) + .from(clientSiteResources) + .where(eq(clientSiteResources.clientId, client.clientId)); + + newSiteResourceIds.push( + ...directSiteResources.map((r) => r.siteResourceId) + ); + + // 2. User-based and role-based access (if client has a userId) + if (client.userId) { + // Direct user associations + const userSiteResourceIds = await trx + .select({ siteResourceId: userSiteResources.siteResourceId }) + .from(userSiteResources) + .where(eq(userSiteResources.userId, client.userId)); + + newSiteResourceIds.push( + ...userSiteResourceIds.map((r) => r.siteResourceId) + ); + + // Role-based access + const roleIds = await trx + .select({ roleId: userOrgs.roleId }) + .from(userOrgs) + .where(eq(userOrgs.userId, client.userId)) + .then((rows) => rows.map((row) => row.roleId)); + + if (roleIds.length > 0) { + const roleSiteResourceIds = await trx + .select({ siteResourceId: roleSiteResources.siteResourceId }) + .from(roleSiteResources) + .where(inArray(roleSiteResources.roleId, roleIds)); + + newSiteResourceIds.push( + ...roleSiteResourceIds.map((r) => r.siteResourceId) + ); + } + } + + // Remove duplicates + newSiteResourceIds = Array.from(new Set(newSiteResourceIds)); + + // Get full siteResource details + const newSiteResources = + newSiteResourceIds.length > 0 + ? await trx + .select() + .from(siteResources) + .where( + inArray(siteResources.siteResourceId, newSiteResourceIds) + ) + : []; + + // Group by siteId for site-level associations + const newSiteIds = Array.from( + new Set(newSiteResources.map((sr) => sr.siteId)) + ); + + /////////// Process client-siteResource associations /////////// + + // Get existing resource associations + const existingResourceAssociations = await trx + .select({ + siteResourceId: clientSiteResourcesAssociationsCache.siteResourceId + }) + .from(clientSiteResourcesAssociationsCache) + .where( + eq(clientSiteResourcesAssociationsCache.clientId, client.clientId) + ); + + const existingSiteResourceIds = existingResourceAssociations.map( + (r) => r.siteResourceId + ); + + const resourcesToAdd = newSiteResourceIds.filter( + (id) => !existingSiteResourceIds.includes(id) + ); + + const resourcesToRemove = existingSiteResourceIds.filter( + (id) => !newSiteResourceIds.includes(id) + ); + + // Insert new associations + if (resourcesToAdd.length > 0) { + await trx.insert(clientSiteResourcesAssociationsCache).values( + resourcesToAdd.map((siteResourceId) => ({ + clientId: client.clientId, + siteResourceId + })) + ); + } + + // Remove old associations + if (resourcesToRemove.length > 0) { + await trx + .delete(clientSiteResourcesAssociationsCache) + .where( + and( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ), + inArray( + clientSiteResourcesAssociationsCache.siteResourceId, + resourcesToRemove + ) + ) + ); + } + + /////////// Process client-site associations /////////// + + // Get existing site associations + const existingSiteAssociations = await trx + .select({ siteId: clientSitesAssociationsCache.siteId }) + .from(clientSitesAssociationsCache) + .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); + + const existingSiteIds = existingSiteAssociations.map((s) => s.siteId); + + const sitesToAdd = newSiteIds.filter((id) => !existingSiteIds.includes(id)); + const sitesToRemove = existingSiteIds.filter( + (id) => !newSiteIds.includes(id) + ); + + // Insert new site associations + if (sitesToAdd.length > 0) { + await trx.insert(clientSitesAssociationsCache).values( + sitesToAdd.map((siteId) => ({ + clientId: client.clientId, + siteId + })) + ); + } + + // Remove old site associations + if (sitesToRemove.length > 0) { + await trx + .delete(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.clientId, client.clientId), + inArray(clientSitesAssociationsCache.siteId, sitesToRemove) + ) + ); + } + + /////////// Send messages /////////// + + // Get the olm for this client + const [olm] = await trx + .select({ olmId: olms.olmId }) + .from(olms) + .where(eq(olms.clientId, client.clientId)) + .limit(1); + + if (!olm) { + logger.warn( + `Olm not found for client ${client.clientId}, skipping peer updates` + ); + return; + } + + // Handle messages for sites being added + await handleMessagesForClientSites( + client, + olm.olmId, + sitesToAdd, + sitesToRemove, + trx + ); + + // Handle subnet proxy target updates for resources + await handleMessagesForClientResources( + client, + newSiteResources, + resourcesToAdd, + resourcesToRemove, + trx + ); +} + +async function handleMessagesForClientSites( + client: { + clientId: number; + pubKey: string | null; + subnet: string | null; + userId: string | null; + orgId: string; + }, + olmId: string, + sitesToAdd: number[], + sitesToRemove: number[], + trx: Transaction | typeof db = db +): Promise { + if (!client.subnet || !client.pubKey) { + logger.warn( + `Client ${client.clientId} missing subnet or pubKey, skipping peer updates` + ); + return; + } + + const allSiteIds = [...sitesToAdd, ...sitesToRemove]; + if (allSiteIds.length === 0) { + return; + } + + // Get site details for all affected sites + const sitesData = await trx + .select() + .from(sites) + .leftJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId)) + .leftJoin(newts, eq(sites.siteId, newts.siteId)) + .where(inArray(sites.siteId, allSiteIds)); + + let newtJobs: Promise[] = []; + let olmJobs: Promise[] = []; + let exitNodeJobs: Promise[] = []; + + for (const siteData of sitesData) { + const site = siteData.sites; + const exitNode = siteData.exitNodes; + const newt = siteData.newt; + + if (!site.publicKey) { + logger.warn( + `Site ${site.siteId} missing publicKey, skipping peer updates` + ); + continue; + } + + if (!newt) { + logger.warn( + `Newt not found for site ${site.siteId}, skipping peer updates` + ); + continue; + } + + const isAdd = sitesToAdd.includes(site.siteId); + const isRemove = sitesToRemove.includes(site.siteId); + + if (isRemove) { + // Remove peer from newt + newtJobs.push( + newtDeletePeer(site.siteId, client.pubKey, newt.newtId) + ); + try { + // Remove peer from olm + olmJobs.push( + olmDeletePeer( + client.clientId, + site.siteId, + site.publicKey, + olmId + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId}, skipping removal` + ); + } else { + throw error; + } + } + } + + if (isAdd) { + if (!exitNode) { + logger.warn( + `Exit node not found for site ${site.siteId}, skipping peer add` + ); + continue; + } + + // Add peer to newt + const isRelayed = true; // Default to relaying for new connections + newtJobs.push( + newtAddPeer( + site.siteId, + { + publicKey: client.pubKey, + allowedIps: [`${client.subnet.split("/")[0]}/32`], + endpoint: isRelayed ? "" : "" + }, + newt.newtId + ) + ); + + // Get all site resources for this site that the client has access to + const accessibleResources = await trx + .select() + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + siteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId + ) + ) + .where( + and( + eq(siteResources.siteId, site.siteId), + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ) + ); + try { + // Add peer to olm + olmJobs.push( + olmAddPeer( + client.clientId, + { + siteId: site.siteId, + endpoint: + isRelayed || !site.endpoint + ? `${exitNode.endpoint}:21820` + : site.endpoint, + publicKey: site.publicKey, + serverIP: site.address || "", + serverPort: site.listenPort || 0, + remoteSubnets: generateRemoteSubnets( + accessibleResources.map( + ({ siteResources }) => siteResources + ) + ) + }, + olmId + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId}, skipping removal` + ); + } else { + throw error; + } + } + } + + // Update exit node destinations + exitNodeJobs.push( + updateClientSiteDestinations( + { + clientId: client.clientId, + pubKey: client.pubKey, + subnet: client.subnet + }, + trx + ) + ); + } + + await Promise.all(exitNodeJobs); + await Promise.all(newtJobs); + await Promise.all(olmJobs); +} + +async function handleMessagesForClientResources( + client: { + clientId: number; + pubKey: string | null; + subnet: string | null; + userId: string | null; + orgId: string; + }, + allNewResources: SiteResource[], + resourcesToAdd: number[], + resourcesToRemove: number[], + trx: Transaction | typeof db = db +): Promise { + // Group resources by site + const resourcesBySite = new Map(); + + for (const resource of allNewResources) { + if (!resourcesBySite.has(resource.siteId)) { + resourcesBySite.set(resource.siteId, []); + } + resourcesBySite.get(resource.siteId)!.push(resource); + } + + let proxyJobs: Promise[] = []; + let olmJobs: Promise[] = []; + + // Handle additions + if (resourcesToAdd.length > 0) { + const addedResources = allNewResources.filter((r) => + resourcesToAdd.includes(r.siteResourceId) + ); + + // Group by site for proxy updates + const addedBySite = new Map(); + for (const resource of addedResources) { + if (!addedBySite.has(resource.siteId)) { + addedBySite.set(resource.siteId, []); + } + addedBySite.get(resource.siteId)!.push(resource); + } + + // Add subnet proxy targets for each site + for (const [siteId, resources] of addedBySite.entries()) { + const [newt] = await trx + .select({ newtId: newts.newtId }) + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + + if (!newt) { + logger.warn( + `Newt not found for site ${siteId}, skipping proxy updates` + ); + continue; + } + + for (const resource of resources) { + const targets = generateSubnetProxyTargets(resource, [ + { + clientId: client.clientId, + pubKey: client.pubKey, + subnet: client.subnet + } + ]); + + if (targets.length > 0) { + proxyJobs.push(addSubnetProxyTargets(newt.newtId, targets)); + } + + try { + // Add peer data to olm + olmJobs.push( + addPeerData( + client.clientId, + resource.siteId, + generateRemoteSubnets([resource]), + generateAliasConfig([resource]) + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` + ); + } else { + throw error; + } + } + } + } + } + + // Handle removals + if (resourcesToRemove.length > 0) { + const removedResources = await trx + .select() + .from(siteResources) + .where(inArray(siteResources.siteResourceId, resourcesToRemove)); + + // Group by site for proxy updates + const removedBySite = new Map(); + for (const resource of removedResources) { + if (!removedBySite.has(resource.siteId)) { + removedBySite.set(resource.siteId, []); + } + removedBySite.get(resource.siteId)!.push(resource); + } + + // Remove subnet proxy targets for each site + for (const [siteId, resources] of removedBySite.entries()) { + const [newt] = await trx + .select({ newtId: newts.newtId }) + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + + if (!newt) { + logger.warn( + `Newt not found for site ${siteId}, skipping proxy updates` + ); + continue; + } + + for (const resource of resources) { + const targets = generateSubnetProxyTargets(resource, [ + { + clientId: client.clientId, + pubKey: client.pubKey, + subnet: client.subnet + } + ]); + + if (targets.length > 0) { + proxyJobs.push( + removeSubnetProxyTargets(newt.newtId, targets) + ); + } + + try { + // Remove peer data from olm + olmJobs.push( + removePeerData( + client.clientId, + resource.siteId, + generateRemoteSubnets([resource]), + generateAliasConfig([resource]) + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` + ); + } else { + throw error; + } + } + } + } + } + + await Promise.all([...proxyJobs, ...olmJobs]); +} diff --git a/server/routers/auth/securityKey.ts b/server/routers/auth/securityKey.ts index cde2f61a..eed2328d 100644 --- a/server/routers/auth/securityKey.ts +++ b/server/routers/auth/securityKey.ts @@ -52,7 +52,7 @@ setInterval(async () => { await db .delete(webauthnChallenge) .where(lt(webauthnChallenge.expiresAt, now)); - logger.debug("Cleaned up expired security key challenges"); + // logger.debug("Cleaned up expired security key challenges"); } catch (error) { logger.error("Failed to clean up expired security key challenges", error); } diff --git a/server/routers/client/createClient.ts b/server/routers/client/createClient.ts index 908ea689..160006e1 100644 --- a/server/routers/client/createClient.ts +++ b/server/routers/client/createClient.ts @@ -24,18 +24,19 @@ import { isIpInCidr } from "@server/lib/ip"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { generateId } from "@server/auth/sessions/app"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; const createClientParamsSchema = z.strictObject({ - orgId: z.string() - }); + orgId: z.string() +}); const createClientSchema = z.strictObject({ - name: z.string().min(1).max(255), - olmId: z.string(), - secret: z.string(), - subnet: z.string(), - type: z.enum(["olm"]) - }); + name: z.string().min(1).max(255), + olmId: z.string(), + secret: z.string(), + subnet: z.string(), + type: z.enum(["olm"]) +}); export type CreateClientBody = z.infer; @@ -186,6 +187,7 @@ export async function createClient( ); } + let newClient: Client | null = null; await db.transaction(async (trx) => { // TODO: more intelligent way to pick the exit node const exitNodesList = await listExitNodes(orgId); @@ -204,7 +206,7 @@ export async function createClient( ); } - const [newClient] = await trx + [newClient] = await trx .insert(clients) .values({ exitNodeId: randomExitNode.exitNodeId, @@ -244,13 +246,15 @@ export async function createClient( dateCreated: moment().toISOString() }); - return response(res, { - data: newClient, - success: true, - error: false, - message: "Site created successfully", - status: HttpCode.CREATED - }); + await rebuildClientAssociationsFromClient(newClient, trx); + }); + + return response(res, { + data: newClient, + success: true, + error: false, + message: "Site created successfully", + status: HttpCode.CREATED }); } catch (error) { logger.error(error); diff --git a/server/routers/client/createUserClient.ts b/server/routers/client/createUserClient.ts index f49a0783..e5b5ea8f 100644 --- a/server/routers/client/createUserClient.ts +++ b/server/routers/client/createUserClient.ts @@ -21,6 +21,7 @@ import { isValidIP } from "@server/lib/validators"; import { isIpInCidr } from "@server/lib/ip"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; const paramsSchema = z .object({ @@ -191,6 +192,7 @@ export async function createUserClient( ); } + let newClient: Client | null = null; await db.transaction(async (trx) => { // TODO: more intelligent way to pick the exit node const exitNodesList = await listExitNodes(orgId); @@ -209,7 +211,7 @@ export async function createUserClient( ); } - const [newClient] = await trx + [newClient] = await trx .insert(clients) .values({ exitNodeId: randomExitNode.exitNodeId, @@ -232,13 +234,15 @@ export async function createUserClient( clientId: newClient.clientId }); - return response(res, { - data: newClient, - success: true, - error: false, - message: "Site created successfully", - status: HttpCode.CREATED - }); + await rebuildClientAssociationsFromClient(newClient, trx); + }); + + return response(res, { + data: newClient, + success: true, + error: false, + message: "Site created successfully", + status: HttpCode.CREATED }); } catch (error) { logger.error(error); diff --git a/server/routers/client/deleteClient.ts b/server/routers/client/deleteClient.ts index 34019a53..775708ce 100644 --- a/server/routers/client/deleteClient.ts +++ b/server/routers/client/deleteClient.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { db, olms } from "@server/db"; import { clients, clientSitesAssociationsCache } from "@server/db"; import { eq } from "drizzle-orm"; import response from "@server/lib/response"; @@ -9,10 +9,12 @@ import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { sendTerminateClient } from "./terminate"; const deleteClientSchema = z.strictObject({ - clientId: z.string().transform(Number).pipe(z.int().positive()) - }); + clientId: z.string().transform(Number).pipe(z.int().positive()) +}); registry.registerPath({ method: "delete", @@ -68,19 +70,27 @@ export async function deleteClient( } await db.transaction(async (trx) => { - // Delete the client-site associations first - await trx - .delete(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.clientId, clientId)); - // Then delete the client itself - await trx.delete(clients).where(eq(clients.clientId, clientId)); + const [deletedClient] = await trx + .delete(clients) + .where(eq(clients.clientId, clientId)) + .returning(); - // this is a machine client + const [olm] = await trx + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + + // this is a machine client so we also delete the olm if (!client.userId && client.olmId) { - await trx - .delete(clients) - .where(eq(clients.olmId, client.olmId)); + await trx.delete(olms).where(eq(olms.olmId, client.olmId)); + } + + await rebuildClientAssociationsFromClient(deletedClient, trx); + + if (olm) { + await sendTerminateClient(deletedClient.clientId, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion } }); diff --git a/server/routers/client/terminate.ts b/server/routers/client/terminate.ts new file mode 100644 index 00000000..dc49ef05 --- /dev/null +++ b/server/routers/client/terminate.ts @@ -0,0 +1,22 @@ +import { sendToClient } from "#dynamic/routers/ws"; +import { db, olms } from "@server/db"; +import { eq } from "drizzle-orm"; + +export async function sendTerminateClient(clientId: number, olmId?: string | null) { + if (!olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + olmId = olm.olmId; + } + + await sendToClient(olmId, { + type: `olm/terminate`, + data: {} + }); +} diff --git a/server/routers/olm/deleteUserOlm.ts b/server/routers/olm/deleteUserOlm.ts index 88e791db..83a3d16f 100644 --- a/server/routers/olm/deleteUserOlm.ts +++ b/server/routers/olm/deleteUserOlm.ts @@ -1,5 +1,5 @@ import { NextFunction, Request, Response } from "express"; -import { db } from "@server/db"; +import { Client, db } from "@server/db"; import { olms, clients, clientSitesAssociationsCache } from "@server/db"; import { eq } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; @@ -9,6 +9,8 @@ import { z } from "zod"; import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { sendTerminateClient } from "../client/terminate"; const paramsSchema = z .object({ @@ -54,20 +56,30 @@ export async function deleteUserOlm( .from(clients) .where(eq(clients.olmId, olmId)); - // Delete client-site associations for each associated client - for (const client of associatedClients) { - await trx - .delete(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); - } - + let deletedClient: Client | null = null; // Delete all associated clients if (associatedClients.length > 0) { - await trx.delete(clients).where(eq(clients.olmId, olmId)); + [deletedClient] = await trx + .delete(clients) + .where(eq(clients.olmId, olmId)) + .returning(); } // Finally, delete the OLM itself - await trx.delete(olms).where(eq(olms.olmId, olmId)); + const [olm] = await trx + .delete(olms) + .where(eq(olms.olmId, olmId)) + .returning(); + + if (deletedClient) { + await rebuildClientAssociationsFromClient(deletedClient, trx); + if (olm) { + await sendTerminateClient( + deletedClient.clientId, + olm.olmId + ); // the olmId needs to be provided because it cant look it up after deletion + } + } }); return response(res, { diff --git a/server/routers/olm/getUserOlm.ts b/server/routers/olm/getUserOlm.ts index 50b32fd8..aa9b89af 100644 --- a/server/routers/olm/getUserOlm.ts +++ b/server/routers/olm/getUserOlm.ts @@ -1,6 +1,6 @@ import { NextFunction, Request, Response } from "express"; import { db } from "@server/db"; -import { olms, clients, clientSites } from "@server/db"; +import { olms } from "@server/db"; import { eq, and } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 048e6baa..2ee5c120 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -31,6 +31,7 @@ import { getNextAvailableClientSubnet } from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.info("Handling register olm message!"); @@ -60,6 +61,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { olm.name || "User Device", // doNotCreateNewClient ? true : false true // for now never create a new client automatically because we create the users clients when they are added to the org + // this means that the rebuildClientAssociationsFromClient call below issue is not a problem ); client = clientRes; @@ -99,6 +101,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .from(clients) .where(eq(clients.clientId, olm.clientId)) .limit(1); + + [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, client.orgId)) + .limit(1); } if (!client) { @@ -205,13 +213,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { `Found ${sitesData.length} sites for client ${client.clientId}` ); - if (sitesData.length === 0) { - sendToClient(olm.olmId, { - type: "olm/register/no-sites", - data: {} - }); - } - // Process each site for (const { sites: site } of sitesData) { if (!site.exitNodeId) { @@ -462,6 +463,8 @@ async function getOrCreateOrgClient( }); } + await rebuildClientAssociationsFromClient(newClient, trx); // TODO: this will try to messages to the olm which has not connected yet - is that a problem? + client = newClient; } diff --git a/server/routers/siteResource/addClientToSiteResource.ts b/server/routers/siteResource/addClientToSiteResource.ts index 8fb6afdc..587294e5 100644 --- a/server/routers/siteResource/addClientToSiteResource.ts +++ b/server/routers/siteResource/addClientToSiteResource.ts @@ -8,7 +8,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const addClientToSiteResourceBodySchema = z .object({ @@ -136,7 +136,7 @@ export async function addClientToSiteResource( siteResourceId }); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/addRoleToSiteResource.ts b/server/routers/siteResource/addRoleToSiteResource.ts index 859ca5be..542ca535 100644 --- a/server/routers/siteResource/addRoleToSiteResource.ts +++ b/server/routers/siteResource/addRoleToSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const addRoleToSiteResourceBodySchema = z .object({ @@ -146,7 +146,7 @@ export async function addRoleToSiteResource( siteResourceId }); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/addUserToSiteResource.ts b/server/routers/siteResource/addUserToSiteResource.ts index 411d37b4..c9d1f30a 100644 --- a/server/routers/siteResource/addUserToSiteResource.ts +++ b/server/routers/siteResource/addUserToSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const addUserToSiteResourceBodySchema = z .object({ @@ -115,7 +115,7 @@ export async function addUserToSiteResource( siteResourceId }); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/createSiteResource.ts b/server/routers/siteResource/createSiteResource.ts index ecbb7768..1d9cd6aa 100644 --- a/server/routers/siteResource/createSiteResource.ts +++ b/server/routers/siteResource/createSiteResource.ts @@ -17,7 +17,7 @@ import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; import { getUniqueSiteResourceName } from "@server/db/names"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; import { getNextAvailableAliasAddress } from "@server/lib/ip"; const createSiteResourceParamsSchema = z.strictObject({ @@ -278,7 +278,7 @@ export async function createSiteResource( ); } - await rebuildClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role + await rebuildClientAssociationsFromSiteResource(newSiteResource, trx); // we need to call this because we added to the admin role }); if (!newSiteResource) { diff --git a/server/routers/siteResource/deleteSiteResource.ts b/server/routers/siteResource/deleteSiteResource.ts index 75f2c3f2..a7175608 100644 --- a/server/routers/siteResource/deleteSiteResource.ts +++ b/server/routers/siteResource/deleteSiteResource.ts @@ -9,7 +9,7 @@ import { eq, and } from "drizzle-orm"; import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const deleteSiteResourceParamsSchema = z.strictObject({ siteResourceId: z.string().transform(Number).pipe(z.int().positive()), @@ -106,7 +106,7 @@ export async function deleteSiteResource( ); } - await rebuildClientAssociations(removedSiteResource, trx); + await rebuildClientAssociationsFromSiteResource(removedSiteResource, trx); }); logger.info( diff --git a/server/routers/siteResource/removeClientFromSiteResource.ts b/server/routers/siteResource/removeClientFromSiteResource.ts index d46e5d67..c6a5dfe8 100644 --- a/server/routers/siteResource/removeClientFromSiteResource.ts +++ b/server/routers/siteResource/removeClientFromSiteResource.ts @@ -8,7 +8,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const removeClientFromSiteResourceBodySchema = z .object({ @@ -142,7 +142,7 @@ export async function removeClientFromSiteResource( ) ); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/removeRoleFromSiteResource.ts b/server/routers/siteResource/removeRoleFromSiteResource.ts index c4c68e06..0041ed83 100644 --- a/server/routers/siteResource/removeRoleFromSiteResource.ts +++ b/server/routers/siteResource/removeRoleFromSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const removeRoleFromSiteResourceBodySchema = z .object({ @@ -151,7 +151,7 @@ export async function removeRoleFromSiteResource( ) ); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/removeUserFromSiteResource.ts b/server/routers/siteResource/removeUserFromSiteResource.ts index 8a90b752..280a01f2 100644 --- a/server/routers/siteResource/removeUserFromSiteResource.ts +++ b/server/routers/siteResource/removeUserFromSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const removeUserFromSiteResourceBodySchema = z .object({ @@ -121,7 +121,7 @@ export async function removeUserFromSiteResource( ) ); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/setSiteResourceClients.ts b/server/routers/siteResource/setSiteResourceClients.ts index 974b27cc..0a25b7e9 100644 --- a/server/routers/siteResource/setSiteResourceClients.ts +++ b/server/routers/siteResource/setSiteResourceClients.ts @@ -8,7 +8,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, inArray } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const setSiteResourceClientsBodySchema = z .object({ @@ -124,7 +124,7 @@ export async function setSiteResourceClients( .values(clientIds.map((clientId) => ({ clientId, siteResourceId }))); } - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/setSiteResourceRoles.ts b/server/routers/siteResource/setSiteResourceRoles.ts index df44e02b..7aa07de1 100644 --- a/server/routers/siteResource/setSiteResourceRoles.ts +++ b/server/routers/siteResource/setSiteResourceRoles.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and, ne, inArray } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const setSiteResourceRolesBodySchema = z .object({ @@ -147,7 +147,7 @@ export async function setSiteResourceRoles( .values(roleIds.map((roleId) => ({ roleId, siteResourceId }))); } - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/setSiteResourceUsers.ts b/server/routers/siteResource/setSiteResourceUsers.ts index 8ef9a0ab..4dae0ada 100644 --- a/server/routers/siteResource/setSiteResourceUsers.ts +++ b/server/routers/siteResource/setSiteResourceUsers.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const setSiteResourceUsersBodySchema = z .object({ @@ -102,7 +102,7 @@ export async function setSiteResourceUsers( .values(userIds.map((userId) => ({ userId, siteResourceId }))); } - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/updateSiteResource.ts b/server/routers/siteResource/updateSiteResource.ts index d66d2cb8..51a18af9 100644 --- a/server/routers/siteResource/updateSiteResource.ts +++ b/server/routers/siteResource/updateSiteResource.ts @@ -25,7 +25,7 @@ import { } from "@server/lib/ip"; import { getClientSiteResourceAccess, - rebuildClientAssociations + rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const updateSiteResourceParamsSchema = z.strictObject({ @@ -224,7 +224,7 @@ export async function updateSiteResource( ); } - const { mergedAllClients } = await rebuildClientAssociations( + const { mergedAllClients } = await rebuildClientAssociationsFromSiteResource( existingSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below trx ); diff --git a/server/routers/user/addUserRole.ts b/server/routers/user/addUserRole.ts index 915ea64a..9404d94f 100644 --- a/server/routers/user/addUserRole.ts +++ b/server/routers/user/addUserRole.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { clients, db, UserOrg } from "@server/db"; import { userOrgs, roles } from "@server/db"; import { eq, and } from "drizzle-orm"; import response from "@server/lib/response"; @@ -10,11 +10,12 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import stoi from "@server/lib/stoi"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; const addUserRoleParamsSchema = z.strictObject({ - userId: z.string(), - roleId: z.string().transform(stoi).pipe(z.number()) - }); + userId: z.string(), + roleId: z.string().transform(stoi).pipe(z.number()) +}); export type AddUserRoleResponse = z.infer; @@ -72,7 +73,9 @@ export async function addUserRole( const existingUser = await db .select() .from(userOrgs) - .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, role.orgId))) + .where( + and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, role.orgId)) + ) .limit(1); if (existingUser.length === 0) { @@ -108,14 +111,39 @@ export async function addUserRole( ); } - const newUserRole = await db - .update(userOrgs) - .set({ roleId }) - .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, role.orgId))) - .returning(); + let newUserRole: UserOrg | null = null; + await db.transaction(async (trx) => { + [newUserRole] = await trx + .update(userOrgs) + .set({ roleId }) + .where( + and( + eq(userOrgs.userId, userId), + eq(userOrgs.orgId, role.orgId) + ) + ) + .returning(); + + // get the client associated with this user in this org + const [orgClient] = await trx + .select() + .from(clients) + .where( + and( + eq(clients.userId, userId), + eq(clients.orgId, role.orgId) + ) + ) + .limit(1); + + if (orgClient) { + // we just changed the user's role, so we need to rebuild client associations and what they have access to + await rebuildClientAssociationsFromClient(orgClient, trx); + } + }); return response(res, { - data: newUserRole[0], + data: newUserRole, success: true, error: false, message: "Role added to user successfully", diff --git a/server/routers/user/adminRemoveUser.ts b/server/routers/user/adminRemoveUser.ts index 02ad56d6..ae7f9f47 100644 --- a/server/routers/user/adminRemoveUser.ts +++ b/server/routers/user/adminRemoveUser.ts @@ -8,10 +8,11 @@ import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; +import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; const removeUserSchema = z.strictObject({ - userId: z.string() - }); + userId: z.string() +}); export async function adminRemoveUser( req: Request, @@ -50,7 +51,11 @@ export async function adminRemoveUser( ); } - await db.delete(users).where(eq(users.userId, userId)); + await db.transaction(async (trx) => { + await trx.delete(users).where(eq(users.userId, userId)); + + await calculateUserClientsForOrgs(userId, trx); + }); return response(res, { data: null,