diff --git a/messages/en-US.json b/messages/en-US.json index 03586cc6..24448e49 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -454,6 +454,8 @@ "accessRoleErrorAddDescription": "An error occurred while adding user to the role.", "userSaved": "User saved", "userSavedDescription": "The user has been updated.", + "autoProvisioned": "Auto Provisioned", + "autoProvisionedDescription": "Allow this user to be automatically managed by identity provider", "accessControlsDescription": "Manage what this user can access and do in the organization", "accessControlsSubmit": "Save Access Controls", "roles": "Roles", @@ -911,6 +913,8 @@ "idpConnectingToFinished": "Connected", "idpErrorConnectingTo": "There was a problem connecting to {name}. Please contact your administrator.", "idpErrorNotFound": "IdP not found", + "idpGoogleAlt": "Google", + "idpAzureAlt": "Azure", "inviteInvalid": "Invalid Invite", "inviteInvalidDescription": "The invite link is invalid.", "inviteErrorWrongUser": "Invite is not for this user", @@ -982,6 +986,8 @@ "licenseTierProfessionalRequired": "Professional Edition Required", "licenseTierProfessionalRequiredDescription": "This feature is only available in the Professional Edition.", "actionGetOrg": "Get Organization", + "updateOrgUser": "Update Org User", + "createOrgUser": "Create Org User", "actionUpdateOrg": "Update Organization", "actionUpdateUser": "Update User", "actionGetUser": "Get User", @@ -1496,5 +1502,7 @@ "convertButton": "Convert This Node to Managed Self-Hosted" }, "internationaldomaindetected": "International Domain Detected", - "willbestoredas": "Will be stored as:" + "willbestoredas": "Will be stored as:", + "idpGoogleDescription": "Google OAuth2/OIDC provider", + "idpAzureDescription": "Microsoft Azure OAuth2/OIDC provider" } diff --git a/public/idp/azure.png b/public/idp/azure.png new file mode 100644 index 00000000..d6ec5baf Binary files /dev/null and b/public/idp/azure.png differ diff --git a/public/idp/google.png b/public/idp/google.png new file mode 100644 index 00000000..da097687 Binary files /dev/null and b/public/idp/google.png differ diff --git a/server/auth/actions.ts b/server/auth/actions.ts index a3ad60ab..ecbbd058 100644 --- a/server/auth/actions.ts +++ b/server/auth/actions.ts @@ -100,7 +100,8 @@ export enum ActionsEnum { getApiKey = "getApiKey", createOrgDomain = "createOrgDomain", deleteOrgDomain = "deleteOrgDomain", - restartOrgDomain = "restartOrgDomain" + restartOrgDomain = "restartOrgDomain", + updateOrgUser = "updateOrgUser" } export async function checkUserActionPermission( diff --git a/server/db/pg/schema.ts b/server/db/pg/schema.ts index c0b81146..fc1f6ec3 100644 --- a/server/db/pg/schema.ts +++ b/server/db/pg/schema.ts @@ -213,7 +213,8 @@ export const userOrgs = pgTable("userOrgs", { roleId: integer("roleId") .notNull() .references(() => roles.roleId), - isOwner: boolean("isOwner").notNull().default(false) + isOwner: boolean("isOwner").notNull().default(false), + autoProvisioned: boolean("autoProvisioned").default(false) }); export const emailVerificationCodes = pgTable("emailVerificationCodes", { diff --git a/server/db/sqlite/schema.ts b/server/db/sqlite/schema.ts index 8e6cfb59..500e2605 100644 --- a/server/db/sqlite/schema.ts +++ b/server/db/sqlite/schema.ts @@ -107,7 +107,7 @@ export const resources = sqliteTable("resources", { enableProxy: integer("enableProxy", { mode: "boolean" }).default(true), skipToIdpId: integer("skipToIdpId").references(() => idp.idpId, { onDelete: "cascade" - }), + }) }); export const targets = sqliteTable("targets", { @@ -143,8 +143,11 @@ export const exitNodes = sqliteTable("exitNodes", { type: text("type").default("gerbil") // gerbil, remoteExitNode }); -export const siteResources = sqliteTable("siteResources", { // this is for the clients - siteResourceId: integer("siteResourceId").primaryKey({ autoIncrement: true }), +export const siteResources = sqliteTable("siteResources", { + // this is for the clients + siteResourceId: integer("siteResourceId").primaryKey({ + autoIncrement: true + }), siteId: integer("siteId") .notNull() .references(() => sites.siteId, { onDelete: "cascade" }), @@ -156,7 +159,7 @@ export const siteResources = sqliteTable("siteResources", { // this is for the c proxyPort: integer("proxyPort").notNull(), destinationPort: integer("destinationPort").notNull(), destinationIp: text("destinationIp").notNull(), - enabled: integer("enabled", { mode: "boolean" }).notNull().default(true), + enabled: integer("enabled", { mode: "boolean" }).notNull().default(true) }); export const users = sqliteTable("user", { @@ -260,7 +263,9 @@ export const clientSites = sqliteTable("clientSites", { siteId: integer("siteId") .notNull() .references(() => sites.siteId, { onDelete: "cascade" }), - isRelayed: integer("isRelayed", { mode: "boolean" }).notNull().default(false), + isRelayed: integer("isRelayed", { mode: "boolean" }) + .notNull() + .default(false), endpoint: text("endpoint") }); @@ -318,7 +323,10 @@ export const userOrgs = sqliteTable("userOrgs", { roleId: integer("roleId") .notNull() .references(() => roles.roleId), - isOwner: integer("isOwner", { mode: "boolean" }).notNull().default(false) + isOwner: integer("isOwner", { mode: "boolean" }).notNull().default(false), + autoProvisioned: integer("autoProvisioned", { + mode: "boolean" + }).default(false) }); export const emailVerificationCodes = sqliteTable("emailVerificationCodes", { diff --git a/server/routers/external.ts b/server/routers/external.ts index e421a3e2..44629db7 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -582,6 +582,14 @@ authenticated.put( user.createOrgUser ); +authenticated.post( + "/org/:orgId/user/:userId", + verifyOrgAccess, + verifyUserAccess, + verifyUserHasAction(ActionsEnum.updateOrgUser), + user.updateOrgUser +); + authenticated.get("/org/:orgId/user/:userId", verifyOrgAccess, user.getOrgUser); authenticated.post( diff --git a/server/routers/idp/listIdps.ts b/server/routers/idp/listIdps.ts index 2a0e5809..c9e2c271 100644 --- a/server/routers/idp/listIdps.ts +++ b/server/routers/idp/listIdps.ts @@ -1,11 +1,11 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { db, idpOidcConfig } from "@server/db"; import { domains, idp, orgDomains, users, idpOrg } from "@server/db"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; -import { sql } from "drizzle-orm"; +import { eq, sql } from "drizzle-orm"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { OpenAPITags, registry } from "@server/openApi"; @@ -33,10 +33,13 @@ async function query(limit: number, offset: number) { idpId: idp.idpId, name: idp.name, type: idp.type, - orgCount: sql`count(${idpOrg.orgId})` + variant: idpOidcConfig.variant, + orgCount: sql`count(${idpOrg.orgId})`, + autoProvision: idp.autoProvision }) .from(idp) .leftJoin(idpOrg, sql`${idp.idpId} = ${idpOrg.idpId}`) + .leftJoin(idpOidcConfig, eq(idp.idpId, idpOidcConfig.idpId)) .groupBy(idp.idpId) .limit(limit) .offset(offset); @@ -44,12 +47,7 @@ async function query(limit: number, offset: number) { } export type ListIdpsResponse = { - idps: Array<{ - idpId: number; - name: string; - type: string; - orgCount: number; - }>; + idps: Awaited>; pagination: { total: number; limit: number; diff --git a/server/routers/idp/validateOidcCallback.ts b/server/routers/idp/validateOidcCallback.ts index 67e2baad..46baa517 100644 --- a/server/routers/idp/validateOidcCallback.ts +++ b/server/routers/idp/validateOidcCallback.ts @@ -354,8 +354,13 @@ export async function validateOidcCallback( .from(userOrgs) .where(eq(userOrgs.userId, userId!)); - // Delete orgs that are no longer valid - const orgsToDelete = currentUserOrgs.filter( + // Filter to only auto-provisioned orgs for CRUD operations + const autoProvisionedOrgs = currentUserOrgs.filter( + (org) => org.autoProvisioned === true + ); + + // Delete auto-provisioned orgs that are no longer valid + const orgsToDelete = autoProvisionedOrgs.filter( (currentOrg) => !userOrgInfo.some( (newOrg) => newOrg.orgId === currentOrg.orgId @@ -374,8 +379,8 @@ export async function validateOidcCallback( ); } - // Update roles for existing orgs where the role has changed - const orgsToUpdate = currentUserOrgs.filter((currentOrg) => { + // Update roles for existing auto-provisioned orgs where the role has changed + const orgsToUpdate = autoProvisionedOrgs.filter((currentOrg) => { const newOrg = userOrgInfo.find( (newOrg) => newOrg.orgId === currentOrg.orgId ); @@ -401,7 +406,7 @@ export async function validateOidcCallback( } } - // Add new orgs that don't exist yet + // Add new orgs that don't exist yet (these will be auto-provisioned) const orgsToAdd = userOrgInfo.filter( (newOrg) => !currentUserOrgs.some( @@ -415,12 +420,14 @@ export async function validateOidcCallback( userId: userId!, orgId: org.orgId, roleId: org.roleId, + autoProvisioned: true, dateCreated: new Date().toISOString() })) ); } // Loop through all the orgs and get the total number of users from the userOrgs table + // Use all current user orgs (both auto-provisioned and manually added) for counting for (const org of currentUserOrgs) { const userCount = await trx .select() diff --git a/server/routers/integration.ts b/server/routers/integration.ts index 79453732..69bdbb42 100644 --- a/server/routers/integration.ts +++ b/server/routers/integration.ts @@ -24,7 +24,8 @@ import { verifyApiKeyIsRoot, verifyApiKeyClientAccess, verifyClientsEnabled, - verifyApiKeySiteResourceAccess + verifyApiKeySiteResourceAccess, + verifyOrgAccess } from "@server/middlewares"; import HttpCode from "@server/types/HttpCode"; import { Router } from "express"; @@ -469,6 +470,21 @@ authenticated.get( user.listUsers ); +authenticated.put( + "/org/:orgId/user", + verifyApiKeyOrgAccess, + verifyApiKeyHasAction(ActionsEnum.createOrgUser), + user.createOrgUser +); + +authenticated.post( + "/org/:orgId/user/:userId", + verifyApiKeyOrgAccess, + verifyApiKeyUserAccess, + verifyApiKeyHasAction(ActionsEnum.updateOrgUser), + user.updateOrgUser +); + authenticated.delete( "/org/:orgId/user/:userId", verifyApiKeyOrgAccess, diff --git a/server/routers/user/createOrgUser.ts b/server/routers/user/createOrgUser.ts index 5193e8fa..5b11c923 100644 --- a/server/routers/user/createOrgUser.ts +++ b/server/routers/user/createOrgUser.ts @@ -84,7 +84,14 @@ export async function createOrgUser( } const { orgId } = parsedParams.data; - const { username, email, name, type, idpId, roleId } = parsedBody.data; + const { + username, + email, + name, + type, + idpId, + roleId + } = parsedBody.data; const [role] = await db .select() @@ -173,7 +180,8 @@ export async function createOrgUser( .values({ orgId, userId: existingUser.userId, - roleId: role.roleId + roleId: role.roleId, + autoProvisioned: false }) .returning(); } else { @@ -189,7 +197,7 @@ export async function createOrgUser( type: "oidc", idpId, dateCreated: new Date().toISOString(), - emailVerified: true + emailVerified: true, }) .returning(); @@ -198,7 +206,8 @@ export async function createOrgUser( .values({ orgId, userId: newUser.userId, - roleId: role.roleId + roleId: role.roleId, + autoProvisioned: false }) .returning(); } @@ -209,7 +218,6 @@ export async function createOrgUser( .from(userOrgs) .where(eq(userOrgs.orgId, orgId)); }); - } else { return next( createHttpError(HttpCode.BAD_REQUEST, "User type is required") diff --git a/server/routers/user/getOrgUser.ts b/server/routers/user/getOrgUser.ts index 05e231c9..02ffd92c 100644 --- a/server/routers/user/getOrgUser.ts +++ b/server/routers/user/getOrgUser.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { db, idp, idpOidcConfig } from "@server/db"; import { roles, userOrgs, users } from "@server/db"; import { and, eq } from "drizzle-orm"; import response from "@server/lib/response"; @@ -25,10 +25,18 @@ async function queryUser(orgId: string, userId: string) { isOwner: userOrgs.isOwner, isAdmin: roles.isAdmin, twoFactorEnabled: users.twoFactorEnabled, + autoProvisioned: userOrgs.autoProvisioned, + idpId: users.idpId, + idpName: idp.name, + idpType: idp.type, + idpVariant: idpOidcConfig.variant, + idpAutoProvision: idp.autoProvision }) .from(userOrgs) .leftJoin(roles, eq(userOrgs.roleId, roles.roleId)) .leftJoin(users, eq(userOrgs.userId, users.userId)) + .leftJoin(idp, eq(users.idpId, idp.idpId)) + .leftJoin(idpOidcConfig, eq(idp.idpId, idpOidcConfig.idpId)) .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId))) .limit(1); return user; diff --git a/server/routers/user/index.ts b/server/routers/user/index.ts index 6d342ad3..7148eb87 100644 --- a/server/routers/user/index.ts +++ b/server/routers/user/index.ts @@ -13,3 +13,4 @@ export * from "./removeInvitation"; export * from "./createOrgUser"; export * from "./adminUpdateUser2FA"; export * from "./adminGetUser"; +export * from "./updateOrgUser"; diff --git a/server/routers/user/listUsers.ts b/server/routers/user/listUsers.ts index 83c1e492..a35da862 100644 --- a/server/routers/user/listUsers.ts +++ b/server/routers/user/listUsers.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { db, idpOidcConfig } from "@server/db"; import { idp, roles, userOrgs, users } from "@server/db"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; @@ -50,12 +50,15 @@ async function queryUsers(orgId: string, limit: number, offset: number) { isOwner: userOrgs.isOwner, idpName: idp.name, idpId: users.idpId, + idpType: idp.type, + idpVariant: idpOidcConfig.variant, twoFactorEnabled: users.twoFactorEnabled, }) .from(users) .leftJoin(userOrgs, eq(users.userId, userOrgs.userId)) .leftJoin(roles, eq(userOrgs.roleId, roles.roleId)) .leftJoin(idp, eq(users.idpId, idp.idpId)) + .leftJoin(idpOidcConfig, eq(idpOidcConfig.idpId, idp.idpId)) .where(eq(userOrgs.orgId, orgId)) .limit(limit) .offset(offset); diff --git a/server/routers/user/updateOrgUser.ts b/server/routers/user/updateOrgUser.ts new file mode 100644 index 00000000..fb00b59f --- /dev/null +++ b/server/routers/user/updateOrgUser.ts @@ -0,0 +1,112 @@ +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { db, userOrgs } from "@server/db"; +import { and, eq } from "drizzle-orm"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; +import { OpenAPITags, registry } from "@server/openApi"; + +const paramsSchema = z + .object({ + userId: z.string(), + orgId: z.string() + }) + .strict(); + +const bodySchema = z + .object({ + autoProvisioned: z.boolean().optional() + }) + .strict() + .refine((data) => Object.keys(data).length > 0, { + message: "At least one field must be provided for update" + }); + +registry.registerPath({ + method: "post", + path: "/org/{orgId}/user/{userId}", + description: "Update a user in an org.", + tags: [OpenAPITags.Org, OpenAPITags.User], + request: { + params: paramsSchema, + body: { + content: { + "application/json": { + schema: bodySchema + } + } + } + }, + responses: {} +}); + +export async function updateOrgUser( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = paramsSchema.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const parsedBody = bodySchema.safeParse(req.body); + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { userId, orgId } = parsedParams.data; + + const [existingUser] = await db + .select() + .from(userOrgs) + .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId))) + .limit(1); + + if (!existingUser) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + "User not found in this organization" + ) + ); + } + + const updateData = parsedBody.data; + + const [updatedUser] = await db + .update(userOrgs) + .set({ + ...updateData + }) + .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId))) + .returning(); + + return response(res, { + data: updatedUser, + success: true, + error: false, + message: "Org user updated successfully", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +} diff --git a/src/app/[orgId]/settings/access/users/[userId]/access-controls/page.tsx b/src/app/[orgId]/settings/access/users/[userId]/access-controls/page.tsx index 82999ad2..a13d08ca 100644 --- a/src/app/[orgId]/settings/access/users/[userId]/access-controls/page.tsx +++ b/src/app/[orgId]/settings/access/users/[userId]/access-controls/page.tsx @@ -16,6 +16,7 @@ import { SelectTrigger, SelectValue } from "@app/components/ui/select"; +import { Checkbox } from "@app/components/ui/checkbox"; import { toast } from "@app/hooks/useToast"; import { zodResolver } from "@hookform/resolvers/zod"; import { InviteUserResponse } from "@server/routers/user"; @@ -41,6 +42,8 @@ import { formatAxiosError } from "@app/lib/api"; import { createApiClient } from "@app/lib/api"; import { useEnvContext } from "@app/hooks/useEnvContext"; import { useTranslations } from "next-intl"; +import IdpTypeBadge from "@app/components/IdpTypeBadge"; +import { UserType } from "@server/types/UserTypes"; export default function AccessControlsPage() { const { orgUser: user } = userOrgUserContext(); @@ -56,14 +59,16 @@ export default function AccessControlsPage() { const formSchema = z.object({ username: z.string(), - roleId: z.string().min(1, { message: t('accessRoleSelectPlease') }) + roleId: z.string().min(1, { message: t('accessRoleSelectPlease') }), + autoProvisioned: z.boolean() }); const form = useForm>({ resolver: zodResolver(formSchema), defaultValues: { username: user.username!, - roleId: user.roleId?.toString() + roleId: user.roleId?.toString(), + autoProvisioned: user.autoProvisioned || false } }); @@ -91,31 +96,36 @@ export default function AccessControlsPage() { fetchRoles(); form.setValue("roleId", user.roleId.toString()); + form.setValue("autoProvisioned", user.autoProvisioned || false); }, []); async function onSubmit(values: z.infer) { setLoading(true); - const res = await api - .post< - AxiosResponse - >(`/role/${values.roleId}/add/${user.userId}`) - .catch((e) => { - toast({ - variant: "destructive", - title: t('accessRoleErrorAdd'), - description: formatAxiosError( - e, - t('accessRoleErrorAddDescription') - ) - }); - }); + try { + // Execute both API calls simultaneously + const [roleRes, userRes] = await Promise.all([ + api.post>(`/role/${values.roleId}/add/${user.userId}`), + api.post(`/org/${orgId}/user/${user.userId}`, { + autoProvisioned: values.autoProvisioned + }) + ]); - if (res && res.status === 200) { + if (roleRes.status === 200 && userRes.status === 200) { + toast({ + variant: "default", + title: t('userSaved'), + description: t('userSavedDescription') + }); + } + } catch (e) { toast({ - variant: "default", - title: t('userSaved'), - description: t('userSavedDescription') + variant: "destructive", + title: t('accessRoleErrorAdd'), + description: formatAxiosError( + e, + t('accessRoleErrorAddDescription') + ) }); } @@ -140,6 +150,20 @@ export default function AccessControlsPage() { className="space-y-4" id="access-controls-form" > + {/* IDP Type Display */} + {user.type !== UserType.Internal && user.idpType && ( +
+ + {t("idp")}: + + +
+ )} + {t('role')} - -

- {t( - "usernameUniq" - )} -

- - + {/* Google/Azure Form */} + {(() => { + const selectedUserOption = userOptions.find(opt => opt.id === selectedOption); + return selectedUserOption?.variant === "google" || selectedUserOption?.variant === "azure"; + })() && ( +
+ - - ( - - - {t( - "emailOptional" - )} - - - - - - - )} - /> - - ( - - - {t( - "nameOptional" - )} - - - - - - - )} - /> - - ( - - - {t("role")} - - - - {roles.map( - ( - role - ) => ( + + + )} + /> + + ( + + + {t("nameOptional")} + + + + + + + )} + /> + + ( + + + {t("role")} + + - - + ))} + + + + + )} + /> + + + )} + + {/* Generic OIDC Form */} + {(() => { + const selectedUserOption = userOptions.find(opt => opt.id === selectedOption); + return selectedUserOption?.variant !== "google" && selectedUserOption?.variant !== "azure"; + })() && ( +
+ -
- + className="space-y-4" + id="create-user-form" + > + ( + + + {t("username")} + + + + +

+ {t("usernameUniq")} +

+ +
+ )} + /> + + ( + + + {t("emailOptional")} + + + + + + + )} + /> + + ( + + + {t("nameOptional")} + + + + + + + )} + /> + + ( + + + {t("role")} + + + + + )} + /> + + + )} - )} - )}
- {userType && dataLoaded && ( + {selectedOption && dataLoaded && (
- {idps.map((idp) => ( - - ))} + {idps.map((idp) => { + const effectiveType = idp.variant || idp.name.toLowerCase(); + + return ( + + ); + })} )} diff --git a/src/components/PermissionsSelectBox.tsx b/src/components/PermissionsSelectBox.tsx index d8f9b59f..3334cca5 100644 --- a/src/components/PermissionsSelectBox.tsx +++ b/src/components/PermissionsSelectBox.tsx @@ -27,7 +27,9 @@ function getActionsCategories(root: boolean) { [t('actionListInvitations')]: "listInvitations", [t('actionRemoveUser')]: "removeUser", [t('actionListUsers')]: "listUsers", - [t('actionListOrgDomains')]: "listOrgDomains" + [t('actionListOrgDomains')]: "listOrgDomains", + [t('updateOrgUser')]: "updateOrgUser", + [t('createOrgUser')]: "createOrgUser" }, Site: { diff --git a/src/components/UsersTable.tsx b/src/components/UsersTable.tsx index 2e9f8d67..2d4c122f 100644 --- a/src/components/UsersTable.tsx +++ b/src/components/UsersTable.tsx @@ -21,6 +21,7 @@ import { createApiClient } from "@app/lib/api"; import { useEnvContext } from "@app/hooks/useEnvContext"; import { useUserContext } from "@app/hooks/useUserContext"; import { useTranslations } from "next-intl"; +import IdpTypeBadge from "./IdpTypeBadge"; export type UserRow = { id: string; @@ -31,6 +32,7 @@ export type UserRow = { idpId: number | null; idpName: string; type: string; + idpVariant: string | null; status: string; role: string; isOwner: boolean; @@ -81,6 +83,16 @@ export default function UsersTable({ users: u }: UsersTableProps) { ); + }, + cell: ({ row }) => { + const userRow = row.original; + return ( + + ); } }, {