Fixing holepunching and other bugs

This commit is contained in:
Owen
2025-12-03 20:31:25 -05:00
parent 7efc947e26
commit 8eec122114
15 changed files with 249 additions and 140 deletions

View File

@@ -630,7 +630,7 @@ export const idpOrg = pgTable("idpOrg", {
});
export const clients = pgTable("clients", {
clientId: serial("id").primaryKey(),
clientId: serial("clientId").primaryKey(),
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
@@ -684,6 +684,7 @@ export const olms = pgTable("olms", {
secretHash: varchar("secretHash").notNull(),
dateCreated: varchar("dateCreated").notNull(),
version: text("version"),
agent: text("agent"),
name: varchar("name"),
clientId: integer("clientId").references(() => clients.clientId, {
// we will switch this depending on the current org it wants to connect to

View File

@@ -336,7 +336,7 @@ export const newts = sqliteTable("newt", {
});
export const clients = sqliteTable("clients", {
clientId: integer("id").primaryKey({ autoIncrement: true }),
clientId: integer("clientId").primaryKey({ autoIncrement: true }),
orgId: text("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
@@ -392,6 +392,7 @@ export const olms = sqliteTable("olms", {
secretHash: text("secretHash").notNull(),
dateCreated: text("dateCreated").notNull(),
version: text("version"),
agent: text("agent"),
name: text("name"),
clientId: integer("clientId").references(() => clients.clientId, {
// we will switch this depending on the current org it wants to connect to

View File

@@ -433,12 +433,15 @@ export function generateSubnetProxyTargets(
const clientPrefix = `${clientSite.subnet.split("/")[0]}/32`;
if (siteResource.mode == "host") {
let destination = siteResource.destination;
// check if this is a valid ip
const ipSchema = z.union([z.ipv4(), z.ipv6()]);
if (ipSchema.safeParse(siteResource.destination).success) {
if (ipSchema.safeParse(destination).success) {
destination = `${destination}/32`;
targets.push({
sourcePrefix: clientPrefix,
destPrefix: `${siteResource.destination}/32`
destPrefix: destination
});
}
@@ -447,7 +450,7 @@ export function generateSubnetProxyTargets(
targets.push({
sourcePrefix: clientPrefix,
destPrefix: `${siteResource.aliasAddress}/32`,
rewriteTo: `${siteResource.destination}/32`
rewriteTo: destination
});
}
} else if (siteResource.mode == "cidr") {
@@ -459,9 +462,9 @@ export function generateSubnetProxyTargets(
}
// print a nice representation of the targets
logger.debug(
`Generated subnet proxy targets for: ${JSON.stringify(targets, null, 2)}`
);
// logger.debug(
// `Generated subnet proxy targets for: ${JSON.stringify(targets, null, 2)}`
// );
return targets;
}

View File

@@ -15,7 +15,6 @@ import {
sites,
Transaction,
userOrgs,
users,
userSiteResources
} from "@server/db";
import { and, eq, inArray, ne } from "drizzle-orm";
@@ -26,7 +25,6 @@ import {
} from "@server/routers/newt/peers";
import {
initPeerAddHandshake as holepunchSiteAdd,
addPeer as olmAddPeer,
deletePeer as olmDeletePeer
} from "@server/routers/olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
@@ -35,7 +33,6 @@ import {
generateAliasConfig,
generateRemoteSubnets,
generateSubnetProxyTargets,
SubnetProxyTarget
} from "@server/lib/ip";
import {
addPeerData,
@@ -95,7 +92,12 @@ export async function getClientSiteResourceAccess(
subnet: clients.subnet
})
.from(clients)
.where(inArray(clients.userId, newAllUserIds));
.where(
and(
inArray(clients.userId, newAllUserIds),
eq(clients.orgId, siteResource.orgId) // filter by org to prevent cross-org associations
)
);
const allClientSiteResources = await trx // this is for if a client is directly associated with a resource instead of implicitly via a user
.select()
@@ -107,14 +109,21 @@ export async function getClientSiteResourceAccess(
const directClientIds = allClientSiteResources.map((row) => row.clientId);
// Get full client details for directly associated clients
const directClients = await trx
const directClients = directClientIds.length > 0
? await trx
.select({
clientId: clients.clientId,
pubKey: clients.pubKey,
subnet: clients.subnet
})
.from(clients)
.where(inArray(clients.clientId, directClientIds));
.where(
and(
inArray(clients.clientId, directClientIds),
eq(clients.orgId, siteResource.orgId) // filter by org to prevent cross-org associations
)
)
: [];
// Merge user-based clients with directly associated clients
const allClientsMap = new Map(
@@ -717,7 +726,16 @@ export async function rebuildClientAssociationsFromClient(
const directSiteResources = await trx
.select({ siteResourceId: clientSiteResources.siteResourceId })
.from(clientSiteResources)
.where(eq(clientSiteResources.clientId, client.clientId));
.innerJoin(
siteResources,
eq(siteResources.siteResourceId, clientSiteResources.siteResourceId)
)
.where(
and(
eq(clientSiteResources.clientId, client.clientId),
eq(siteResources.orgId, client.orgId) // filter by org to prevent cross-org associations
)
);
newSiteResourceIds.push(
...directSiteResources.map((r) => r.siteResourceId)
@@ -763,7 +781,16 @@ export async function rebuildClientAssociationsFromClient(
const roleSiteResourceIds = await trx
.select({ siteResourceId: roleSiteResources.siteResourceId })
.from(roleSiteResources)
.where(inArray(roleSiteResources.roleId, roleIds));
.innerJoin(
siteResources,
eq(siteResources.siteResourceId, roleSiteResources.siteResourceId)
)
.where(
and(
inArray(roleSiteResources.roleId, roleIds),
eq(siteResources.orgId, client.orgId) // filter by org to prevent cross-org associations
)
);
newSiteResourceIds.push(
...roleSiteResourceIds.map((r) => r.siteResourceId)

View File

@@ -371,6 +371,9 @@ const sendToClientLocal = async (
client.send(messageString);
}
});
logger.debug(`sendToClient: Message type ${message.type} sent to clientId ${clientId}`);
return true;
};

View File

@@ -1,6 +1,7 @@
import { sendToClient } from "#dynamic/routers/ws";
import { db, olms } from "@server/db";
import { Alias, SubnetProxyTarget } from "@server/lib/ip";
import logger from "@server/logger";
import { eq } from "drizzle-orm";
export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
@@ -30,6 +31,8 @@ export async function updateTargets(
await sendToClient(newtId, {
type: `newt/wg/targets/update`,
data: targets
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
@@ -47,7 +50,7 @@ export async function addPeerData(
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
return; // ignore this because an olm might not be associated with the client anymore
}
olmId = olm.olmId;
}
@@ -59,6 +62,8 @@ export async function addPeerData(
remoteSubnets: remoteSubnets,
aliases: aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
@@ -76,7 +81,7 @@ export async function removePeerData(
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
return;
}
olmId = olm.olmId;
}
@@ -88,6 +93,8 @@ export async function removePeerData(
remoteSubnets: remoteSubnets,
aliases: aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
@@ -95,13 +102,13 @@ export async function updatePeerData(
clientId: number,
siteId: number,
remoteSubnets: {
oldRemoteSubnets: string[],
newRemoteSubnets: string[]
},
oldRemoteSubnets: string[];
newRemoteSubnets: string[];
} | undefined,
aliases: {
oldAliases: Alias[],
newAliases: Alias[]
},
oldAliases: Alias[];
newAliases: Alias[];
} | undefined,
olmId?: string
) {
if (!olmId) {
@@ -111,7 +118,7 @@ export async function updatePeerData(
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
return;
}
olmId = olm.olmId;
}
@@ -123,5 +130,7 @@ export async function updatePeerData(
...remoteSubnets,
...aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}

View File

@@ -10,14 +10,11 @@ import {
clientSiteResourcesAssociationsCache
} from "@server/db";
import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db";
import { eq, and, inArray } from "drizzle-orm";
import { eq } from "drizzle-orm";
import { updatePeer } from "../olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
import {
generateRemoteSubnets,
generateSubnetProxyTargets,
SubnetProxyTarget
} from "@server/lib/ip";
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip";
import config from "@server/lib/config";
const inputSchema = z.object({
publicKey: z.string(),
@@ -81,7 +78,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) {
logger.warn(
`Site ${existingSite.siteId} last hole punch is too old, skipping`
`handleGetConfigMessage: Site ${existingSite.siteId} last hole punch is too old, skipping`
);
return;
}
@@ -148,84 +145,77 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
clientsRes
.filter((client) => {
if (!client.clients.pubKey) {
logger.warn(
`Client ${client.clients.clientId} has no public key, skipping`
);
return false;
}
if (!client.clients.subnet) {
logger.warn(
`Client ${client.clients.clientId} has no subnet, skipping`
);
return false;
}
return true;
})
.map(async (client) => {
// Add or update this peer on the olm if it is connected
try {
if (!site.publicKey) {
logger.warn(
`Site ${site.siteId} has no public key, skipping`
);
return null;
}
let endpoint = site.endpoint;
if (client.clientSitesAssociationsCache.isRelayed) {
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} has no exit node, skipping`
);
return null;
}
if (!exitNode) {
logger.warn(
`Exit node not found for site ${site.siteId}`
);
logger.warn(`Exit node not found for site ${site.siteId}`);
return null;
}
endpoint = `${exitNode.endpoint}:21820`;
}
if (!endpoint) {
if (!site.endpoint) {
logger.warn(
`In Newt get config: Peer site ${site.siteId} has no endpoint, skipping`
`Site ${site.siteId} has no endpoint, skipping`
);
return null;
}
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clients.clientId
)
)
);
// const allSiteResources = await db // only get the site resources that this client has access to
// .select()
// .from(siteResources)
// .innerJoin(
// clientSiteResourcesAssociationsCache,
// eq(
// siteResources.siteResourceId,
// clientSiteResourcesAssociationsCache.siteResourceId
// )
// )
// .where(
// and(
// eq(siteResources.siteId, site.siteId),
// eq(
// clientSiteResourcesAssociationsCache.clientId,
// client.clients.clientId
// )
// )
// );
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: endpoint,
endpoint: site.endpoint,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: generateRemoteSubnets(
allSiteResources.map(
({ siteResources }) => siteResources
)
)
serverPort: site.listenPort
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// ),
// aliases: generateAliasConfig(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// )
});
} catch (error) {
logger.error(
`Failed to add/update peer ${client.clients.pubKey} to olm ${newt.newtId}: ${error}`
);
}
return {
publicKey: client.clients.pubKey!,

View File

@@ -21,7 +21,7 @@ export async function addPeer(
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Exit node with ID ${siteId} not found`);
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
@@ -39,6 +39,8 @@ export async function addPeer(
await sendToClient(newtId, {
type: "newt/wg/peer/add",
data: peer
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Added peer ${peer.publicKey} to newt ${newtId}`);
@@ -75,6 +77,8 @@ export async function deletePeer(siteId: number, publicKey: string, newtId?: str
data: {
publicKey
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Deleted peer ${publicKey} from newt ${newtId}`);
@@ -120,6 +124,8 @@ export async function updatePeer(
publicKey,
...peer
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Updated peer ${publicKey} on newt ${newtId}`);

View File

@@ -10,7 +10,7 @@ import {
import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { eq, inArray } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
@@ -22,7 +22,6 @@ import {
import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import config from "@server/lib/config";
import { listExitNodes } from "#dynamic/lib/exitNodes";
export const olmGetTokenBodySchema = z.object({
olmId: z.string(),
@@ -139,7 +138,9 @@ export async function getOlmToken(
const [client] = await db
.select()
.from(clients)
.where(eq(clients.orgId, orgIdToUse))
.where(
and(eq(clients.orgId, orgIdToUse), eq(clients.olmId, olmId))
) // we want to lock on to the client with this olmId otherwise it can get assigned to a random one
.limit(1);
if (!client) {

View File

@@ -48,7 +48,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
const { publicKey, relay, olmVersion, orgId, userToken } = message.data;
const { publicKey, relay, olmVersion, olmAgent, orgId, userToken } = message.data;
if (!olm.clientId) {
logger.warn("Olm client ID not found");
@@ -117,11 +117,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
if (olmVersion && olm.version !== olmVersion) {
if ((olmVersion && olm.version !== olmVersion) || (olmAgent && olm.agent !== olmAgent)) {
await db
.update(olms)
.set({
version: olmVersion
version: olmVersion,
agent: olmAgent
})
.where(eq(olms.olmId, olm.olmId));
}
@@ -274,7 +275,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
relayEndpoint: relayEndpoint, // this can be undefined now if not relayed
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,

View File

@@ -108,7 +108,7 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
let endpoint: string | null = null;
// TODO: should we pick only the one from the site its talking to instead of any good current session?
const currentSessionSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)

View File

@@ -3,7 +3,7 @@ import { clients, olms, newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import { exit } from "process";
import { Alias } from "yaml";
export async function addPeer(
clientId: number,
@@ -11,9 +11,11 @@ export async function addPeer(
siteId: number;
publicKey: string;
endpoint: string;
relayEndpoint: string;
serverIP: string | null;
serverPort: number | null;
remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access
aliases: Alias[];
},
olmId?: string
) {
@@ -24,7 +26,7 @@ export async function addPeer(
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
return; // ignore this because an olm might not be associated with the client anymore
}
olmId = olm.olmId;
}
@@ -35,10 +37,14 @@ export async function addPeer(
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
relayEndpoint: peer.relayEndpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort,
remoteSubnets: peer.remoteSubnets // optional, comma-separated list of subnets that this site can access
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
aliases: peer.aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`);
@@ -57,7 +63,7 @@ export async function deletePeer(
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
return;
}
olmId = olm.olmId;
}
@@ -68,6 +74,8 @@ export async function deletePeer(
publicKey,
siteId: siteId
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Deleted peer ${publicKey} from olm ${olmId}`);
@@ -79,9 +87,11 @@ export async function updatePeer(
siteId: number;
publicKey: string;
endpoint: string;
relayEndpoint?: string;
serverIP?: string | null;
serverPort?: number | null;
remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that
aliases?: Alias[] | null;
},
olmId?: string
) {
@@ -92,7 +102,7 @@ export async function updatePeer(
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
return
}
olmId = olm.olmId;
}
@@ -103,14 +113,17 @@ export async function updatePeer(
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
relayEndpoint: peer.serverIP,
relayEndpoint: peer.relayEndpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort,
remoteSubnets: peer.remoteSubnets
remoteSubnets: peer.remoteSubnets,
aliases: peer.aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`);
logger.info(`Updated peer ${peer.publicKey} on olm ${olmId}`);
}
export async function initPeerAddHandshake(
@@ -131,7 +144,7 @@ export async function initPeerAddHandshake(
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
return;
}
olmId = olm.olmId;
}
@@ -145,6 +158,8 @@ export async function initPeerAddHandshake(
endpoint: peer.exitNode.endpoint
}
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`);

View File

@@ -12,7 +12,7 @@ import { siteResources, sites, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import { eq, and, is } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
@@ -64,12 +64,17 @@ const createSiteResourceSchema = z
.union([z.ipv4(), z.ipv6()])
.safeParse(data.destination).success;
if (isValidIP) {
return true
}
// Check if it's a valid domain (hostname pattern, TLD not required)
const domainRegex =
/^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/;
const isValidDomain = domainRegex.test(data.destination);
const isValidAlias = data.alias && domainRegex.test(data.alias);
return isValidIP || isValidDomain;
return isValidDomain && isValidAlias; // require the alias to be set in the case of domain
}
return true;
},
@@ -193,9 +198,33 @@ export async function createSiteResource(
// }
// }
// make sure the alias is unique within the org if provided
if (alias) {
const [conflict] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.orgId, orgId),
eq(siteResources.alias, alias.trim())
)
)
.limit(1);
if (conflict) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Alias already in use by another site resource"
)
);
}
}
const niceId = await getUniqueSiteResourceName(orgId);
let aliasAddress: string | null = null;
if (mode == "host") { // we can only have an alias on a host
if (mode == "host") {
// we can only have an alias on a host
aliasAddress = await getNextAvailableAliasAddress(orgId);
}
@@ -278,7 +307,10 @@ export async function createSiteResource(
);
}
await rebuildClientAssociationsFromSiteResource(newSiteResource, trx); // we need to call this because we added to the admin role
await rebuildClientAssociationsFromSiteResource(
newSiteResource,
trx
); // we need to call this because we added to the admin role
});
if (!newSiteResource) {

View File

@@ -178,6 +178,30 @@ export async function updateSiteResource(
);
}
// make sure the alias is unique within the org if provided
if (alias) {
const [conflict] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.orgId, orgId),
eq(siteResources.alias, alias.trim()),
ne(siteResources.siteResourceId, siteResourceId) // exclude self
)
)
.limit(1);
if (conflict) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Alias already in use by another site resource"
)
);
}
}
let updatedSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => {
// Update the site resource
@@ -311,29 +335,23 @@ export async function updateSiteResource(
updatePeerData(
client.clientId,
updatedSiteResource.siteId,
{
destinationChanged ? {
oldRemoteSubnets: generateRemoteSubnets([
existingSiteResource
]),
newRemoteSubnets: generateRemoteSubnets([
updatedSiteResource
])
},
{
} : undefined,
aliasChanged ? {
oldAliases: generateAliasConfig([
existingSiteResource
]),
newAliases: generateAliasConfig([
updatedSiteResource
])
}
).catch((error) => {
// this is okay because sometimes the olm is not online to receive the update or associated with the client yet
logger.warn(
`Error updating peer data for client ${client.clientId}:`,
error
);
})
} : undefined
)
);
}

View File

@@ -95,6 +95,8 @@ const sendToClient = async (clientId: string, message: WSMessage): Promise<boole
// Try to send locally first
const localSent = await sendToClientLocal(clientId, message);
logger.debug(`sendToClient: Message type ${message.type} sent to clientId ${clientId}`);
return localSent;
};