diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index f26ce060..2773e098 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -18,7 +18,7 @@ import { users, userSiteResources } from "@server/db"; -import { and, eq, inArray } from "drizzle-orm"; +import { and, eq, inArray, ne } from "drizzle-orm"; import { addPeer as newtAddPeer, @@ -38,6 +38,7 @@ import { import { addRemoteSubnets, addTargets as addSubnetProxyTargets, + removeRemoteSubnets, removeTargets as removeSubnetProxyTargets } from "@server/routers/client/targets"; @@ -142,68 +143,36 @@ export async function rebuildClientAssociations( const { site, mergedAllClients, mergedAllClientIds } = await getClientSiteResourceAccess(siteResource, trx); - /////////// process the client-site associations /////////// - - const existingClientSites = await trx - .select({ - clientId: clientSitesAssociationsCache.clientId - }) - .from(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.siteId, siteResource.siteId)); - - const existingClientSiteIds = existingClientSites.map( - (row) => row.clientId - ); - - // Get full client details for existing clients (needed for sending delete messages) - const existingClients = await trx - .select({ - clientId: clients.clientId, - pubKey: clients.pubKey, - subnet: clients.subnet - }) - .from(clients) - .where(inArray(clients.clientId, existingClientSiteIds)); - - // ------------- calculations begin below ------------- - - const clientSitesToAdd = mergedAllClientIds.filter( - (clientId) => !existingClientSiteIds.includes(clientId) - ); - - const clientSitesToInsert = clientSitesToAdd.map((clientId) => ({ - clientId, - siteId - })); - - if (clientSitesToInsert.length > 0) { - await trx - .insert(clientSitesAssociationsCache) - .values(clientSitesToInsert) - .returning(); - } - - // Now remove any client-site associations that should no longer exist - const clientSitesToRemove = existingClientSiteIds.filter( - (clientId) => !mergedAllClientIds.includes(clientId) - ); - - if (clientSitesToRemove.length > 0) { - await trx - .delete(clientSitesAssociationsCache) - .where( - and( - eq(clientSitesAssociationsCache.siteId, siteId), - inArray( - clientSitesAssociationsCache.clientId, - clientSitesToRemove - ) - ) - ); - } - /////////// process the client-siteResource associations /////////// + // get all of the clients associated with other resources on this site + const allUpdatedClientsFromOtherResourcesOnThisSite = await trx + .select({ + clientId: clientSiteResourcesAssociationsCache.clientId + }) + .from(clientSiteResourcesAssociationsCache) + .innerJoin( + siteResources, + eq( + clientSiteResourcesAssociationsCache.siteResourceId, + siteResources.siteResourceId + ) + ) + .where( + and( + eq(siteResources.siteId, siteId), + ne(siteResources.siteResourceId, siteResource.siteResourceId) + ) + ); + + const allClientIdsFromOtherResourcesOnThisSite = Array.from( + new Set( + allUpdatedClientsFromOtherResourcesOnThisSite.map( + (row) => row.clientId + ) + ) + ); + const existingClientSiteResources = await trx .select({ clientId: clientSiteResourcesAssociationsCache.clientId @@ -274,6 +243,70 @@ export async function rebuildClientAssociations( ); } + /////////// process the client-site associations /////////// + + const existingClientSites = await trx + .select({ + clientId: clientSitesAssociationsCache.clientId + }) + .from(clientSitesAssociationsCache) + .where(eq(clientSitesAssociationsCache.siteId, siteResource.siteId)); + + const existingClientSiteIds = existingClientSites.map( + (row) => row.clientId + ); + + // Get full client details for existing clients (needed for sending delete messages) + const existingClients = await trx + .select({ + clientId: clients.clientId, + pubKey: clients.pubKey, + subnet: clients.subnet + }) + .from(clients) + .where(inArray(clients.clientId, existingClientSiteIds)); + + const clientSitesToAdd = mergedAllClientIds.filter( + (clientId) => + !existingClientSiteIds.includes(clientId) && + !allClientIdsFromOtherResourcesOnThisSite.includes(clientId) // dont remove if there is still another connection for another site resource + ); + + const clientSitesToInsert = clientSitesToAdd.map((clientId) => ({ + clientId, + siteId + })); + + if (clientSitesToInsert.length > 0) { + await trx + .insert(clientSitesAssociationsCache) + .values(clientSitesToInsert) + .returning(); + } + + // Now remove any client-site associations that should no longer exist + const clientSitesToRemove = existingClientSiteIds.filter( + (clientId) => + !mergedAllClientIds.includes(clientId) && + !allClientIdsFromOtherResourcesOnThisSite.includes(clientId) // dont remove if there is still another connection for another site resource + ); + + if (clientSitesToRemove.length > 0) { + await trx + .delete(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.siteId, siteId), + inArray( + clientSitesAssociationsCache.clientId, + clientSitesToRemove + ) + ) + ); + } + + /////////// send the messages /////////// + // Now handle the messages to add/remove peers on both the newt and olm sides await handleMessagesForSiteClients( site, @@ -670,7 +703,11 @@ async function handleSubnetProxyTargetUpdates( for (const client of addedClients) { olmJobs.push( - addRemoteSubnets(client.clientId, siteResource.siteId, generateRemoteSubnets([siteResource])) + addRemoteSubnets( + client.clientId, + siteResource.siteId, + generateRemoteSubnets([siteResource]) + ) ); } } @@ -701,11 +738,15 @@ async function handleSubnetProxyTargetUpdates( for (const client of removedClients) { olmJobs.push( - addRemoteSubnets(client.clientId, siteResource.siteId, generateRemoteSubnets([siteResource])) + removeRemoteSubnets( + client.clientId, + siteResource.siteId, + generateRemoteSubnets([siteResource]) + ) ); } } } await Promise.all(proxyJobs); -} \ No newline at end of file +}