diff --git a/messages/en-US.json b/messages/en-US.json index f69d11a8..1e87480c 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -2148,7 +2148,7 @@ "deviceOrganizationsAccess": "Access to all organizations your account has access to", "deviceAuthorize": "Authorize {applicationName}", "deviceConnected": "Device Connected!", - "deviceAuthorizedMessage": "Your device is authorized to access your account.", + "deviceAuthorizedMessage": "Device is authorized to access your account.", "pangolinCloud": "Pangolin Cloud", "viewDevices": "View Devices", "viewDevicesDescription": "Manage your connected devices", @@ -2210,5 +2210,5 @@ "enterIdentifier": "Enter identifier", "identifier": "Identifier", "deviceLoginUseDifferentAccount": "Not you? Use a different account.", - "deviceLoginDeviceRequestingAccessToAccount": "Your device is requesting access to this account." + "deviceLoginDeviceRequestingAccessToAccount": "A device is requesting access to this account." } diff --git a/server/auth/sessions/app.ts b/server/auth/sessions/app.ts index 0e3da100..73b220fa 100644 --- a/server/auth/sessions/app.ts +++ b/server/auth/sessions/app.ts @@ -36,13 +36,15 @@ export async function createSession( const sessionId = encodeHexLowerCase( sha256(new TextEncoder().encode(token)) ); - const session: Session = { - sessionId: sessionId, - userId, - expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(), - issuedAt: new Date().getTime() - }; - await db.insert(sessions).values(session); + const [session] = await db + .insert(sessions) + .values({ + sessionId: sessionId, + userId, + expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(), + issuedAt: new Date().getTime() + }) + .returning(); return session; } diff --git a/server/auth/sessions/verifySession.ts b/server/auth/sessions/verifySession.ts index 68a1f17e..01b32ef6 100644 --- a/server/auth/sessions/verifySession.ts +++ b/server/auth/sessions/verifySession.ts @@ -18,13 +18,19 @@ export async function verifySession(req: Request, forceLogin?: boolean) { user: null }; } + if (res.session.deviceAuthUsed) { + return { + session: null, + user: null + }; + } if (!res.session.issuedAt) { return { session: null, user: null }; } - const mins = 3 * 60 * 1000; + const mins = 5 * 60 * 1000; const now = new Date().getTime(); if (now - res.session.issuedAt > mins) { return { diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index 8ab1b24c..32b1252f 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -11,6 +11,7 @@ import { } from "drizzle-orm/pg-core"; import { InferSelectModel } from "drizzle-orm"; import { randomUUID } from "crypto"; +import { alias } from "yargs"; export const domains = pgTable("domains", { domainId: varchar("domainId").primaryKey(), @@ -40,6 +41,7 @@ export const orgs = pgTable("orgs", { orgId: varchar("orgId").primaryKey(), name: varchar("name").notNull(), subnet: varchar("subnet"), + utilitySubnet: varchar("utilitySubnet"), // this is the subnet for utility addresses createdAt: text("createdAt"), requireTwoFactor: boolean("requireTwoFactor"), maxSessionLengthHours: integer("maxSessionLengthHours"), @@ -209,7 +211,8 @@ export const siteResources = pgTable("siteResources", { destinationPort: integer("destinationPort"), // only for port mode destination: varchar("destination").notNull(), // ip, cidr, hostname; validate against the mode enabled: boolean("enabled").notNull().default(true), - alias: varchar("alias") + alias: varchar("alias"), + aliasAddress: varchar("aliasAddress") }); export const clientSiteResources = pgTable("clientSiteResources", { @@ -284,7 +287,8 @@ export const sessions = pgTable("session", { .notNull() .references(() => users.userId, { onDelete: "cascade" }), expiresAt: bigint("expiresAt", { mode: "number" }).notNull(), - issuedAt: bigint("issuedAt", { mode: "number" }) + issuedAt: bigint("issuedAt", { mode: "number" }), + deviceAuthUsed: boolean("deviceAuthUsed").notNull().default(false) }); export const newtSessions = pgTable("newtSession", { @@ -661,7 +665,8 @@ export const clientSitesAssociationsCache = pgTable( .notNull(), siteId: integer("siteId").notNull(), isRelayed: boolean("isRelayed").notNull().default(false), - endpoint: varchar("endpoint") + endpoint: varchar("endpoint"), + publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes } ); diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index cfffdba7..8b42a461 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -1,6 +1,7 @@ import { randomUUID } from "crypto"; import { InferSelectModel } from "drizzle-orm"; import { sqliteTable, text, integer, index } from "drizzle-orm/sqlite-core"; +import { no } from "zod/v4/locales"; export const domains = sqliteTable("domains", { domainId: text("domainId").primaryKey(), @@ -32,6 +33,7 @@ export const orgs = sqliteTable("orgs", { orgId: text("orgId").primaryKey(), name: text("name").notNull(), subnet: text("subnet"), + utilitySubnet: text("utilitySubnet"), // this is the subnet for utility addresses createdAt: text("createdAt"), requireTwoFactor: integer("requireTwoFactor", { mode: "boolean" }), maxSessionLengthHours: integer("maxSessionLengthHours"), // hours @@ -230,7 +232,8 @@ export const siteResources = sqliteTable("siteResources", { destinationPort: integer("destinationPort"), // only for port mode destination: text("destination").notNull(), // ip, cidr, hostname enabled: integer("enabled", { mode: "boolean" }).notNull().default(true), - alias: text("alias") + alias: text("alias"), + aliasAddress: text("aliasAddress") }); export const clientSiteResources = sqliteTable("clientSiteResources", { @@ -370,7 +373,8 @@ export const clientSitesAssociationsCache = sqliteTable( isRelayed: integer("isRelayed", { mode: "boolean" }) .notNull() .default(false), - endpoint: text("endpoint") + endpoint: text("endpoint"), + publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes } ); @@ -413,7 +417,10 @@ export const sessions = sqliteTable("session", { .notNull() .references(() => users.userId, { onDelete: "cascade" }), expiresAt: integer("expiresAt").notNull(), - issuedAt: integer("issuedAt") + issuedAt: integer("issuedAt"), + deviceAuthUsed: integer("deviceAuthUsed", { mode: "boolean" }) + .notNull() + .default(false) }); export const newtSessions = sqliteTable("newtSession", { diff --git a/server/lib/calculateUserClientsForOrgs.ts b/server/lib/calculateUserClientsForOrgs.ts index f66e3888..4cde8657 100644 --- a/server/lib/calculateUserClientsForOrgs.ts +++ b/server/lib/calculateUserClientsForOrgs.ts @@ -1,8 +1,20 @@ -import { clients, clientSitesAssociationsCache, db, olms, orgs, roleClients, roles, userClients, userOrgs, Transaction } from "@server/db"; +import { + clients, + db, + olms, + orgs, + roleClients, + roles, + userClients, + userOrgs, + Transaction +} from "@server/db"; import { eq, and, notInArray } from "drizzle-orm"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { getNextAvailableClientSubnet } from "@server/lib/ip"; import logger from "@server/logger"; +import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations"; +import { sendTerminateClient } from "@server/routers/client/terminate"; export async function calculateUserClientsForOrgs( userId: string, @@ -88,7 +100,10 @@ export async function calculateUserClientsForOrgs( .where( and( eq(roleClients.roleId, adminRole.roleId), - eq(roleClients.clientId, existingClient.clientId) + eq( + roleClients.clientId, + existingClient.clientId + ) ) ) .limit(1); @@ -110,7 +125,10 @@ export async function calculateUserClientsForOrgs( .where( and( eq(userClients.userId, userId), - eq(userClients.clientId, existingClient.clientId) + eq( + userClients.clientId, + existingClient.clientId + ) ) ) .limit(1); @@ -172,6 +190,11 @@ export async function calculateUserClientsForOrgs( }) .returning(); + await rebuildClientAssociationsFromClient( + newClient, + transaction + ); + // Grant admin role access to the client await transaction.insert(roleClients).values({ roleId: adminRole.roleId, @@ -225,15 +248,8 @@ async function cleanupOrphanedClients( : and(eq(clients.userId, userId)) ); - // Delete client-site associations first, then delete the clients - for (const client of clientsToDelete) { - await trx - .delete(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); - } - if (clientsToDelete.length > 0) { - await trx + const deletedClients = await trx .delete(clients) .where( userOrgIds.length > 0 @@ -242,7 +258,20 @@ async function cleanupOrphanedClients( notInArray(clients.orgId, userOrgIds) ) : and(eq(clients.userId, userId)) - ); + ) + .returning(); + + // Rebuild associations for each deleted client to clean up related data + for (const deletedClient of deletedClients) { + await rebuildClientAssociationsFromClient(deletedClient, trx); + + if (deletedClient.olmId) { + await sendTerminateClient( + deletedClient.clientId, + deletedClient.olmId + ); + } + } if (userOrgIds.length === 0) { logger.debug( @@ -255,4 +284,3 @@ async function cleanupOrphanedClients( } } } - diff --git a/server/lib/createUserAccountOrg.ts b/server/lib/createUserAccountOrg.ts index 1406b935..11f4e247 100644 --- a/server/lib/createUserAccountOrg.ts +++ b/server/lib/createUserAccountOrg.ts @@ -18,6 +18,7 @@ import { defaultRoleAllowedActions } from "@server/routers/role"; import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing"; import { createCustomer } from "#dynamic/lib/billing"; import { usageService } from "@server/lib/billing/usageService"; +import config from "@server/lib/config"; export async function createUserAccountOrg( userId: string, @@ -76,6 +77,8 @@ export async function createUserAccountOrg( .from(domains) .where(eq(domains.configManaged, true)); + const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group; + const newOrg = await trx .insert(orgs) .values({ @@ -83,6 +86,7 @@ export async function createUserAccountOrg( name, // subnet subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs? + utilitySubnet: utilitySubnet, createdAt: new Date().toISOString() }) .returning(); diff --git a/server/lib/ip.ts b/server/lib/ip.ts index 8f09d86e..94b88800 100644 --- a/server/lib/ip.ts +++ b/server/lib/ip.ts @@ -1,4 +1,10 @@ -import { clientSitesAssociationsCache, db, SiteResource, Transaction } from "@server/db"; +import { + clientSitesAssociationsCache, + db, + SiteResource, + siteResources, + Transaction +} from "@server/db"; import { clients, orgs, sites } from "@server/db"; import { and, eq, isNotNull } from "drizzle-orm"; import config from "@server/lib/config"; @@ -281,6 +287,56 @@ export async function getNextAvailableClientSubnet( return subnet; } +export async function getNextAvailableAliasAddress( + orgId: string +): Promise { + const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId)); + + if (!org) { + throw new Error(`Organization with ID ${orgId} not found`); + } + + if (!org.subnet) { + throw new Error(`Organization with ID ${orgId} has no subnet defined`); + } + + if (!org.utilitySubnet) { + throw new Error( + `Organization with ID ${orgId} has no utility subnet defined` + ); + } + + const existingAddresses = await db + .select({ + aliasAddress: siteResources.aliasAddress + }) + .from(siteResources) + .where( + and( + isNotNull(siteResources.aliasAddress), + eq(siteResources.orgId, orgId) + ) + ); + + const addresses = [ + ...existingAddresses.map( + (site) => `${site.aliasAddress?.split("/")[0]}/32` + ), + // reserve a /29 for the dns server and other stuff + `${org.utilitySubnet.split("/")[0]}/29` + ].filter((address) => address !== null) as string[]; + + let subnet = findNextAvailableCidr(addresses, 32, org.utilitySubnet); + if (!subnet) { + throw new Error("No available subnets remaining in space"); + } + + // remove the cidr + subnet = subnet.split("/")[0]; + + return subnet; +} + export async function getNextAvailableOrgSubnet(): Promise { const existingAddresses = await db .select({ @@ -327,9 +383,22 @@ export function generateRemoteSubnets(allSiteResources: SiteResource[]): string[ return Array.from(new Set(remoteSubnets)); } +export type Alias = { alias: string | null; aliasAddress: string | null }; + +export function generateAliasConfig(allSiteResources: SiteResource[]): Alias[] { + let aliasConfigs = allSiteResources + .filter((sr) => sr.alias && sr.aliasAddress && sr.mode == "host") + .map((sr) => ({ + alias: sr.alias, + aliasAddress: sr.aliasAddress + })); + return aliasConfigs; +} + export type SubnetProxyTarget = { - sourcePrefix: string; - destPrefix: string; + sourcePrefix: string; // must be a cidr + destPrefix: string; // must be a cidr + rewriteTo?: string; // must be a cidr portRange?: { min: number; max: number; @@ -372,6 +441,15 @@ export function generateSubnetProxyTargets( destPrefix: `${siteResource.destination}/32` }); } + + if (siteResource.alias && siteResource.aliasAddress) { + // also push a match for the alias address + targets.push({ + sourcePrefix: clientPrefix, + destPrefix: `${siteResource.aliasAddress}/32`, + rewriteTo: `${siteResource.destination}/32` + }); + } } else if (siteResource.mode == "cidr") { targets.push({ sourcePrefix: clientPrefix, @@ -386,4 +464,4 @@ export function generateSubnetProxyTargets( ); return targets; -} \ No newline at end of file +} diff --git a/server/lib/readConfigFile.ts b/server/lib/readConfigFile.ts index 2da8c0a7..fe0dd593 100644 --- a/server/lib/readConfigFile.ts +++ b/server/lib/readConfigFile.ts @@ -229,6 +229,11 @@ export const configSchema = z .default(51820) .transform(stoi) .pipe(portSchema), + clients_start_port: portSchema + .optional() + .default(21820) + .transform(stoi) + .pipe(portSchema), base_endpoint: z .string() .optional() @@ -249,12 +254,14 @@ export const configSchema = z orgs: z .object({ block_size: z.number().positive().gt(0).optional().default(24), - subnet_group: z.string().optional().default("100.90.128.0/24") + subnet_group: z.string().optional().default("100.90.128.0/24"), + utility_subnet_group: z.string().optional().default("100.96.128.0/24") //just hardcode this for now as well }) .optional() .default({ block_size: 24, - subnet_group: "100.90.128.0/24" + subnet_group: "100.90.128.0/24", + utility_subnet_group: "100.96.128.0/24" }), rate_limits: z .object({ diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index 483d3a99..a9b5985a 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -25,20 +25,22 @@ import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; import { + initPeerAddHandshake as holepunchSiteAdd, addPeer as olmAddPeer, deletePeer as olmDeletePeer } from "@server/routers/olm/peers"; import { sendToExitNode } from "#dynamic/lib/exitNodes"; import logger from "@server/logger"; import { + generateAliasConfig, generateRemoteSubnets, generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip"; import { - addRemoteSubnets, + addPeerData, addTargets as addSubnetProxyTargets, - removeRemoteSubnets, + removePeerData, removeTargets as removeSubnetProxyTargets } from "@server/routers/client/targets"; @@ -128,7 +130,7 @@ export async function getClientSiteResourceAccess( }; } -export async function rebuildClientAssociations( +export async function rebuildClientAssociationsFromSiteResource( siteResource: SiteResource, trx: Transaction | typeof db = db ): Promise<{ @@ -463,65 +465,17 @@ async function handleMessagesForSiteClients( } if (isAdd) { - // TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES - // BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS - // AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES - const isRelayed = true; - - newtJobs.push( - newtAddPeer( + await holepunchSiteAdd( + // this will kick off the add peer process for the client + client.clientId, + { siteId, - { - publicKey: client.pubKey, - allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client - // endpoint: isRelayed ? "" : clientSite.endpoint - endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint - }, - newt.newtId - ) - ); - - // TODO: should we have this here? - const allSiteResources = await db // only get the site resources that this client has access to - .select() - .from(siteResources) - .innerJoin( - clientSiteResourcesAssociationsCache, - eq( - siteResources.siteResourceId, - clientSiteResourcesAssociationsCache.siteResourceId - ) - ) - .where( - and( - eq(siteResources.siteId, site.siteId), - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ) - ) - ); - - olmJobs.push( - olmAddPeer( - client.clientId, - { - siteId: site.siteId, - endpoint: - isRelayed || !site.endpoint - ? `${exitNode.endpoint}:21820` - : site.endpoint, - publicKey: site.publicKey, - serverIP: site.address, - serverPort: site.listenPort, - remoteSubnets: generateRemoteSubnets( - allSiteResources.map( - ({ siteResources }) => siteResources - ) - ) - }, - olm.olmId - ) + exitNode: { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint + } + }, + olm.olmId ); } @@ -703,10 +657,11 @@ async function handleSubnetProxyTargetUpdates( for (const client of addedClients) { olmJobs.push( - addRemoteSubnets( + addPeerData( client.clientId, siteResource.siteId, - generateRemoteSubnets([siteResource]) + generateRemoteSubnets([siteResource]), + generateAliasConfig([siteResource]) ) ); } @@ -738,10 +693,11 @@ async function handleSubnetProxyTargetUpdates( for (const client of removedClients) { olmJobs.push( - removeRemoteSubnets( + removePeerData( client.clientId, siteResource.siteId, - generateRemoteSubnets([siteResource]) + generateRemoteSubnets([siteResource]), + generateAliasConfig([siteResource]) ) ); } @@ -750,3 +706,511 @@ async function handleSubnetProxyTargetUpdates( await Promise.all(proxyJobs); } + +export async function rebuildClientAssociationsFromClient( + client: Client, + trx: Transaction | typeof db = db +): Promise { + let newSiteResourceIds: number[] = []; + + // 1. Direct client associations + const directSiteResources = await trx + .select({ siteResourceId: clientSiteResources.siteResourceId }) + .from(clientSiteResources) + .where(eq(clientSiteResources.clientId, client.clientId)); + + newSiteResourceIds.push( + ...directSiteResources.map((r) => r.siteResourceId) + ); + + // 2. User-based and role-based access (if client has a userId) + if (client.userId) { + // Direct user associations + const userSiteResourceIds = await trx + .select({ siteResourceId: userSiteResources.siteResourceId }) + .from(userSiteResources) + .innerJoin( + siteResources, + eq( + siteResources.siteResourceId, + userSiteResources.siteResourceId + ) + ) + .where( + and( + eq(userSiteResources.userId, client.userId), + eq(siteResources.orgId, client.orgId) + ) + ); // this needs to be locked onto this org or else cross-org access could happen + + newSiteResourceIds.push( + ...userSiteResourceIds.map((r) => r.siteResourceId) + ); + + // Role-based access + const roleIds = await trx + .select({ roleId: userOrgs.roleId }) + .from(userOrgs) + .where( + and( + eq(userOrgs.userId, client.userId), + eq(userOrgs.orgId, client.orgId) + ) + ) // this needs to be locked onto this org or else cross-org access could happen + .then((rows) => rows.map((row) => row.roleId)); + + if (roleIds.length > 0) { + const roleSiteResourceIds = await trx + .select({ siteResourceId: roleSiteResources.siteResourceId }) + .from(roleSiteResources) + .where(inArray(roleSiteResources.roleId, roleIds)); + + newSiteResourceIds.push( + ...roleSiteResourceIds.map((r) => r.siteResourceId) + ); + } + } + + // Remove duplicates + newSiteResourceIds = Array.from(new Set(newSiteResourceIds)); + + // Get full siteResource details + const newSiteResources = + newSiteResourceIds.length > 0 + ? await trx + .select() + .from(siteResources) + .where( + inArray(siteResources.siteResourceId, newSiteResourceIds) + ) + : []; + + // Group by siteId for site-level associations + const newSiteIds = Array.from( + new Set(newSiteResources.map((sr) => sr.siteId)) + ); + + /////////// Process client-siteResource associations /////////// + + // Get existing resource associations + const existingResourceAssociations = await trx + .select({ + siteResourceId: clientSiteResourcesAssociationsCache.siteResourceId + }) + .from(clientSiteResourcesAssociationsCache) + .where( + eq(clientSiteResourcesAssociationsCache.clientId, client.clientId) + ); + + const existingSiteResourceIds = existingResourceAssociations.map( + (r) => r.siteResourceId + ); + + const resourcesToAdd = newSiteResourceIds.filter( + (id) => !existingSiteResourceIds.includes(id) + ); + + const resourcesToRemove = existingSiteResourceIds.filter( + (id) => !newSiteResourceIds.includes(id) + ); + + // Insert new associations + if (resourcesToAdd.length > 0) { + await trx.insert(clientSiteResourcesAssociationsCache).values( + resourcesToAdd.map((siteResourceId) => ({ + clientId: client.clientId, + siteResourceId + })) + ); + } + + // Remove old associations + if (resourcesToRemove.length > 0) { + await trx + .delete(clientSiteResourcesAssociationsCache) + .where( + and( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ), + inArray( + clientSiteResourcesAssociationsCache.siteResourceId, + resourcesToRemove + ) + ) + ); + } + + /////////// Process client-site associations /////////// + + // Get existing site associations + const existingSiteAssociations = await trx + .select({ siteId: clientSitesAssociationsCache.siteId }) + .from(clientSitesAssociationsCache) + .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); + + const existingSiteIds = existingSiteAssociations.map((s) => s.siteId); + + const sitesToAdd = newSiteIds.filter((id) => !existingSiteIds.includes(id)); + const sitesToRemove = existingSiteIds.filter( + (id) => !newSiteIds.includes(id) + ); + + // Insert new site associations + if (sitesToAdd.length > 0) { + await trx.insert(clientSitesAssociationsCache).values( + sitesToAdd.map((siteId) => ({ + clientId: client.clientId, + siteId + })) + ); + } + + // Remove old site associations + if (sitesToRemove.length > 0) { + await trx + .delete(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.clientId, client.clientId), + inArray(clientSitesAssociationsCache.siteId, sitesToRemove) + ) + ); + } + + /////////// Send messages /////////// + + // Get the olm for this client + const [olm] = await trx + .select({ olmId: olms.olmId }) + .from(olms) + .where(eq(olms.clientId, client.clientId)) + .limit(1); + + if (!olm) { + logger.warn( + `Olm not found for client ${client.clientId}, skipping peer updates` + ); + return; + } + + // Handle messages for sites being added + await handleMessagesForClientSites( + client, + olm.olmId, + sitesToAdd, + sitesToRemove, + trx + ); + + // Handle subnet proxy target updates for resources + await handleMessagesForClientResources( + client, + newSiteResources, + resourcesToAdd, + resourcesToRemove, + trx + ); +} + +async function handleMessagesForClientSites( + client: { + clientId: number; + pubKey: string | null; + subnet: string | null; + userId: string | null; + orgId: string; + }, + olmId: string, + sitesToAdd: number[], + sitesToRemove: number[], + trx: Transaction | typeof db = db +): Promise { + if (!client.subnet || !client.pubKey) { + logger.warn( + `Client ${client.clientId} missing subnet or pubKey, skipping peer updates` + ); + return; + } + + const allSiteIds = [...sitesToAdd, ...sitesToRemove]; + if (allSiteIds.length === 0) { + return; + } + + // Get site details for all affected sites + const sitesData = await trx + .select() + .from(sites) + .leftJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId)) + .leftJoin(newts, eq(sites.siteId, newts.siteId)) + .where(inArray(sites.siteId, allSiteIds)); + + let newtJobs: Promise[] = []; + let olmJobs: Promise[] = []; + let exitNodeJobs: Promise[] = []; + + for (const siteData of sitesData) { + const site = siteData.sites; + const exitNode = siteData.exitNodes; + const newt = siteData.newt; + + if (!site.publicKey) { + logger.warn( + `Site ${site.siteId} missing publicKey, skipping peer updates` + ); + continue; + } + + if (!newt) { + logger.warn( + `Newt not found for site ${site.siteId}, skipping peer updates` + ); + continue; + } + + const isAdd = sitesToAdd.includes(site.siteId); + const isRemove = sitesToRemove.includes(site.siteId); + + if (isRemove) { + // Remove peer from newt + newtJobs.push( + newtDeletePeer(site.siteId, client.pubKey, newt.newtId) + ); + try { + // Remove peer from olm + olmJobs.push( + olmDeletePeer( + client.clientId, + site.siteId, + site.publicKey, + olmId + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId}, skipping removal` + ); + } else { + throw error; + } + } + } + + if (isAdd) { + if (!exitNode) { + logger.warn( + `Exit node not found for site ${site.siteId}, skipping peer add` + ); + continue; + } + + await holepunchSiteAdd( + // this will kick off the add peer process for the client + client.clientId, + { + siteId: site.siteId, + exitNode: { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint + } + }, + olmId + ); + } + + // Update exit node destinations + exitNodeJobs.push( + updateClientSiteDestinations( + { + clientId: client.clientId, + pubKey: client.pubKey, + subnet: client.subnet + }, + trx + ) + ); + } + + await Promise.all(exitNodeJobs); + await Promise.all(newtJobs); + await Promise.all(olmJobs); +} + +async function handleMessagesForClientResources( + client: { + clientId: number; + pubKey: string | null; + subnet: string | null; + userId: string | null; + orgId: string; + }, + allNewResources: SiteResource[], + resourcesToAdd: number[], + resourcesToRemove: number[], + trx: Transaction | typeof db = db +): Promise { + // Group resources by site + const resourcesBySite = new Map(); + + for (const resource of allNewResources) { + if (!resourcesBySite.has(resource.siteId)) { + resourcesBySite.set(resource.siteId, []); + } + resourcesBySite.get(resource.siteId)!.push(resource); + } + + let proxyJobs: Promise[] = []; + let olmJobs: Promise[] = []; + + // Handle additions + if (resourcesToAdd.length > 0) { + const addedResources = allNewResources.filter((r) => + resourcesToAdd.includes(r.siteResourceId) + ); + + // Group by site for proxy updates + const addedBySite = new Map(); + for (const resource of addedResources) { + if (!addedBySite.has(resource.siteId)) { + addedBySite.set(resource.siteId, []); + } + addedBySite.get(resource.siteId)!.push(resource); + } + + // Add subnet proxy targets for each site + for (const [siteId, resources] of addedBySite.entries()) { + const [newt] = await trx + .select({ newtId: newts.newtId }) + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + + if (!newt) { + logger.warn( + `Newt not found for site ${siteId}, skipping proxy updates` + ); + continue; + } + + for (const resource of resources) { + const targets = generateSubnetProxyTargets(resource, [ + { + clientId: client.clientId, + pubKey: client.pubKey, + subnet: client.subnet + } + ]); + + if (targets.length > 0) { + proxyJobs.push(addSubnetProxyTargets(newt.newtId, targets)); + } + + try { + // Add peer data to olm + olmJobs.push( + addPeerData( + client.clientId, + resource.siteId, + generateRemoteSubnets([resource]), + generateAliasConfig([resource]) + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` + ); + } else { + throw error; + } + } + } + } + } + + // Handle removals + if (resourcesToRemove.length > 0) { + const removedResources = await trx + .select() + .from(siteResources) + .where(inArray(siteResources.siteResourceId, resourcesToRemove)); + + // Group by site for proxy updates + const removedBySite = new Map(); + for (const resource of removedResources) { + if (!removedBySite.has(resource.siteId)) { + removedBySite.set(resource.siteId, []); + } + removedBySite.get(resource.siteId)!.push(resource); + } + + // Remove subnet proxy targets for each site + for (const [siteId, resources] of removedBySite.entries()) { + const [newt] = await trx + .select({ newtId: newts.newtId }) + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + + if (!newt) { + logger.warn( + `Newt not found for site ${siteId}, skipping proxy updates` + ); + continue; + } + + for (const resource of resources) { + const targets = generateSubnetProxyTargets(resource, [ + { + clientId: client.clientId, + pubKey: client.pubKey, + subnet: client.subnet + } + ]); + + if (targets.length > 0) { + proxyJobs.push( + removeSubnetProxyTargets(newt.newtId, targets) + ); + } + + try { + // Remove peer data from olm + olmJobs.push( + removePeerData( + client.clientId, + resource.siteId, + generateRemoteSubnets([resource]), + generateAliasConfig([resource]) + ) + ); + } catch (error) { + // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send + if ( + error instanceof Error && + error.message.includes("not found") + ) { + logger.debug( + `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` + ); + } else { + throw error; + } + } + } + } + } + + await Promise.all([...proxyJobs, ...olmJobs]); +} diff --git a/server/private/routers/hybrid.ts b/server/private/routers/hybrid.ts index 4fc3f97e..24416c5b 100644 --- a/server/private/routers/hybrid.ts +++ b/server/private/routers/hybrid.ts @@ -1369,7 +1369,7 @@ const updateHolePunchSchema = z.object({ port: z.number(), timestamp: z.number(), reachableAt: z.string().optional(), - publicKey: z.string().optional() + publicKey: z.string() // this is the client public key }); hybridRouter.post( "/gerbil/update-hole-punch", @@ -1408,7 +1408,7 @@ hybridRouter.post( ); } - const { olmId, newtId, ip, port, timestamp, token, reachableAt } = + const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } = parsedParams.data; const destinations = await updateAndGenerateEndpointDestinations( @@ -1418,6 +1418,7 @@ hybridRouter.post( port, timestamp, token, + publicKey, exitNode, true ); diff --git a/server/routers/auth/securityKey.ts b/server/routers/auth/securityKey.ts index cde2f61a..eed2328d 100644 --- a/server/routers/auth/securityKey.ts +++ b/server/routers/auth/securityKey.ts @@ -52,7 +52,7 @@ setInterval(async () => { await db .delete(webauthnChallenge) .where(lt(webauthnChallenge.expiresAt, now)); - logger.debug("Cleaned up expired security key challenges"); + // logger.debug("Cleaned up expired security key challenges"); } catch (error) { logger.error("Failed to clean up expired security key challenges", error); } diff --git a/server/routers/auth/startDeviceWebAuth.ts b/server/routers/auth/startDeviceWebAuth.ts index 8897e73f..925df67f 100644 --- a/server/routers/auth/startDeviceWebAuth.ts +++ b/server/routers/auth/startDeviceWebAuth.ts @@ -22,7 +22,7 @@ export type StartDeviceWebAuthBody = z.infer; export type StartDeviceWebAuthResponse = { code: string; - expiresAt: number; + expiresInSeconds: number; }; // Helper function to generate device code in format A1AJ-N5JD @@ -131,10 +131,13 @@ export async function startDeviceWebAuth( createdAt: Date.now() }); + // calculate relative expiration in seconds + const expiresInSeconds = Math.floor((expiresAt - Date.now()) / 1000); + return response(res, { data: { code, - expiresAt + expiresInSeconds }, success: true, error: false, diff --git a/server/routers/auth/verifyDeviceWebAuth.ts b/server/routers/auth/verifyDeviceWebAuth.ts index 715b299a..be0e0ff2 100644 --- a/server/routers/auth/verifyDeviceWebAuth.ts +++ b/server/routers/auth/verifyDeviceWebAuth.ts @@ -5,7 +5,7 @@ import { fromError } from "zod-validation-error"; import HttpCode from "@server/types/HttpCode"; import logger from "@server/logger"; import { response } from "@server/lib/response"; -import { db, deviceWebAuthCodes } from "@server/db"; +import { db, deviceWebAuthCodes, sessions } from "@server/db"; import { eq, and, gt } from "drizzle-orm"; import { encodeHexLowerCase } from "@oslojs/encoding"; import { sha256 } from "@oslojs/crypto/sha2"; @@ -44,20 +44,36 @@ export async function verifyDeviceWebAuth( ): Promise { const { user, session } = req; if (!user || !session) { - logger.debug("Unauthorized attempt to verify device web auth code"); - return next(unauthorized()); + return next(createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")); + } + + if (session.deviceAuthUsed) { + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + "Device web auth code already used for this session" + ) + ); } if (!session.issuedAt) { - logger.debug("Session missing issuedAt timestamp"); - return next(unauthorized()); + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + "Session issuedAt timestamp missing" + ) + ); } // make sure sessions is not older than 5 minutes const now = Date.now(); - if (now - session.issuedAt > 3 * 60 * 1000) { - logger.debug("Session is too old to verify device web auth code"); - return next(unauthorized()); + if (now - session.issuedAt > 5 * 60 * 1000) { + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + "Session is too old to verify device web auth code" + ) + ); } const parsedBody = bodySchema.safeParse(req.body); @@ -134,6 +150,14 @@ export async function verifyDeviceWebAuth( }) .where(eq(deviceWebAuthCodes.codeId, deviceCode.codeId)); + // Also update the session to mark that device auth was used + await db + .update(sessions) + .set({ + deviceAuthUsed: true + }) + .where(eq(sessions.sessionId, session.sessionId)); + return response(res, { data: { success: true, diff --git a/server/routers/client/createClient.ts b/server/routers/client/createClient.ts index 908ea689..160006e1 100644 --- a/server/routers/client/createClient.ts +++ b/server/routers/client/createClient.ts @@ -24,18 +24,19 @@ import { isIpInCidr } from "@server/lib/ip"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { generateId } from "@server/auth/sessions/app"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; const createClientParamsSchema = z.strictObject({ - orgId: z.string() - }); + orgId: z.string() +}); const createClientSchema = z.strictObject({ - name: z.string().min(1).max(255), - olmId: z.string(), - secret: z.string(), - subnet: z.string(), - type: z.enum(["olm"]) - }); + name: z.string().min(1).max(255), + olmId: z.string(), + secret: z.string(), + subnet: z.string(), + type: z.enum(["olm"]) +}); export type CreateClientBody = z.infer; @@ -186,6 +187,7 @@ export async function createClient( ); } + let newClient: Client | null = null; await db.transaction(async (trx) => { // TODO: more intelligent way to pick the exit node const exitNodesList = await listExitNodes(orgId); @@ -204,7 +206,7 @@ export async function createClient( ); } - const [newClient] = await trx + [newClient] = await trx .insert(clients) .values({ exitNodeId: randomExitNode.exitNodeId, @@ -244,13 +246,15 @@ export async function createClient( dateCreated: moment().toISOString() }); - return response(res, { - data: newClient, - success: true, - error: false, - message: "Site created successfully", - status: HttpCode.CREATED - }); + await rebuildClientAssociationsFromClient(newClient, trx); + }); + + return response(res, { + data: newClient, + success: true, + error: false, + message: "Site created successfully", + status: HttpCode.CREATED }); } catch (error) { logger.error(error); diff --git a/server/routers/client/createUserClient.ts b/server/routers/client/createUserClient.ts index f49a0783..e5b5ea8f 100644 --- a/server/routers/client/createUserClient.ts +++ b/server/routers/client/createUserClient.ts @@ -21,6 +21,7 @@ import { isValidIP } from "@server/lib/validators"; import { isIpInCidr } from "@server/lib/ip"; import { listExitNodes } from "#dynamic/lib/exitNodes"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; const paramsSchema = z .object({ @@ -191,6 +192,7 @@ export async function createUserClient( ); } + let newClient: Client | null = null; await db.transaction(async (trx) => { // TODO: more intelligent way to pick the exit node const exitNodesList = await listExitNodes(orgId); @@ -209,7 +211,7 @@ export async function createUserClient( ); } - const [newClient] = await trx + [newClient] = await trx .insert(clients) .values({ exitNodeId: randomExitNode.exitNodeId, @@ -232,13 +234,15 @@ export async function createUserClient( clientId: newClient.clientId }); - return response(res, { - data: newClient, - success: true, - error: false, - message: "Site created successfully", - status: HttpCode.CREATED - }); + await rebuildClientAssociationsFromClient(newClient, trx); + }); + + return response(res, { + data: newClient, + success: true, + error: false, + message: "Site created successfully", + status: HttpCode.CREATED }); } catch (error) { logger.error(error); diff --git a/server/routers/client/deleteClient.ts b/server/routers/client/deleteClient.ts index 34019a53..775708ce 100644 --- a/server/routers/client/deleteClient.ts +++ b/server/routers/client/deleteClient.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { db, olms } from "@server/db"; import { clients, clientSitesAssociationsCache } from "@server/db"; import { eq } from "drizzle-orm"; import response from "@server/lib/response"; @@ -9,10 +9,12 @@ import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { sendTerminateClient } from "./terminate"; const deleteClientSchema = z.strictObject({ - clientId: z.string().transform(Number).pipe(z.int().positive()) - }); + clientId: z.string().transform(Number).pipe(z.int().positive()) +}); registry.registerPath({ method: "delete", @@ -68,19 +70,27 @@ export async function deleteClient( } await db.transaction(async (trx) => { - // Delete the client-site associations first - await trx - .delete(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.clientId, clientId)); - // Then delete the client itself - await trx.delete(clients).where(eq(clients.clientId, clientId)); + const [deletedClient] = await trx + .delete(clients) + .where(eq(clients.clientId, clientId)) + .returning(); - // this is a machine client + const [olm] = await trx + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + + // this is a machine client so we also delete the olm if (!client.userId && client.olmId) { - await trx - .delete(clients) - .where(eq(clients.olmId, client.olmId)); + await trx.delete(olms).where(eq(olms.olmId, client.olmId)); + } + + await rebuildClientAssociationsFromClient(deletedClient, trx); + + if (olm) { + await sendTerminateClient(deletedClient.clientId, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion } }); diff --git a/server/routers/client/targets.ts b/server/routers/client/targets.ts index c94cb680..b5684436 100644 --- a/server/routers/client/targets.ts +++ b/server/routers/client/targets.ts @@ -1,6 +1,6 @@ import { sendToClient } from "#dynamic/routers/ws"; import { db, olms } from "@server/db"; -import { SubnetProxyTarget } from "@server/lib/ip"; +import { Alias, SubnetProxyTarget } from "@server/lib/ip"; import { eq } from "drizzle-orm"; export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) { @@ -33,10 +33,11 @@ export async function updateTargets( }); } -export async function addRemoteSubnets( +export async function addPeerData( clientId: number, siteId: number, remoteSubnets: string[], + aliases: Alias[], olmId?: string ) { if (!olmId) { @@ -52,18 +53,20 @@ export async function addRemoteSubnets( } await sendToClient(olmId, { - type: `olm/wg/peer/add-remote-subnets`, + type: `olm/wg/peer/data/add`, data: { siteId: siteId, - remoteSubnets: remoteSubnets + remoteSubnets: remoteSubnets, + aliases: aliases } }); } -export async function removeRemoteSubnets( +export async function removePeerData( clientId: number, siteId: number, remoteSubnets: string[], + aliases: Alias[], olmId?: string ) { if (!olmId) { @@ -79,21 +82,26 @@ export async function removeRemoteSubnets( } await sendToClient(olmId, { - type: `olm/wg/peer/remove-remote-subnets`, + type: `olm/wg/peer/data/remove`, data: { siteId: siteId, - remoteSubnets: remoteSubnets + remoteSubnets: remoteSubnets, + aliases: aliases } }); } -export async function updateRemoteSubnets( +export async function updatePeerData( clientId: number, siteId: number, remoteSubnets: { oldRemoteSubnets: string[], newRemoteSubnets: string[] }, + aliases: { + oldAliases: Alias[], + newAliases: Alias[] + }, olmId?: string ) { if (!olmId) { @@ -109,10 +117,11 @@ export async function updateRemoteSubnets( } await sendToClient(olmId, { - type: `olm/wg/peer/update-remote-subnets`, + type: `olm/wg/peer/data/update`, data: { siteId: siteId, - ...remoteSubnets + ...remoteSubnets, + ...aliases } }); } diff --git a/server/routers/client/terminate.ts b/server/routers/client/terminate.ts new file mode 100644 index 00000000..dc49ef05 --- /dev/null +++ b/server/routers/client/terminate.ts @@ -0,0 +1,22 @@ +import { sendToClient } from "#dynamic/routers/ws"; +import { db, olms } from "@server/db"; +import { eq } from "drizzle-orm"; + +export async function sendTerminateClient(clientId: number, olmId?: string | null) { + if (!olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + olmId = olm.olmId; + } + + await sendToClient(olmId, { + type: `olm/terminate`, + data: {} + }); +} diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index 031cd23e..e1fa7c4c 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -19,6 +19,8 @@ import { fromError } from "zod-validation-error"; import { validateNewtSessionToken } from "@server/auth/sessions/newt"; import { validateOlmSessionToken } from "@server/auth/sessions/olm"; import { checkExitNodeOrg } from "#dynamic/lib/exitNodes"; +import { updatePeer as updateOlmPeer } from "../olm/peers"; +import { updatePeer as updateNewtPeer } from "../newt/peers"; // Define Zod schema for request validation const updateHolePunchSchema = z.object({ @@ -28,8 +30,9 @@ const updateHolePunchSchema = z.object({ ip: z.string(), port: z.number(), timestamp: z.number(), + publicKey: z.string(), reachableAt: z.string().optional(), - publicKey: z.string().optional() + exitNodePublicKey: z.string().optional() }); // New response type with multi-peer destination support @@ -63,23 +66,26 @@ export async function updateHolePunch( timestamp, token, reachableAt, - publicKey + publicKey, // this is the client's current public key for this session + exitNodePublicKey } = parsedParams.data; let exitNode: ExitNode | undefined; - if (publicKey) { + if (exitNodePublicKey) { // Get the exit node by public key [exitNode] = await db .select() .from(exitNodes) - .where(eq(exitNodes.publicKey, publicKey)); + .where(eq(exitNodes.publicKey, exitNodePublicKey)); } else { // FOR BACKWARDS COMPATIBILITY IF GERBIL IS STILL =<1.1.0 [exitNode] = await db.select().from(exitNodes).limit(1); } if (!exitNode) { - logger.warn(`Exit node not found for publicKey: ${publicKey}`); + logger.warn( + `Exit node not found for publicKey: ${exitNodePublicKey}` + ); return next( createHttpError(HttpCode.NOT_FOUND, "Exit node not found") ); @@ -92,12 +98,13 @@ export async function updateHolePunch( port, timestamp, token, + publicKey, exitNode ); - logger.debug( - `Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}` - ); + // logger.debug( + // `Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}` + // ); // Return the new multi-peer structure return res.status(HttpCode.OK).send({ @@ -121,6 +128,7 @@ export async function updateAndGenerateEndpointDestinations( port: number, timestamp: number, token: string, + publicKey: string, exitNode: ExitNode, checkOrg = false ) { @@ -128,9 +136,9 @@ export async function updateAndGenerateEndpointDestinations( const destinations: PeerDestination[] = []; if (olmId) { - logger.debug( - `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}` - ); + // logger.debug( + // `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}` + // ); const { session, olm: olmSession } = await validateOlmSessionToken(token); @@ -150,7 +158,7 @@ export async function updateAndGenerateEndpointDestinations( throw new Error("Olm not found"); } - const [client] = await db + const [updatedClient] = await db .update(clients) .set({ lastHolePunch: timestamp @@ -158,10 +166,16 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(clients.clientId, olm.clientId)) .returning(); - if (await checkExitNodeOrg(exitNode.exitNodeId, client.orgId) && checkOrg) { + if ( + (await checkExitNodeOrg( + exitNode.exitNodeId, + updatedClient.orgId + )) && + checkOrg + ) { // not allowed logger.warn( - `Exit node ${exitNode.exitNodeId} is not allowed for org ${client.orgId}` + `Exit node ${exitNode.exitNodeId} is not allowed for org ${updatedClient.orgId}` ); throw new Error("Exit node not allowed"); } @@ -171,10 +185,15 @@ export async function updateAndGenerateEndpointDestinations( .select({ siteId: sites.siteId, subnet: sites.subnet, - listenPort: sites.listenPort + listenPort: sites.listenPort, + publicKey: sites.publicKey, + endpoint: clientSitesAssociationsCache.endpoint }) .from(sites) - .innerJoin(clientSitesAssociationsCache, eq(sites.siteId, clientSitesAssociationsCache.siteId)) + .innerJoin( + clientSitesAssociationsCache, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) .where( and( eq(sites.exitNodeId, exitNode.exitNodeId), @@ -184,27 +203,52 @@ export async function updateAndGenerateEndpointDestinations( // Update clientSites for each site on this exit node for (const site of sitesOnExitNode) { - logger.debug( - `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}` - ); + // logger.debug( + // `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}` + // ); - await db + // if the public key or endpoint has changed, update it otherwise continue + if ( + site.endpoint === `${ip}:${port}` && + site.publicKey === publicKey + ) { + continue; + } + + const [updatedClientSitesAssociationsCache] = await db .update(clientSitesAssociationsCache) .set({ - endpoint: `${ip}:${port}` + endpoint: `${ip}:${port}`, + publicKey: publicKey }) .where( and( eq(clientSitesAssociationsCache.clientId, olm.clientId), eq(clientSitesAssociationsCache.siteId, site.siteId) ) + ) + .returning(); + + if ( + updatedClientSitesAssociationsCache.endpoint !== + site.endpoint && // this is the endpoint from the join table not the site + updatedClient.pubKey === publicKey // only trigger if the client's public key matches the current public key which means it has registered so we dont prematurely send the update + ) { + logger.info( + `ClientSitesAssociationsCache for client ${olm.clientId} and site ${site.siteId} endpoint changed from ${site.endpoint} to ${updatedClientSitesAssociationsCache.endpoint}` ); + // Handle any additional logic for endpoint change + handleClientEndpointChange( + olm.clientId, + updatedClientSitesAssociationsCache.endpoint! + ); + } } - logger.debug( - `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}` - ); - if (!client) { + // logger.debug( + // `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}` + // ); + if (!updatedClient) { logger.warn(`Client not found for olm: ${olmId}`); throw new Error("Client not found"); } @@ -219,9 +263,9 @@ export async function updateAndGenerateEndpointDestinations( } } } else if (newtId) { - logger.debug( - `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}` - ); + // logger.debug( + // `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}` + // ); const { session, newt: newtSession } = await validateNewtSessionToken(token); @@ -253,7 +297,10 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(sites.siteId, newt.siteId)) .limit(1); - if (await checkExitNodeOrg(exitNode.exitNodeId, site.orgId) && checkOrg) { + if ( + (await checkExitNodeOrg(exitNode.exitNodeId, site.orgId)) && + checkOrg + ) { // not allowed logger.warn( `Exit node ${exitNode.exitNodeId} is not allowed for org ${site.orgId}` @@ -273,6 +320,18 @@ export async function updateAndGenerateEndpointDestinations( .where(eq(sites.siteId, newt.siteId)) .returning(); + if ( + updatedSite.endpoint != site.endpoint && + updatedSite.publicKey == publicKey + ) { + // only trigger if the site's public key matches the current public key which means it has registered so we dont prematurely send the update + logger.info( + `Site ${newt.siteId} endpoint changed from ${site.endpoint} to ${updatedSite.endpoint}` + ); + // Handle any additional logic for endpoint change + handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!); + } + if (!updatedSite || !updatedSite.subnet) { logger.warn(`Site not found: ${newt.siteId}`); throw new Error("Site not found"); @@ -326,3 +385,143 @@ export async function updateAndGenerateEndpointDestinations( } return destinations; } + +async function handleSiteEndpointChange(siteId: number, newEndpoint: string) { + // Alert all clients connected to this site that the endpoint has changed (only if NOT relayed) + try { + // Get site details + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site || !site.publicKey) { + logger.warn(`Site ${siteId} not found or has no public key`); + return; + } + + // Get all non-relayed clients connected to this site + const connectedClients = await db + .select({ + clientId: clients.clientId, + olmId: olms.olmId, + isRelayed: clientSitesAssociationsCache.isRelayed + }) + .from(clientSitesAssociationsCache) + .innerJoin( + clients, + eq(clientSitesAssociationsCache.clientId, clients.clientId) + ) + .innerJoin(olms, eq(olms.clientId, clients.clientId)) + .where( + and( + eq(clientSitesAssociationsCache.siteId, siteId), + eq(clientSitesAssociationsCache.isRelayed, false) + ) + ); + + // Update each non-relayed client with the new site endpoint + for (const client of connectedClients) { + try { + await updateOlmPeer( + client.clientId, + { + siteId: siteId, + publicKey: site.publicKey, + endpoint: newEndpoint + }, + client.olmId + ); + logger.debug( + `Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}` + ); + } catch (error) { + logger.error( + `Failed to update client ${client.clientId} with new site endpoint: ${error}` + ); + } + } + } catch (error) { + logger.error( + `Error handling site endpoint change for site ${siteId}: ${error}` + ); + } +} + +async function handleClientEndpointChange( + clientId: number, + newEndpoint: string +) { + // Alert all sites connected to this client that the endpoint has changed (only if NOT relayed) + try { + // Get client details + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, clientId)) + .limit(1); + + if (!client || !client.pubKey) { + logger.warn(`Client ${clientId} not found or has no public key`); + return; + } + + // Get all non-relayed sites connected to this client + const connectedSites = await db + .select({ + siteId: sites.siteId, + newtId: newts.newtId, + isRelayed: clientSitesAssociationsCache.isRelayed, + subnet: clients.subnet + }) + .from(clientSitesAssociationsCache) + .innerJoin( + sites, + eq(clientSitesAssociationsCache.siteId, sites.siteId) + ) + .innerJoin(newts, eq(newts.siteId, sites.siteId)) + .innerJoin( + clients, + eq(clientSitesAssociationsCache.clientId, clients.clientId) + ) + .where( + and( + eq(clientSitesAssociationsCache.clientId, clientId), + eq(clientSitesAssociationsCache.isRelayed, false) + ) + ); + + // Update each non-relayed site with the new client endpoint + for (const siteData of connectedSites) { + try { + if (!siteData.subnet) { + logger.warn( + `Client ${clientId} has no subnet, skipping update for site ${siteData.siteId}` + ); + continue; + } + + await updateNewtPeer( + siteData.siteId, + client.pubKey, + { + endpoint: newEndpoint + }, + siteData.newtId + ); + logger.debug( + `Updated site ${siteData.siteId} with new client ${clientId} endpoint: ${newEndpoint}` + ); + } catch (error) { + logger.error( + `Failed to update site ${siteData.siteId} with new client endpoint: ${error}` + ); + } + } + } catch (error) { + logger.error( + `Error handling client endpoint change for client ${clientId}: ${error}` + ); + } +} diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index 35c65716..31ab9f6f 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -79,12 +79,12 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { // TODO: somehow we should make sure a recent hole punch has happened if this occurs (hole punch could be from the last restart if done quickly) } - // if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) { - // logger.warn( - // `Site ${existingSite.siteId} last hole punch is too old, skipping` - // ); - // return; - // } + if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) { + logger.warn( + `Site ${existingSite.siteId} last hole punch is too old, skipping` + ); + return; + } // update the endpoint and the public key const [site] = await db @@ -275,6 +275,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { resource, resourceClients ); + targetsToSend.push(...resourceTargets); } diff --git a/server/routers/olm/deleteUserOlm.ts b/server/routers/olm/deleteUserOlm.ts index 88e791db..83a3d16f 100644 --- a/server/routers/olm/deleteUserOlm.ts +++ b/server/routers/olm/deleteUserOlm.ts @@ -1,5 +1,5 @@ import { NextFunction, Request, Response } from "express"; -import { db } from "@server/db"; +import { Client, db } from "@server/db"; import { olms, clients, clientSitesAssociationsCache } from "@server/db"; import { eq } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; @@ -9,6 +9,8 @@ import { z } from "zod"; import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { sendTerminateClient } from "../client/terminate"; const paramsSchema = z .object({ @@ -54,20 +56,30 @@ export async function deleteUserOlm( .from(clients) .where(eq(clients.olmId, olmId)); - // Delete client-site associations for each associated client - for (const client of associatedClients) { - await trx - .delete(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); - } - + let deletedClient: Client | null = null; // Delete all associated clients if (associatedClients.length > 0) { - await trx.delete(clients).where(eq(clients.olmId, olmId)); + [deletedClient] = await trx + .delete(clients) + .where(eq(clients.olmId, olmId)) + .returning(); } // Finally, delete the OLM itself - await trx.delete(olms).where(eq(olms.olmId, olmId)); + const [olm] = await trx + .delete(olms) + .where(eq(olms.olmId, olmId)) + .returning(); + + if (deletedClient) { + await rebuildClientAssociationsFromClient(deletedClient, trx); + if (olm) { + await sendTerminateClient( + deletedClient.clientId, + olm.olmId + ); // the olmId needs to be provided because it cant look it up after deletion + } + } }); return response(res, { diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index c26f5936..cea8386c 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -1,9 +1,16 @@ import { generateSessionToken } from "@server/auth/sessions/app"; -import { db } from "@server/db"; +import { + clients, + db, + ExitNode, + exitNodes, + sites, + clientSitesAssociationsCache +} from "@server/db"; import { olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; import response from "@server/lib/response"; -import { eq } from "drizzle-orm"; +import { eq, inArray } from "drizzle-orm"; import { NextFunction, Request, Response } from "express"; import createHttpError from "http-errors"; import { z } from "zod"; @@ -15,11 +22,13 @@ import { import { verifyPassword } from "@server/auth/password"; import logger from "@server/logger"; import config from "@server/lib/config"; +import { listExitNodes } from "#dynamic/lib/exitNodes"; export const olmGetTokenBodySchema = z.object({ olmId: z.string(), secret: z.string(), - token: z.string().optional() + token: z.string().optional(), + orgId: z.string().optional() }); export type OlmGetTokenBody = z.infer; @@ -40,7 +49,7 @@ export async function getOlmToken( ); } - const { olmId, secret, token } = parsedBody.data; + const { olmId, secret, token, orgId } = parsedBody.data; try { if (token) { @@ -61,11 +70,12 @@ export async function getOlmToken( } } - const existingOlmRes = await db + const [existingOlm] = await db .select() .from(olms) .where(eq(olms.olmId, olmId)); - if (!existingOlmRes || !existingOlmRes.length) { + + if (!existingOlm) { return next( createHttpError( HttpCode.BAD_REQUEST, @@ -74,12 +84,11 @@ export async function getOlmToken( ); } - const existingOlm = existingOlmRes[0]; - const validSecret = await verifyPassword( secret, existingOlm.secretHash ); + if (!validSecret) { if (config.getRawConfig().app.log_failed_attempts) { logger.info( @@ -96,11 +105,113 @@ export async function getOlmToken( const resToken = generateSessionToken(); await createOlmSession(resToken, existingOlm.olmId); + let orgIdToUse = orgId; + let clientIdToUse; + if (!orgIdToUse) { + if (!existingOlm.clientId) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Olm is not associated with a client, orgId is required" + ) + ); + } + + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, existingOlm.clientId)) + .limit(1); + + if (!client) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Olm's associated client not found, orgId is required" + ) + ); + } + + orgIdToUse = client.orgId; + clientIdToUse = client.clientId; + } else { + // we did provide the org + const [client] = await db + .select() + .from(clients) + .where(eq(clients.orgId, orgIdToUse)) + .limit(1); + + if (!client) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "No client found for provided orgId" + ) + ); + } + + if (existingOlm.clientId !== client.clientId) { + // we only need to do this if the client is changing + + logger.debug( + `Switching olm client ${existingOlm.olmId} to org ${orgId} for user ${existingOlm.userId}` + ); + + await db + .update(olms) + .set({ + clientId: client.clientId + }) + .where(eq(olms.olmId, existingOlm.olmId)); + } + + clientIdToUse = client.clientId; + } + + // Get all exit nodes from sites where the client has peers + const clientSites = await db + .select() + .from(clientSitesAssociationsCache) + .innerJoin( + sites, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) + .where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!)); + + // Extract unique exit node IDs + const exitNodeIds = Array.from( + new Set( + clientSites + .map(({ sites: site }) => site.exitNodeId) + .filter((id): id is number => id !== null) + ) + ); + + let allExitNodes: ExitNode[] = []; + if (exitNodeIds.length > 0) { + allExitNodes = await db + .select() + .from(exitNodes) + .where(inArray(exitNodes.exitNodeId, exitNodeIds)); + } + + const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => { + return { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint + }; + }); + logger.debug("Token created successfully"); - return response<{ token: string }>(res, { + return response<{ + token: string; + exitNodes: { publicKey: string; endpoint: string }[]; + }>(res, { data: { - token: resToken + token: resToken, + exitNodes: exitNodesHpData }, success: true, error: false, diff --git a/server/routers/olm/getUserOlm.ts b/server/routers/olm/getUserOlm.ts index 50b32fd8..aa9b89af 100644 --- a/server/routers/olm/getUserOlm.ts +++ b/server/routers/olm/getUserOlm.ts @@ -1,6 +1,6 @@ import { NextFunction, Request, Response } from "express"; import { db } from "@server/db"; -import { olms, clients, clientSites } from "@server/db"; +import { olms } from "@server/db"; import { eq, and } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index ab503d4c..4bcbbb8b 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -5,6 +5,8 @@ import { clients, Olm } from "@server/db"; import { eq, lt, isNull, and, or } from "drizzle-orm"; import logger from "@server/logger"; import { validateSessionToken } from "@server/auth/sessions/app"; +import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; +import { sendTerminateClient } from "../client/terminate"; // Track if the offline checker interval is running let offlineCheckerInterval: NodeJS.Timeout | null = null; @@ -57,6 +59,9 @@ export const startOlmOfflineChecker = (): void => { // Send a disconnect message to the client if connected try { + await sendTerminateClient(offlineClient.clientId, offlineClient.olmId); // terminate first + // wait a moment to ensure the message is sent + await new Promise(resolve => setTimeout(resolve, 1000)); await disconnectClient(offlineClient.olmId); } catch (error) { logger.error( @@ -110,6 +115,36 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { logger.warn("User ID mismatch for olm ping"); return; } + + // get the client + const [client] = await db + .select() + .from(clients) + .where( + and( + eq(clients.olmId, olm.olmId), + eq(clients.userId, olm.userId) + ) + ) + .limit(1); + + if (!client) { + logger.warn("Client not found for olm ping"); + return; + } + + const policyCheck = await checkOrgAccessPolicy({ + orgId: client.orgId, + userId: olm.userId, + session: userToken // this is the user token passed in the message + }); + + if (!policyCheck.allowed) { + logger.warn( + `Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}` + ); + return; + } } if (!olm.clientId) { diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 5c438e4f..696da748 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -3,6 +3,7 @@ import { clientSiteResourcesAssociationsCache, db, ExitNode, + Org, orgs, roleClients, roles, @@ -25,77 +26,88 @@ import { and, eq, inArray, isNull } from "drizzle-orm"; import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; import { listExitNodes } from "#dynamic/lib/exitNodes"; -import { getNextAvailableClientSubnet } from "@server/lib/ip"; +import { + generateAliasConfig, + getNextAvailableClientSubnet +} from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; +import { validateSessionToken } from "@server/auth/sessions/app"; +import config from "@server/lib/config"; export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.info("Handling register olm message!"); const { message, client: c, sendToClient } = context; const olm = c as Olm; - const now = new Date().getTime() / 1000; + const now = Math.floor(Date.now() / 1000); if (!olm) { logger.warn("Olm not found"); return; } - const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient } = - message.data; - let client: Client; + const { publicKey, relay, olmVersion, orgId, userToken } = message.data; - if (orgId) { - try { - client = await getOrCreateOrgClient( - orgId, - olm.userId, - olm.olmId, - olm.name || "User Device", - // doNotCreateNewClient ? true : false - true // for now never create a new client automatically because we create the users clients when they are added to the org - ); - } catch (err) { - logger.error( - `Error switching olm client ${olm.olmId} to org ${orgId}: ${err}` - ); - return; - } - - if (!client) { - logger.warn("Client not found"); - return; - } - - logger.debug( - `Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}` - ); - - await db - .update(olms) - .set({ - clientId: client.clientId - }) - .where(eq(olms.olmId, olm.olmId)); - } else { - if (!olm.clientId) { - logger.warn("Olm has no client ID!"); - return; - } - - logger.debug(`Using last connected org for client ${olm.clientId}`); - - [client] = await db - .select() - .from(clients) - .where(eq(clients.clientId, olm.clientId)) - .limit(1); + if (!olm.clientId) { + logger.warn("Olm client ID not found"); + return; } + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, olm.clientId)) + .limit(1); + if (!client) { logger.warn("Client ID not found"); return; } + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, client.orgId)) + .limit(1); + + if (!org) { + logger.warn("Org not found"); + return; + } + + if (orgId) { + if (!olm.userId) { + logger.warn("Olm has no user ID"); + return; + } + + const { session: userSession, user } = + await validateSessionToken(userToken); + if (!userSession || !user) { + logger.warn("Invalid user session for olm register"); + return; // by returning here we just ignore the ping and the setInterval will force it to disconnect + } + if (user.userId !== olm.userId) { + logger.warn("User ID mismatch for olm register"); + return; + } + + const policyCheck = await checkOrgAccessPolicy({ + orgId: orgId, + userId: olm.userId, + session: userToken // this is the user token passed in the message + }); + + if (!policyCheck.allowed) { + logger.warn( + `Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}` + ); + return; + } + } + logger.debug( `Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}` ); @@ -105,41 +117,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { return; } - if (client.exitNodeId) { - // TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER - - // Get the exit node - const allExitNodes = await listExitNodes(client.orgId, true); // FILTER THE ONLINE ONES - - const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => { - return { - publicKey: exitNode.publicKey, - endpoint: exitNode.endpoint - }; - }); - - // Send holepunch message - await sendToClient(olm.olmId, { - type: "olm/wg/holepunch/all", - data: { - exitNodes: exitNodesHpData - } - }); - - if (!olmVersion) { - // THIS IS FOR BACKWARDS COMPATIBILITY - // THE OLDER CLIENTS DID NOT SEND THE VERSION - await sendToClient(olm.olmId, { - type: "olm/wg/holepunch", - data: { - serverPubKey: allExitNodes[0].publicKey, - endpoint: allExitNodes[0].endpoint - } - }); - } - } - - if (olmVersion) { + if (olmVersion && olm.version !== olmVersion) { await db .update(olms) .set({ @@ -148,11 +126,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .where(eq(olms.olmId, olm.olmId)); } - // if (now - (client.lastHolePunch || 0) > 6) { - // logger.warn("Client last hole punch is too old, skipping all sites"); - // return; - // } - if (client.pubKey !== publicKey) { logger.info( "Public key mismatch. Updating public key and clearing session info..." @@ -190,15 +163,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { `Found ${sitesData.length} sites for client ${client.clientId}` ); - if (sitesData.length === 0) { - sendToClient(olm.olmId, { - type: "olm/register/no-sites", - data: {} - }); + // this prevents us from accepting a register from an olm that has not hole punched yet. + // the olm will pump the register so we can keep checking + // TODO: I still think there is a better way to do this rather than locking it out here but ??? + if (now - (client.lastHolePunch || 0) > 5 && sitesData.length > 0) { + logger.warn( + "Client last hole punch is too old and we have sites to send; skipping this register" + ); + return; } // Process each site - for (const { sites: site } of sitesData) { + for (const { sites: site, clientSitesAssociationsCache: association } of sitesData) { if (!site.exitNodeId) { logger.warn( `Site ${site.siteId} does not have exit node, skipping` @@ -261,7 +237,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { ); } - let endpoint = site.endpoint; + let relayEndpoint: string | undefined = undefined; if (relay) { const [exitNode] = await db .select() @@ -272,7 +248,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.warn(`Exit node not found for site ${site.siteId}`); continue; } - endpoint = `${exitNode.endpoint}:21820`; + relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`; } const allSiteResources = await db // only get the site resources that this client has access to @@ -298,11 +274,17 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { // Add site configuration to the array siteConfigurations.push({ siteId: site.siteId, - endpoint: endpoint, + relayEndpoint: relayEndpoint, // this can be undefined now if not relayed + endpoint: site.endpoint, publicKey: site.publicKey, serverIP: site.address, serverPort: site.listenPort, - remoteSubnets: generateRemoteSubnets(allSiteResources.map(({ siteResources }) => siteResources)) + remoteSubnets: generateRemoteSubnets( + allSiteResources.map(({ siteResources }) => siteResources) + ), + aliases: generateAliasConfig( + allSiteResources.map(({ siteResources }) => siteResources) + ) }); } @@ -318,128 +300,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { type: "olm/wg/connect", data: { sites: siteConfigurations, - tunnelIP: client.subnet + tunnelIP: client.subnet, + utilitySubnet: org.utilitySubnet } }, broadcast: false, excludeSender: false }; }; - -async function getOrCreateOrgClient( - orgId: string, - userId: string | null, - olmId: string, - name: string, - doNotCreateNewClient: boolean, - trx: Transaction | typeof db = db -): Promise { - // get the org - const [org] = await trx - .select() - .from(orgs) - .where(eq(orgs.orgId, orgId)) - .limit(1); - - if (!org) { - throw new Error("Org not found"); - } - - if (!org.subnet) { - throw new Error("Org has no subnet defined"); - } - - // check if the user has a client in the org and if not then create a client for them - const [existingClient] = await trx - .select() - .from(clients) - .where( - and( - eq(clients.orgId, orgId), - userId ? eq(clients.userId, userId) : isNull(clients.userId), // we dont check the user id if it is null because the olm is not tied to a user? - eq(clients.olmId, olmId) - ) - ) // checking the olmid here because we want to create a new client PER OLM PER ORG - .limit(1); - - let client = existingClient; - if (!client && !doNotCreateNewClient) { - logger.debug( - `Client does not exist in org ${orgId}, creating new client for user ${userId}` - ); - - if (!userId) { - throw new Error("User ID is required to create client in org"); - } - - // Verify that the user belongs to the org - const [userOrg] = await trx - .select() - .from(userOrgs) - .where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId))) - .limit(1); - - if (!userOrg) { - throw new Error("User does not belong to org"); - } - - // TODO: more intelligent way to pick the exit node - const exitNodesList = await listExitNodes(orgId); - const randomExitNode = - exitNodesList[Math.floor(Math.random() * exitNodesList.length)]; - - const [adminRole] = await trx - .select() - .from(roles) - .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) - .limit(1); - - if (!adminRole) { - throw new Error("Admin role not found"); - } - - const newSubnet = await getNextAvailableClientSubnet(orgId); - if (!newSubnet) { - throw new Error("No available subnet found"); - } - - const subnet = newSubnet.split("/")[0]; - const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org - - const [newClient] = await trx - .insert(clients) - .values({ - exitNodeId: randomExitNode.exitNodeId, - orgId, - name, - subnet: updatedSubnet, - type: "olm", - userId: userId, - olmId: olmId // to lock this client to the olm even as the olm moves between clients in different orgs - }) - .returning(); - - await trx.insert(roleClients).values({ - roleId: adminRole.roleId, - clientId: newClient.clientId - }); - - await trx.insert(userClients).values({ - // we also want to make sure that the user can see their own client if they are not an admin - userId, - clientId: newClient.clientId - }); - - if (userOrg.roleId != adminRole.roleId) { - // make sure the user can access the client - trx.insert(userClients).values({ - userId, - clientId: newClient.clientId - }); - } - - client = newClient; - } - - return client; -} diff --git a/server/routers/olm/handleOlmRelayMessage.ts b/server/routers/olm/handleOlmRelayMessage.ts index 153c4e7c..595b35ba 100644 --- a/server/routers/olm/handleOlmRelayMessage.ts +++ b/server/routers/olm/handleOlmRelayMessage.ts @@ -2,7 +2,7 @@ import { db, exitNodes, sites } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; import { clients, clientSitesAssociationsCache, Olm } from "@server/db"; import { and, eq } from "drizzle-orm"; -import { updatePeer } from "../newt/peers"; +import { updatePeer as newtUpdatePeer } from "../newt/peers"; import logger from "@server/logger"; export const handleOlmRelayMessage: MessageHandler = async (context) => { @@ -79,18 +79,19 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => { ); // update the peer on the exit node - await updatePeer(siteId, client.pubKey, { - endpoint: "" // this removes the endpoint + await newtUpdatePeer(siteId, client.pubKey, { + endpoint: "" // this removes the endpoint so the exit node knows to relay }); - sendToClient(olm.olmId, { - type: "olm/wg/peer/relay", - data: { - siteId: siteId, - endpoint: exitNode.endpoint, - publicKey: exitNode.publicKey - } - }); - - return; + return { + message: { + type: "olm/wg/peer/relay", + data: { + siteId: siteId, + relayEndpoint: exitNode.endpoint + } + }, + broadcast: false, + excludeSender: false + }; }; diff --git a/server/routers/olm/handleOlmServerPeerAddMessage.ts b/server/routers/olm/handleOlmServerPeerAddMessage.ts new file mode 100644 index 00000000..2e5009eb --- /dev/null +++ b/server/routers/olm/handleOlmServerPeerAddMessage.ts @@ -0,0 +1,187 @@ +import { + Client, + clientSiteResourcesAssociationsCache, + db, + ExitNode, + Org, + orgs, + roleClients, + roles, + siteResources, + Transaction, + userClients, + userOrgs, + users +} from "@server/db"; +import { MessageHandler } from "@server/routers/ws"; +import { + clients, + clientSitesAssociationsCache, + exitNodes, + Olm, + olms, + sites +} from "@server/db"; +import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm"; +import { addPeer, deletePeer } from "../newt/peers"; +import logger from "@server/logger"; +import { listExitNodes } from "#dynamic/lib/exitNodes"; +import { + generateAliasConfig, + getNextAvailableClientSubnet +} from "@server/lib/ip"; +import { generateRemoteSubnets } from "@server/lib/ip"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; +import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; +import { validateSessionToken } from "@server/auth/sessions/app"; +import config from "@server/lib/config"; +import { + addPeer as newtAddPeer, + deletePeer as newtDeletePeer +} from "@server/routers/newt/peers"; + +export const handleOlmServerPeerAddMessage: MessageHandler = async ( + context +) => { + logger.info("Handling register olm message!"); + const { message, client: c, sendToClient } = context; + const olm = c as Olm; + + const now = Math.floor(Date.now() / 1000); + + if (!olm) { + logger.warn("Olm not found"); + return; + } + + const { siteId } = message.data; + + // get the site + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site) { + logger.error( + `handleOlmServerPeerAddMessage: Site with ID ${siteId} not found` + ); + return; + } + + if (!site.endpoint) { + logger.error( + `handleOlmServerPeerAddMessage: Site with ID ${siteId} has no endpoint` + ); + return; + } + + // get the client + + if (!olm.clientId) { + logger.error( + `handleOlmServerPeerAddMessage: Olm with ID ${olm.olmId} has no clientId` + ); + return; + } + + const [client] = await db + .select() + .from(clients) + .where(and(eq(clients.clientId, olm.clientId))) + .limit(1); + + if (!client) { + logger.error( + `handleOlmServerPeerAddMessage: Client with ID ${olm.clientId} not found` + ); + return; + } + + if (!client.pubKey) { + logger.error( + `handleOlmServerPeerAddMessage: Client with ID ${client.clientId} has no public key` + ); + return; + } + + let endpoint: string | null = null; + + + const currentSessionSiteAssociationCaches = await db + .select() + .from(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.clientId, client.clientId), + isNotNull(clientSitesAssociationsCache.endpoint), + eq(clientSitesAssociationsCache.publicKey, client.pubKey) // limit it to the current session its connected with otherwise the endpoint could be stale + ) + ); + + // pick an endpoint + for (const assoc of currentSessionSiteAssociationCaches) { + if (assoc.endpoint) { + endpoint = assoc.endpoint; + break; + } + } + + if (!endpoint) { + logger.error( + `handleOlmServerPeerAddMessage: No endpoint found for client ${client.clientId}` + ); + return; + } + + // NOTE: here we are always starting direct to the peer and will relay later + + await newtAddPeer(siteId, { + publicKey: client.pubKey, + allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client + endpoint: endpoint // this is the client's endpoint with reference to the site's exit node + }); + + const allSiteResources = await db // only get the site resources that this client has access to + .select() + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + siteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId + ) + ) + .where( + and( + eq(siteResources.siteId, site.siteId), + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ) + ); + + // Return connect message with all site configurations + return { + message: { + type: "olm/wg/peer/add", + data: { + siteId: site.siteId, + endpoint: site.endpoint, + publicKey: site.publicKey, + serverIP: site.address, + serverPort: site.listenPort, + remoteSubnets: generateRemoteSubnets( + allSiteResources.map(({ siteResources }) => siteResources) + ), + aliases: generateAliasConfig( + allSiteResources.map(({ siteResources }) => siteResources) + ) + } + }, + broadcast: false, + excludeSender: false + }; +}; diff --git a/server/routers/olm/handleOlmUnRelayMessage.ts b/server/routers/olm/handleOlmUnRelayMessage.ts new file mode 100644 index 00000000..5f47a095 --- /dev/null +++ b/server/routers/olm/handleOlmUnRelayMessage.ts @@ -0,0 +1,96 @@ +import { db, exitNodes, sites } from "@server/db"; +import { MessageHandler } from "@server/routers/ws"; +import { clients, clientSitesAssociationsCache, Olm } from "@server/db"; +import { and, eq } from "drizzle-orm"; +import { updatePeer as newtUpdatePeer } from "../newt/peers"; +import logger from "@server/logger"; + +export const handleOlmUnRelayMessage: MessageHandler = async (context) => { + const { message, client: c, sendToClient } = context; + const olm = c as Olm; + + logger.info("Handling unrelay olm message!"); + + if (!olm) { + logger.warn("Olm not found"); + return; + } + + if (!olm.clientId) { + logger.warn("Olm has no site!"); // TODO: Maybe we create the site here? + return; + } + + const clientId = olm.clientId; + + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, clientId)) + .limit(1); + + if (!client) { + logger.warn("Client not found"); + return; + } + + // make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old + if (!client.pubKey) { + logger.warn("Client has no endpoint or listen port"); + return; + } + + const { siteId } = message.data; + + // Get the site + const [site] = await db + .select() + .from(sites) + .where(eq(sites.siteId, siteId)) + .limit(1); + + if (!site) { + logger.warn("Site not found or has no exit node"); + return; + } + + const [clientSiteAssociation] = await db + .update(clientSitesAssociationsCache) + .set({ + isRelayed: false + }) + .where( + and( + eq(clientSitesAssociationsCache.clientId, olm.clientId), + eq(clientSitesAssociationsCache.siteId, siteId) + ) + ) + .returning(); + + if (!clientSiteAssociation) { + logger.warn("Client-Site association not found"); + return; + } + + if (!clientSiteAssociation.endpoint) { + logger.warn("Client-Site association has no endpoint, cannot unrelay"); + return; + } + + // update the peer on the exit node + await newtUpdatePeer(siteId, client.pubKey, { + endpoint: clientSiteAssociation.endpoint // this is the endpoint of the client to connect directly to the exit node + }); + + return { + message: { + type: "olm/wg/peer/unrelay", + data: { + siteId: siteId, + endpoint: site.endpoint + } + }, + broadcast: false, + excludeSender: false + }; +}; diff --git a/server/routers/olm/index.ts b/server/routers/olm/index.ts index 7adbf859..e671dd42 100644 --- a/server/routers/olm/index.ts +++ b/server/routers/olm/index.ts @@ -7,3 +7,5 @@ export * from "./deleteUserOlm"; export * from "./listUserOlms"; export * from "./deleteUserOlm"; export * from "./getUserOlm"; +export * from "./handleOlmServerPeerAddMessage"; +export * from "./handleOlmUnRelayMessage"; \ No newline at end of file diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 1daed53a..87c634cc 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -3,6 +3,7 @@ import { clients, olms, newts, sites } from "@server/db"; import { eq } from "drizzle-orm"; import { sendToClient } from "#dynamic/routers/ws"; import logger from "@server/logger"; +import { exit } from "process"; export async function addPeer( clientId: number, @@ -78,8 +79,8 @@ export async function updatePeer( siteId: number; publicKey: string; endpoint: string; - serverIP: string | null; - serverPort: number | null; + serverIP?: string | null; + serverPort?: number | null; remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that }, olmId?: string @@ -102,6 +103,7 @@ export async function updatePeer( siteId: peer.siteId, publicKey: peer.publicKey, endpoint: peer.endpoint, + relayEndpoint: peer.serverIP, serverIP: peer.serverIP, serverPort: peer.serverPort, remoteSubnets: peer.remoteSubnets @@ -110,3 +112,40 @@ export async function updatePeer( logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`); } + +export async function initPeerAddHandshake( + clientId: number, + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }, + olmId?: string +) { + if (!olmId) { + const [olm] = await db + .select() + .from(olms) + .where(eq(olms.clientId, clientId)) + .limit(1); + if (!olm) { + throw new Error(`Olm with ID ${clientId} not found`); + } + olmId = olm.olmId; + } + + await sendToClient(olmId, { + type: "olm/wg/peer/holepunch/site/add", + data: { + siteId: peer.siteId, + exitNode: { + publicKey: peer.exitNode.publicKey, + endpoint: peer.exitNode.endpoint + } + } + }); + + logger.info(`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`); +} diff --git a/server/routers/org/createOrg.ts b/server/routers/org/createOrg.ts index e44bf021..8276da9a 100644 --- a/server/routers/org/createOrg.ts +++ b/server/routers/org/createOrg.ts @@ -28,10 +28,10 @@ import { FeatureId } from "@server/lib/billing"; import { build } from "@server/build"; const createOrgSchema = z.strictObject({ - orgId: z.string(), - name: z.string().min(1).max(255), - subnet: z.string() - }); + orgId: z.string(), + name: z.string().min(1).max(255), + subnet: z.string() +}); registry.registerPath({ method: "put", @@ -131,12 +131,16 @@ export async function createOrg( .from(domains) .where(eq(domains.configManaged, true)); + const utilitySubnet = + config.getRawConfig().orgs.utility_subnet_group; + const newOrg = await trx .insert(orgs) .values({ orgId, name, subnet, + utilitySubnet, createdAt: new Date().toISOString() }) .returning(); diff --git a/server/routers/org/deleteOrg.ts b/server/routers/org/deleteOrg.ts index 0e21a8c0..35dc7503 100644 --- a/server/routers/org/deleteOrg.ts +++ b/server/routers/org/deleteOrg.ts @@ -1,6 +1,15 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db, domains, orgDomains, resources } from "@server/db"; +import { + clients, + clientSiteResourcesAssociationsCache, + clientSitesAssociationsCache, + db, + domains, + olms, + orgDomains, + resources +} from "@server/db"; import { newts, newtSessions, orgs, sites, userActions } from "@server/db"; import { eq, and, inArray, sql } from "drizzle-orm"; import response from "@server/lib/response"; @@ -14,8 +23,8 @@ import { deletePeer } from "../gerbil/peers"; import { OpenAPITags, registry } from "@server/openApi"; const deleteOrgSchema = z.strictObject({ - orgId: z.string() - }); + orgId: z.string() +}); export type DeleteOrgResponse = {}; @@ -69,41 +78,75 @@ export async function deleteOrg( .where(eq(sites.orgId, orgId)) .limit(1); + const orgClients = await db + .select() + .from(clients) + .where(eq(clients.orgId, orgId)); + const deletedNewtIds: string[] = []; + const olmsToTerminate: string[] = []; await db.transaction(async (trx) => { - if (sites) { - for (const site of orgSites) { - if (site.pubKey) { - if (site.type == "wireguard") { - await deletePeer(site.exitNodeId!, site.pubKey); - } else if (site.type == "newt") { - // get the newt on the site by querying the newt table for siteId - const [deletedNewt] = await trx - .delete(newts) - .where(eq(newts.siteId, site.siteId)) - .returning(); - if (deletedNewt) { - deletedNewtIds.push(deletedNewt.newtId); + for (const site of orgSites) { + if (site.pubKey) { + if (site.type == "wireguard") { + await deletePeer(site.exitNodeId!, site.pubKey); + } else if (site.type == "newt") { + // get the newt on the site by querying the newt table for siteId + const [deletedNewt] = await trx + .delete(newts) + .where(eq(newts.siteId, site.siteId)) + .returning(); + if (deletedNewt) { + deletedNewtIds.push(deletedNewt.newtId); - // delete all of the sessions for the newt - await trx - .delete(newtSessions) - .where( - eq( - newtSessions.newtId, - deletedNewt.newtId - ) - ); - } + // delete all of the sessions for the newt + await trx + .delete(newtSessions) + .where( + eq(newtSessions.newtId, deletedNewt.newtId) + ); } } - - logger.info(`Deleting site ${site.siteId}`); - await trx - .delete(sites) - .where(eq(sites.siteId, site.siteId)); } + + logger.info(`Deleting site ${site.siteId}`); + await trx.delete(sites).where(eq(sites.siteId, site.siteId)); + } + for (const client of orgClients) { + const [olm] = await trx + .select() + .from(olms) + .where(eq(olms.clientId, client.clientId)) + .limit(1); + + if (olm) { + olmsToTerminate.push(olm.olmId); + } + + logger.info(`Deleting client ${client.clientId}`); + await trx + .delete(clients) + .where(eq(clients.clientId, client.clientId)); + + // also delete the associations + await trx + .delete(clientSiteResourcesAssociationsCache) + .where( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ) + ); + + await trx + .delete(clientSitesAssociationsCache) + .where( + eq( + clientSitesAssociationsCache.clientId, + client.clientId + ) + ); } const allOrgDomains = await trx @@ -150,7 +193,7 @@ export async function deleteOrg( // Send termination messages outside of transaction to prevent blocking for (const newtId of deletedNewtIds) { const payload = { - type: `newt/terminate`, + type: `newt/wg/terminate`, data: {} }; // Don't await this to prevent blocking the response @@ -162,6 +205,18 @@ export async function deleteOrg( }); } + for (const olmId of olmsToTerminate) { + sendToClient(olmId, { + type: "olm/terminate", + data: {} + }).catch((error) => { + logger.error( + "Failed to send termination message to olm:", + error + ); + }); + } + return response(res, { data: null, success: true, diff --git a/server/routers/siteResource/addClientToSiteResource.ts b/server/routers/siteResource/addClientToSiteResource.ts index 8fb6afdc..587294e5 100644 --- a/server/routers/siteResource/addClientToSiteResource.ts +++ b/server/routers/siteResource/addClientToSiteResource.ts @@ -8,7 +8,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const addClientToSiteResourceBodySchema = z .object({ @@ -136,7 +136,7 @@ export async function addClientToSiteResource( siteResourceId }); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/addRoleToSiteResource.ts b/server/routers/siteResource/addRoleToSiteResource.ts index 859ca5be..542ca535 100644 --- a/server/routers/siteResource/addRoleToSiteResource.ts +++ b/server/routers/siteResource/addRoleToSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const addRoleToSiteResourceBodySchema = z .object({ @@ -146,7 +146,7 @@ export async function addRoleToSiteResource( siteResourceId }); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/addUserToSiteResource.ts b/server/routers/siteResource/addUserToSiteResource.ts index 411d37b4..c9d1f30a 100644 --- a/server/routers/siteResource/addUserToSiteResource.ts +++ b/server/routers/siteResource/addUserToSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const addUserToSiteResourceBodySchema = z .object({ @@ -115,7 +115,7 @@ export async function addUserToSiteResource( siteResourceId }); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/createSiteResource.ts b/server/routers/siteResource/createSiteResource.ts index 2c7bf0fe..1d9cd6aa 100644 --- a/server/routers/siteResource/createSiteResource.ts +++ b/server/routers/siteResource/createSiteResource.ts @@ -17,7 +17,8 @@ import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; import { getUniqueSiteResourceName } from "@server/db/names"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; +import { getNextAvailableAliasAddress } from "@server/lib/ip"; const createSiteResourceParamsSchema = z.strictObject({ siteId: z.string().transform(Number).pipe(z.int().positive()), @@ -193,6 +194,10 @@ export async function createSiteResource( // } const niceId = await getUniqueSiteResourceName(orgId); + let aliasAddress: string | null = null; + if (mode == "host") { // we can only have an alias on a host + aliasAddress = await getNextAvailableAliasAddress(orgId); + } let newSiteResource: SiteResource | undefined; await db.transaction(async (trx) => { @@ -210,7 +215,8 @@ export async function createSiteResource( // destinationPort: mode === "port" ? destinationPort : null, destination, enabled, - alias: alias || null + alias, + aliasAddress }) .returning(); @@ -272,7 +278,7 @@ export async function createSiteResource( ); } - await rebuildClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role + await rebuildClientAssociationsFromSiteResource(newSiteResource, trx); // we need to call this because we added to the admin role }); if (!newSiteResource) { diff --git a/server/routers/siteResource/deleteSiteResource.ts b/server/routers/siteResource/deleteSiteResource.ts index 75f2c3f2..a7175608 100644 --- a/server/routers/siteResource/deleteSiteResource.ts +++ b/server/routers/siteResource/deleteSiteResource.ts @@ -9,7 +9,7 @@ import { eq, and } from "drizzle-orm"; import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const deleteSiteResourceParamsSchema = z.strictObject({ siteResourceId: z.string().transform(Number).pipe(z.int().positive()), @@ -106,7 +106,7 @@ export async function deleteSiteResource( ); } - await rebuildClientAssociations(removedSiteResource, trx); + await rebuildClientAssociationsFromSiteResource(removedSiteResource, trx); }); logger.info( diff --git a/server/routers/siteResource/removeClientFromSiteResource.ts b/server/routers/siteResource/removeClientFromSiteResource.ts index d46e5d67..c6a5dfe8 100644 --- a/server/routers/siteResource/removeClientFromSiteResource.ts +++ b/server/routers/siteResource/removeClientFromSiteResource.ts @@ -8,7 +8,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const removeClientFromSiteResourceBodySchema = z .object({ @@ -142,7 +142,7 @@ export async function removeClientFromSiteResource( ) ); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/removeRoleFromSiteResource.ts b/server/routers/siteResource/removeRoleFromSiteResource.ts index c4c68e06..0041ed83 100644 --- a/server/routers/siteResource/removeRoleFromSiteResource.ts +++ b/server/routers/siteResource/removeRoleFromSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const removeRoleFromSiteResourceBodySchema = z .object({ @@ -151,7 +151,7 @@ export async function removeRoleFromSiteResource( ) ); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/removeUserFromSiteResource.ts b/server/routers/siteResource/removeUserFromSiteResource.ts index 8a90b752..280a01f2 100644 --- a/server/routers/siteResource/removeUserFromSiteResource.ts +++ b/server/routers/siteResource/removeUserFromSiteResource.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const removeUserFromSiteResourceBodySchema = z .object({ @@ -121,7 +121,7 @@ export async function removeUserFromSiteResource( ) ); - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/setSiteResourceClients.ts b/server/routers/siteResource/setSiteResourceClients.ts index 974b27cc..0a25b7e9 100644 --- a/server/routers/siteResource/setSiteResourceClients.ts +++ b/server/routers/siteResource/setSiteResourceClients.ts @@ -8,7 +8,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, inArray } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const setSiteResourceClientsBodySchema = z .object({ @@ -124,7 +124,7 @@ export async function setSiteResourceClients( .values(clientIds.map((clientId) => ({ clientId, siteResourceId }))); } - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/setSiteResourceRoles.ts b/server/routers/siteResource/setSiteResourceRoles.ts index df44e02b..7aa07de1 100644 --- a/server/routers/siteResource/setSiteResourceRoles.ts +++ b/server/routers/siteResource/setSiteResourceRoles.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, and, ne, inArray } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const setSiteResourceRolesBodySchema = z .object({ @@ -147,7 +147,7 @@ export async function setSiteResourceRoles( .values(roleIds.map((roleId) => ({ roleId, siteResourceId }))); } - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/setSiteResourceUsers.ts b/server/routers/siteResource/setSiteResourceUsers.ts index 8ef9a0ab..4dae0ada 100644 --- a/server/routers/siteResource/setSiteResourceUsers.ts +++ b/server/routers/siteResource/setSiteResourceUsers.ts @@ -9,7 +9,7 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; -import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations"; +import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const setSiteResourceUsersBodySchema = z .object({ @@ -102,7 +102,7 @@ export async function setSiteResourceUsers( .values(userIds.map((userId) => ({ userId, siteResourceId }))); } - await rebuildClientAssociations(siteResource, trx); + await rebuildClientAssociationsFromSiteResource(siteResource, trx); }); return response(res, { diff --git a/server/routers/siteResource/updateSiteResource.ts b/server/routers/siteResource/updateSiteResource.ts index 470c24f6..233391fb 100644 --- a/server/routers/siteResource/updateSiteResource.ts +++ b/server/routers/siteResource/updateSiteResource.ts @@ -17,17 +17,15 @@ import { eq, and, ne } from "drizzle-orm"; import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; +import { updatePeerData, updateTargets } from "@server/routers/client/targets"; import { - updateRemoteSubnets, - updateTargets -} from "@server/routers/client/targets"; -import { + generateAliasConfig, generateRemoteSubnets, generateSubnetProxyTargets } from "@server/lib/ip"; import { getClientSiteResourceAccess, - rebuildClientAssociations + rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; const updateSiteResourceParamsSchema = z.strictObject({ @@ -51,7 +49,44 @@ const updateSiteResourceSchema = z roleIds: z.array(z.int()), clientIds: z.array(z.int()) }) - .strict(); + .strict() + .refine( + (data) => { + if (data.mode === "host" && data.destination) { + // Check if it's a valid IP address using zod (v4 or v6) + const isValidIP = z + .union([z.ipv4(), z.ipv6()]) + .safeParse(data.destination).success; + + // Check if it's a valid domain (hostname pattern, TLD not required) + const domainRegex = + /^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/; + const isValidDomain = domainRegex.test(data.destination); + + return isValidIP || isValidDomain; + } + return true; + }, + { + message: + "Destination must be a valid IP address or domain name for host mode" + } + ) + .refine( + (data) => { + if (data.mode === "cidr" && data.destination) { + // Check if it's a valid CIDR (v4 or v6) + const isValidCIDR = z + .union([z.cidrv4(), z.cidrv6()]) + .safeParse(data.destination).success; + return isValidCIDR; + } + return true; + }, + { + message: "Destination must be a valid CIDR notation for cidr mode" + } + ); export type UpdateSiteResourceBody = z.infer; export type UpdateSiteResourceResponse = SiteResource; @@ -226,16 +261,20 @@ export async function updateSiteResource( ); } - const { mergedAllClients } = await rebuildClientAssociations( - existingSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below - trx - ); + const { mergedAllClients } = + await rebuildClientAssociationsFromSiteResource( + existingSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below + trx + ); // after everything is rebuilt above we still need to update the targets and remote subnets if the destination changed - if ( + const destinationChanged = existingSiteResource.destination !== - updatedSiteResource.destination - ) { + updatedSiteResource.destination; + const aliasChanged = + existingSiteResource.alias !== updatedSiteResource.alias; + + if (destinationChanged || aliasChanged) { const [newt] = await trx .select() .from(newts) @@ -248,25 +287,28 @@ export async function updateSiteResource( ); } - const oldTargets = generateSubnetProxyTargets( - existingSiteResource, - mergedAllClients - ); - const newTargets = generateSubnetProxyTargets( - updatedSiteResource, - mergedAllClients - ); + // Only update targets on newt if destination changed + if (destinationChanged) { + const oldTargets = generateSubnetProxyTargets( + existingSiteResource, + mergedAllClients + ); + const newTargets = generateSubnetProxyTargets( + updatedSiteResource, + mergedAllClients + ); - await updateTargets(newt.newtId, { - oldTargets: oldTargets, - newTargets: newTargets - }); + await updateTargets(newt.newtId, { + oldTargets: oldTargets, + newTargets: newTargets + }); + } const olmJobs: Promise[] = []; for (const client of mergedAllClients) { // we also need to update the remote subnets on the olms for each client that has access to this site olmJobs.push( - updateRemoteSubnets( + updatePeerData( client.clientId, updatedSiteResource.siteId, { @@ -276,8 +318,22 @@ export async function updateSiteResource( newRemoteSubnets: generateRemoteSubnets([ updatedSiteResource ]) + }, + { + oldAliases: generateAliasConfig([ + existingSiteResource + ]), + newAliases: generateAliasConfig([ + updatedSiteResource + ]) } - ) + ).catch((error) => { + // this is okay because sometimes the olm is not online to receive the update or associated with the client yet + logger.warn( + `Error updating peer data for client ${client.clientId}:`, + error + ); + }) ); } diff --git a/server/routers/user/addUserRole.ts b/server/routers/user/addUserRole.ts index 915ea64a..32eaa19d 100644 --- a/server/routers/user/addUserRole.ts +++ b/server/routers/user/addUserRole.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { clients, db, UserOrg } from "@server/db"; import { userOrgs, roles } from "@server/db"; import { eq, and } from "drizzle-orm"; import response from "@server/lib/response"; @@ -10,11 +10,12 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import stoi from "@server/lib/stoi"; import { OpenAPITags, registry } from "@server/openApi"; +import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; const addUserRoleParamsSchema = z.strictObject({ - userId: z.string(), - roleId: z.string().transform(stoi).pipe(z.number()) - }); + userId: z.string(), + roleId: z.string().transform(stoi).pipe(z.number()) +}); export type AddUserRoleResponse = z.infer; @@ -72,7 +73,9 @@ export async function addUserRole( const existingUser = await db .select() .from(userOrgs) - .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, role.orgId))) + .where( + and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, role.orgId)) + ) .limit(1); if (existingUser.length === 0) { @@ -108,14 +111,39 @@ export async function addUserRole( ); } - const newUserRole = await db - .update(userOrgs) - .set({ roleId }) - .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, role.orgId))) - .returning(); + let newUserRole: UserOrg | null = null; + await db.transaction(async (trx) => { + [newUserRole] = await trx + .update(userOrgs) + .set({ roleId }) + .where( + and( + eq(userOrgs.userId, userId), + eq(userOrgs.orgId, role.orgId) + ) + ) + .returning(); + + // get the client associated with this user in this org + const orgClients = await trx + .select() + .from(clients) + .where( + and( + eq(clients.userId, userId), + eq(clients.orgId, role.orgId) + ) + ) + .limit(1); + + for (const orgClient of orgClients) { + // we just changed the user's role, so we need to rebuild client associations and what they have access to + await rebuildClientAssociationsFromClient(orgClient, trx); + } + }); return response(res, { - data: newUserRole[0], + data: newUserRole, success: true, error: false, message: "Role added to user successfully", diff --git a/server/routers/user/adminRemoveUser.ts b/server/routers/user/adminRemoveUser.ts index 02ad56d6..ae7f9f47 100644 --- a/server/routers/user/adminRemoveUser.ts +++ b/server/routers/user/adminRemoveUser.ts @@ -8,10 +8,11 @@ import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; +import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; const removeUserSchema = z.strictObject({ - userId: z.string() - }); + userId: z.string() +}); export async function adminRemoveUser( req: Request, @@ -50,7 +51,11 @@ export async function adminRemoveUser( ); } - await db.delete(users).where(eq(users.userId, userId)); + await db.transaction(async (trx) => { + await trx.delete(users).where(eq(users.userId, userId)); + + await calculateUserClientsForOrgs(userId, trx); + }); return response(res, { data: null, diff --git a/server/routers/ws/messageHandlers.ts b/server/routers/ws/messageHandlers.ts index cbb023b3..acd1aef0 100644 --- a/server/routers/ws/messageHandlers.ts +++ b/server/routers/ws/messageHandlers.ts @@ -11,23 +11,27 @@ import { handleOlmRegisterMessage, handleOlmRelayMessage, handleOlmPingMessage, - startOlmOfflineChecker + startOlmOfflineChecker, + handleOlmServerPeerAddMessage, + handleOlmUnRelayMessage } from "../olm"; import { handleHealthcheckStatusMessage } from "../target"; import { MessageHandler } from "./types"; export const messageHandlers: Record = { - "newt/wg/register": handleNewtRegisterMessage, + "olm/wg/server/peer/add": handleOlmServerPeerAddMessage, "olm/wg/register": handleOlmRegisterMessage, + "olm/wg/relay": handleOlmRelayMessage, + "olm/wg/unrelay": handleOlmUnRelayMessage, + "olm/ping": handleOlmPingMessage, + "newt/wg/register": handleNewtRegisterMessage, "newt/wg/get-config": handleGetConfigMessage, "newt/receive-bandwidth": handleReceiveBandwidthMessage, - "olm/wg/relay": handleOlmRelayMessage, - "olm/ping": handleOlmPingMessage, "newt/socket/status": handleDockerStatusMessage, "newt/socket/containers": handleDockerContainersMessage, "newt/ping/request": handleNewtPingRequestMessage, "newt/blueprint/apply": handleApplyBlueprintMessage, - "newt/healthcheck/status": handleHealthcheckStatusMessage, + "newt/healthcheck/status": handleHealthcheckStatusMessage }; -startOlmOfflineChecker(); // this is to handle the offline check for olms \ No newline at end of file +startOlmOfflineChecker(); // this is to handle the offline check for olms diff --git a/src/app/auth/login/device/page.tsx b/src/app/auth/login/device/page.tsx index a19174d0..07c804fb 100644 --- a/src/app/auth/login/device/page.tsx +++ b/src/app/auth/login/device/page.tsx @@ -15,8 +15,6 @@ export default async function DeviceLoginPage({ searchParams }: Props) { const params = await searchParams; const code = params.code || ""; - console.log("user", user); - if (!user) { const redirectDestination = code ? `/auth/login/device?code=${encodeURIComponent(code)}` diff --git a/src/components/DeviceLoginForm.tsx b/src/components/DeviceLoginForm.tsx index 8b6d460c..1eeeb5ae 100644 --- a/src/components/DeviceLoginForm.tsx +++ b/src/components/DeviceLoginForm.tsx @@ -84,6 +84,9 @@ export default function DeviceLoginForm({ if (!data.code.includes("-") && data.code.length === 8) { data.code = data.code.slice(0, 4) + "-" + data.code.slice(4); } + + await new Promise((resolve) => setTimeout(resolve, 300)); + // First check - get metadata const res = await api.post( "/device-web-auth/verify?forceLogin=true", @@ -93,8 +96,6 @@ export default function DeviceLoginForm({ } ); - await new Promise((resolve) => setTimeout(resolve, 500)); // artificial delay for better UX - if (res.data.success && res.data.data.metadata) { setMetadata(res.data.data.metadata); setCode(data.code.toUpperCase()); @@ -116,14 +117,14 @@ export default function DeviceLoginForm({ setLoading(true); try { + await new Promise((resolve) => setTimeout(resolve, 300)); + // Final verify await api.post("/device-web-auth/verify", { code: code, verify: true }); - await new Promise((resolve) => setTimeout(resolve, 500)); // artificial delay for better UX - // Redirect to success page router.push("/auth/login/device/success"); } catch (e: any) {