Pass through transaction

This commit is contained in:
Owen
2025-10-17 14:04:49 -07:00
parent e5a436593f
commit c07abf8ff9
4 changed files with 29 additions and 19 deletions

View File

@@ -612,7 +612,8 @@ export class UsageService {
public async getUsage(
orgId: string,
featureId: FeatureId
featureId: FeatureId,
trx: Transaction | typeof db = db
): Promise<Usage | null> {
if (noop()) {
return null;
@@ -621,7 +622,7 @@ export class UsageService {
const usageId = `${orgId}-${featureId}`;
try {
const [result] = await db
const [result] = await trx
.select()
.from(usage)
.where(eq(usage.usageId, usageId))
@@ -635,7 +636,7 @@ export class UsageService {
const meterId = getFeatureMeterId(featureId);
try {
const [newUsage] = await db
const [newUsage] = await trx
.insert(usage)
.values({
usageId,
@@ -652,7 +653,7 @@ export class UsageService {
return newUsage;
} else {
// Record was created by another process, fetch it
const [existingUsage] = await db
const [existingUsage] = await trx
.select()
.from(usage)
.where(eq(usage.usageId, usageId))
@@ -665,7 +666,7 @@ export class UsageService {
`Insert failed for ${orgId}/${featureId}, attempting to fetch existing record:`,
insertError
);
const [existingUsage] = await db
const [existingUsage] = await trx
.select()
.from(usage)
.where(eq(usage.usageId, usageId))
@@ -812,7 +813,8 @@ export class UsageService {
orgId: string,
kickSites = false,
featureId?: FeatureId,
usage?: Usage
usage?: Usage,
trx: Transaction | typeof db = db
): Promise<boolean> {
if (noop()) {
return false;
@@ -825,7 +827,7 @@ export class UsageService {
let orgLimits: Limit[] = [];
if (featureId) {
// Get all limits set for this organization
orgLimits = await db
orgLimits = await trx
.select()
.from(limits)
.where(
@@ -836,7 +838,7 @@ export class UsageService {
);
} else {
// Get all limits set for this organization
orgLimits = await db
orgLimits = await trx
.select()
.from(limits)
.where(eq(limits.orgId, orgId));
@@ -855,7 +857,8 @@ export class UsageService {
} else {
currentUsage = await this.getUsage(
orgId,
limit.featureId as FeatureId
limit.featureId as FeatureId,
trx
);
}
@@ -890,7 +893,7 @@ export class UsageService {
);
// Get all sites for this organization
const orgSites = await db
const orgSites = await trx
.select()
.from(sites)
.where(eq(sites.orgId, orgId));
@@ -902,7 +905,7 @@ export class UsageService {
// Send termination messages to newt sites
for (const site of orgSites) {
if (site.type === "newt") {
const [newt] = await db
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
@@ -917,7 +920,7 @@ export class UsageService {
};
// Don't await to prevent blocking
sendToClient(newt.newtId, payload).catch(
await sendToClient(newt.newtId, payload).catch(
(error: any) => {
logger.error(
`Failed to send termination message to newt ${newt.newtId}:`,

View File

@@ -1,4 +1,4 @@
import { db, exitNodes } from "@server/db";
import { db, exitNodes, Transaction } from "@server/db";
import logger from "@server/logger";
import { ExitNodePingResult } from "@server/routers/newt";
import { eq } from "drizzle-orm";
@@ -59,7 +59,11 @@ export function selectBestExitNode(
return pingResults[0];
}
export async function checkExitNodeOrg(exitNodeId: number, orgId: string) {
export async function checkExitNodeOrg(
exitNodeId: number,
orgId: string,
trx?: Transaction | typeof db
): Promise<boolean> {
return false;
}

View File

@@ -18,7 +18,8 @@ import {
resources,
targets,
sites,
targetHealthCheck
targetHealthCheck,
Transaction
} from "@server/db";
import logger from "@server/logger";
import { ExitNodePingResult } from "@server/routers/newt";
@@ -333,8 +334,8 @@ export function selectBestExitNode(
return fallbackNode;
}
export async function checkExitNodeOrg(exitNodeId: number, orgId: string) {
const [exitNodeOrg] = await db
export async function checkExitNodeOrg(exitNodeId: number, orgId: string, trx: Transaction | typeof db = db) {
const [exitNodeOrg] = await trx
.select()
.from(exitNodeOrgs)
.where(

View File

@@ -98,7 +98,8 @@ export async function updateSiteBandwidth(
if (
await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
updatedSite.orgId,
trx
)
) {
// not allowed
@@ -242,7 +243,8 @@ export async function updateSiteBandwidth(
if (
await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
updatedSite.orgId,
trx
)
) {
// not allowed