mirror of
https://github.com/ajnart/homarr.git
synced 2026-01-29 10:49:14 +01:00
refactor: replace signIn callback with signIn event, adjust getUserByEmail in adapter to check provider (#1223)
* refactor: replace signIn callback with signIn event, adjust getUserByEmail in adapter to check provider * test: adjusting tests for adapter and events * docs: add comments for unknown auth provider * fix: missing dayjs import
This commit is contained in:
@@ -12,8 +12,7 @@ import type { useForm } from "@homarr/form";
|
||||
import { useZodForm } from "@homarr/form";
|
||||
import { showErrorNotification, showSuccessNotification } from "@homarr/notifications";
|
||||
import { useScopedI18n } from "@homarr/translation/client";
|
||||
import type { z } from "@homarr/validation";
|
||||
import { validation } from "@homarr/validation";
|
||||
import { validation, z } from "@homarr/validation";
|
||||
|
||||
interface LoginFormProps {
|
||||
providers: string[];
|
||||
@@ -22,15 +21,17 @@ interface LoginFormProps {
|
||||
callbackUrl: string;
|
||||
}
|
||||
|
||||
const extendedValidation = validation.user.signIn.extend({ provider: z.enum(["credentials", "ldap"]) });
|
||||
|
||||
export const LoginForm = ({ providers, oidcClientName, isOidcAutoLoginEnabled, callbackUrl }: LoginFormProps) => {
|
||||
const t = useScopedI18n("user");
|
||||
const router = useRouter();
|
||||
const [isPending, setIsPending] = useState(false);
|
||||
const form = useZodForm(validation.user.signIn, {
|
||||
const form = useZodForm(extendedValidation, {
|
||||
initialValues: {
|
||||
name: "",
|
||||
password: "",
|
||||
credentialType: "basic",
|
||||
provider: "credentials",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -95,14 +96,14 @@ export const LoginForm = ({ providers, oidcClientName, isOidcAutoLoginEnabled, c
|
||||
<Stack gap="lg">
|
||||
{credentialInputsVisible && (
|
||||
<>
|
||||
<form onSubmit={form.onSubmit((credentials) => void signInAsync("credentials", credentials))}>
|
||||
<form onSubmit={form.onSubmit((credentials) => void signInAsync(credentials.provider, credentials))}>
|
||||
<Stack gap="lg">
|
||||
<TextInput label={t("field.username.label")} {...form.getInputProps("name")} />
|
||||
<PasswordInput label={t("field.password.label")} {...form.getInputProps("password")} />
|
||||
|
||||
{providers.includes("credentials") && (
|
||||
<Stack gap="sm">
|
||||
<SubmitButton isPending={isPending} form={form} credentialType="basic">
|
||||
<SubmitButton isPending={isPending} form={form} provider="credentials">
|
||||
{t("action.login.label")}
|
||||
</SubmitButton>
|
||||
<PasswordForgottenCollapse username={form.values.name} />
|
||||
@@ -110,7 +111,7 @@ export const LoginForm = ({ providers, oidcClientName, isOidcAutoLoginEnabled, c
|
||||
)}
|
||||
|
||||
{providers.includes("ldap") && (
|
||||
<SubmitButton isPending={isPending} form={form} credentialType="ldap">
|
||||
<SubmitButton isPending={isPending} form={form} provider="ldap">
|
||||
{t("action.login.labelWith", { provider: "LDAP" })}
|
||||
</SubmitButton>
|
||||
)}
|
||||
@@ -133,18 +134,18 @@ export const LoginForm = ({ providers, oidcClientName, isOidcAutoLoginEnabled, c
|
||||
interface SubmitButtonProps {
|
||||
isPending: boolean;
|
||||
form: ReturnType<typeof useForm<FormType, (values: FormType) => FormType>>;
|
||||
credentialType: "basic" | "ldap";
|
||||
provider: "credentials" | "ldap";
|
||||
}
|
||||
|
||||
const SubmitButton = ({ isPending, form, credentialType, children }: PropsWithChildren<SubmitButtonProps>) => {
|
||||
const isCurrentProviderActive = form.getValues().credentialType === credentialType;
|
||||
const SubmitButton = ({ isPending, form, provider, children }: PropsWithChildren<SubmitButtonProps>) => {
|
||||
const isCurrentProviderActive = form.getValues().provider === provider;
|
||||
|
||||
return (
|
||||
<Button
|
||||
type="submit"
|
||||
name={credentialType}
|
||||
name={provider}
|
||||
fullWidth
|
||||
onClick={() => form.setFieldValue("credentialType", credentialType)}
|
||||
onClick={() => form.setFieldValue("provider", provider)}
|
||||
loading={isPending && isCurrentProviderActive}
|
||||
disabled={isPending && !isCurrentProviderActive}
|
||||
>
|
||||
@@ -181,4 +182,4 @@ const PasswordForgottenCollapse = ({ username }: PasswordForgottenCollapseProps)
|
||||
);
|
||||
};
|
||||
|
||||
type FormType = z.infer<typeof validation.user.signIn>;
|
||||
type FormType = z.infer<typeof extendedValidation>;
|
||||
|
||||
@@ -1,17 +1,37 @@
|
||||
import { NextRequest } from "next/server";
|
||||
|
||||
import { createHandlers } from "@homarr/auth";
|
||||
import type { SupportedAuthProvider } from "@homarr/definitions";
|
||||
import { logger } from "@homarr/log";
|
||||
|
||||
export const GET = async (req: NextRequest) => {
|
||||
return await createHandlers(isCredentialsRequest(req)).handlers.GET(reqWithTrustedOrigin(req));
|
||||
return await createHandlers(extractProvider(req)).handlers.GET(reqWithTrustedOrigin(req));
|
||||
};
|
||||
export const POST = async (req: NextRequest) => {
|
||||
return await createHandlers(isCredentialsRequest(req)).handlers.POST(reqWithTrustedOrigin(req));
|
||||
return await createHandlers(extractProvider(req)).handlers.POST(reqWithTrustedOrigin(req));
|
||||
};
|
||||
|
||||
const isCredentialsRequest = (req: NextRequest) => {
|
||||
return req.url.includes("credentials") && req.method === "POST";
|
||||
/**
|
||||
* This method extracts the used provider from the url and allows us to override the getUserByEmail method in the adapter.
|
||||
* @param req request containing the url
|
||||
* @returns the provider or "unknown" if the provider could not be extracted
|
||||
*/
|
||||
const extractProvider = (req: NextRequest): SupportedAuthProvider | "unknown" => {
|
||||
const url = new URL(req.url);
|
||||
|
||||
if (url.pathname.includes("oidc")) {
|
||||
return "oidc";
|
||||
}
|
||||
|
||||
if (url.pathname.includes("credentials")) {
|
||||
return "credentials";
|
||||
}
|
||||
|
||||
if (url.pathname.includes("ldap")) {
|
||||
return "ldap";
|
||||
}
|
||||
|
||||
return "unknown";
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,5 +1,45 @@
|
||||
import type { Adapter } from "@auth/core/adapters";
|
||||
import { DrizzleAdapter } from "@auth/drizzle-adapter";
|
||||
|
||||
import { db } from "@homarr/db";
|
||||
import type { Database } from "@homarr/db";
|
||||
import { and, eq } from "@homarr/db";
|
||||
import { accounts, users } from "@homarr/db/schema/sqlite";
|
||||
import type { SupportedAuthProvider } from "@homarr/definitions";
|
||||
|
||||
export const adapter = DrizzleAdapter(db);
|
||||
export const createAdapter = (db: Database, provider: SupportedAuthProvider | "unknown"): Adapter => {
|
||||
const drizzleAdapter = DrizzleAdapter(db, { usersTable: users, accountsTable: accounts });
|
||||
|
||||
return {
|
||||
...drizzleAdapter,
|
||||
// We override the default implementation as we want to have a provider
|
||||
// flag in the user instead of the account to not intermingle users from different providers
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
getUserByEmail: async (email) => {
|
||||
if (provider === "unknown") {
|
||||
throw new Error("Unable to get user by email for unknown provider");
|
||||
}
|
||||
|
||||
const user = await db.query.users.findFirst({
|
||||
where: and(eq(users.email, email), eq(users.provider, provider)),
|
||||
columns: {
|
||||
id: true,
|
||||
name: true,
|
||||
email: true,
|
||||
emailVerified: true,
|
||||
image: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (!user) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
...user,
|
||||
// We allow null as email for credentials provider
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
email: user.email!,
|
||||
};
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import { cookies } from "next/headers";
|
||||
import type { Adapter } from "@auth/core/adapters";
|
||||
import dayjs from "dayjs";
|
||||
import type { NextAuthConfig } from "next-auth";
|
||||
|
||||
@@ -9,9 +7,6 @@ import { eq, inArray } from "@homarr/db";
|
||||
import { groupMembers, groupPermissions, users } from "@homarr/db/schema/sqlite";
|
||||
import { getPermissionsWithChildren } from "@homarr/definitions";
|
||||
|
||||
import { env } from "./env.mjs";
|
||||
import { expireDateAfter, generateSessionToken, sessionTokenCookieName } from "./session";
|
||||
|
||||
export const getCurrentUserPermissionsAsync = async (db: Database, userId: string) => {
|
||||
const dbGroupMembers = await db.query.groupMembers.findMany({
|
||||
where: eq(groupMembers.userId, userId),
|
||||
@@ -68,51 +63,6 @@ export const createSessionCallback = (db: Database): NextAuthCallbackOf<"session
|
||||
};
|
||||
};
|
||||
|
||||
export const createSignInCallback =
|
||||
(adapter: Adapter, db: Database, isCredentialsRequest: boolean): NextAuthCallbackOf<"signIn"> =>
|
||||
async ({ user }) => {
|
||||
if (!isCredentialsRequest) return true;
|
||||
|
||||
// https://github.com/nextauthjs/next-auth/issues/6106
|
||||
if (!adapter.createSession || !user.id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const sessionToken = generateSessionToken();
|
||||
const sessionExpires = expireDateAfter(env.AUTH_SESSION_EXPIRY_TIME);
|
||||
|
||||
await adapter.createSession({
|
||||
sessionToken,
|
||||
userId: user.id,
|
||||
expires: sessionExpires,
|
||||
});
|
||||
|
||||
cookies().set(sessionTokenCookieName, sessionToken, {
|
||||
path: "/",
|
||||
expires: sessionExpires,
|
||||
httpOnly: true,
|
||||
sameSite: "lax",
|
||||
secure: true,
|
||||
});
|
||||
|
||||
const dbUser = await db.query.users.findFirst({
|
||||
where: eq(users.id, user.id),
|
||||
columns: {
|
||||
colorScheme: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (!dbUser) return false;
|
||||
|
||||
// We use a cookie as localStorage is not shared with server (otherwise flickering would occur)
|
||||
cookies().set("homarr-color-scheme", dbUser.colorScheme, {
|
||||
path: "/",
|
||||
expires: dayjs().add(1, "year").toDate(),
|
||||
});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
type NextAuthCallbackRecord = Exclude<NextAuthConfig["callbacks"], undefined>;
|
||||
export type NextAuthCallbackOf<TKey extends keyof NextAuthCallbackRecord> = Exclude<
|
||||
NextAuthCallbackRecord[TKey],
|
||||
|
||||
@@ -4,18 +4,21 @@ import NextAuth from "next-auth";
|
||||
import Credentials from "next-auth/providers/credentials";
|
||||
|
||||
import { db } from "@homarr/db";
|
||||
import type { SupportedAuthProvider } from "@homarr/definitions";
|
||||
|
||||
import { adapter } from "./adapter";
|
||||
import { createSessionCallback, createSignInCallback } from "./callbacks";
|
||||
import { createAdapter } from "./adapter";
|
||||
import { createSessionCallback } from "./callbacks";
|
||||
import { env } from "./env.mjs";
|
||||
import { createCredentialsConfiguration } from "./providers/credentials/credentials-provider";
|
||||
import { createSignInEventHandler } from "./events";
|
||||
import { createCredentialsConfiguration, createLdapConfiguration } from "./providers/credentials/credentials-provider";
|
||||
import { EmptyNextAuthProvider } from "./providers/empty/empty-provider";
|
||||
import { filterProviders } from "./providers/filter-providers";
|
||||
import { OidcProvider } from "./providers/oidc/oidc-provider";
|
||||
import { createRedirectUri } from "./redirect";
|
||||
import { sessionTokenCookieName } from "./session";
|
||||
import { generateSessionToken, sessionTokenCookieName } from "./session";
|
||||
|
||||
export const createConfiguration = (isCredentialsRequest: boolean, headers: ReadonlyHeaders | null) =>
|
||||
// See why it's unknown in the [...nextauth]/route.ts file
|
||||
export const createConfiguration = (provider: SupportedAuthProvider | "unknown", headers: ReadonlyHeaders | null) =>
|
||||
NextAuth({
|
||||
logger: {
|
||||
error: (code, ...message) => {
|
||||
@@ -30,21 +33,25 @@ export const createConfiguration = (isCredentialsRequest: boolean, headers: Read
|
||||
},
|
||||
},
|
||||
trustHost: true,
|
||||
adapter,
|
||||
adapter: createAdapter(db, provider),
|
||||
providers: filterProviders([
|
||||
Credentials(createCredentialsConfiguration(db)),
|
||||
Credentials(createLdapConfiguration(db)),
|
||||
EmptyNextAuthProvider(),
|
||||
OidcProvider(headers),
|
||||
]),
|
||||
callbacks: {
|
||||
session: createSessionCallback(db),
|
||||
signIn: createSignInCallback(adapter, db, isCredentialsRequest),
|
||||
},
|
||||
events: {
|
||||
signIn: createSignInEventHandler(db),
|
||||
},
|
||||
redirectProxyUrl: createRedirectUri(headers, "/api/auth"),
|
||||
secret: "secret-is-not-defined-yet", // TODO: This should be added later
|
||||
session: {
|
||||
strategy: "database",
|
||||
maxAge: env.AUTH_SESSION_EXPIRY_TIME,
|
||||
generateSessionToken,
|
||||
},
|
||||
pages: {
|
||||
signIn: "/auth/login",
|
||||
|
||||
@@ -74,6 +74,7 @@ export const env = createEnv({
|
||||
AUTH_OIDC_CLIENT_NAME: z.string().min(1).default("OIDC"),
|
||||
AUTH_OIDC_AUTO_LOGIN: booleanSchema,
|
||||
AUTH_OIDC_SCOPE_OVERWRITE: z.string().min(1).default("openid email profile groups"),
|
||||
AUTH_OIDC_GROUPS_ATTRIBUTE: z.string().default("groups"), // Is used in the signIn event to assign the correct groups, key is from object of decoded id_token
|
||||
}
|
||||
: {}),
|
||||
...(authProviders.includes("ldap")
|
||||
@@ -113,6 +114,7 @@ export const env = createEnv({
|
||||
AUTH_OIDC_CLIENT_SECRET: process.env.AUTH_OIDC_CLIENT_SECRET,
|
||||
AUTH_OIDC_ISSUER: process.env.AUTH_OIDC_ISSUER,
|
||||
AUTH_OIDC_SCOPE_OVERWRITE: process.env.AUTH_OIDC_SCOPE_OVERWRITE,
|
||||
AUTH_OIDC_GROUPS_ATTRIBUTE: process.env.AUTH_OIDC_GROUPS_ATTRIBUTE,
|
||||
AUTH_LDAP_USERNAME_ATTRIBUTE: process.env.AUTH_LDAP_USERNAME_ATTRIBUTE,
|
||||
AUTH_LDAP_USER_MAIL_ATTRIBUTE: process.env.AUTH_LDAP_USER_MAIL_ATTRIBUTE,
|
||||
AUTH_LDAP_USERNAME_FILTER_EXTRA_ARG: process.env.AUTH_LDAP_USERNAME_FILTER_EXTRA_ARG,
|
||||
|
||||
131
packages/auth/events.ts
Normal file
131
packages/auth/events.ts
Normal file
@@ -0,0 +1,131 @@
|
||||
import { cookies } from "next/headers";
|
||||
import dayjs from "dayjs";
|
||||
import type { NextAuthConfig } from "next-auth";
|
||||
|
||||
import { and, eq, inArray } from "@homarr/db";
|
||||
import type { Database } from "@homarr/db";
|
||||
import { groupMembers, groups, users } from "@homarr/db/schema/sqlite";
|
||||
import { logger } from "@homarr/log";
|
||||
|
||||
import { env } from "./env.mjs";
|
||||
|
||||
export const createSignInEventHandler = (db: Database): Exclude<NextAuthConfig["events"], undefined>["signIn"] => {
|
||||
return async ({ user, profile }) => {
|
||||
if (!user.id) throw new Error("User ID is missing");
|
||||
|
||||
const dbUser = await db.query.users.findFirst({
|
||||
where: eq(users.id, user.id),
|
||||
columns: {
|
||||
name: true,
|
||||
colorScheme: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (!dbUser) throw new Error("User not found");
|
||||
|
||||
const groupsKey = env.AUTH_OIDC_GROUPS_ATTRIBUTE;
|
||||
// Groups from oidc provider are provided from the profile, it's not typed.
|
||||
if (profile && groupsKey in profile && Array.isArray(profile[groupsKey])) {
|
||||
await synchronizeGroupsWithExternalForUserAsync(db, user.id, profile[groupsKey] as string[]);
|
||||
}
|
||||
|
||||
// In ldap-authroization we return the groups from ldap, it's not typed.
|
||||
if ("groups" in user && Array.isArray(user.groups)) {
|
||||
await synchronizeGroupsWithExternalForUserAsync(db, user.id, user.groups as string[]);
|
||||
}
|
||||
|
||||
if (dbUser.name !== user.name) {
|
||||
await db.update(users).set({ name: user.name }).where(eq(users.id, user.id));
|
||||
logger.info(
|
||||
`Username for user of credentials provider has changed. user=${user.id} old=${dbUser.name} new=${user.name}`,
|
||||
);
|
||||
}
|
||||
|
||||
const profileUsername = profile?.preferred_username?.includes("@") ? profile.name : profile?.preferred_username;
|
||||
if (profileUsername && dbUser.name !== profileUsername) {
|
||||
await db.update(users).set({ name: profileUsername }).where(eq(users.id, user.id));
|
||||
logger.info(
|
||||
`Username for user of oidc provider has changed. user=${user.id} old='${dbUser.name}' new='${profileUsername}'`,
|
||||
);
|
||||
}
|
||||
|
||||
// We use a cookie as localStorage is not shared with server (otherwise flickering would occur)
|
||||
cookies().set("homarr-color-scheme", dbUser.colorScheme, {
|
||||
path: "/",
|
||||
expires: dayjs().add(1, "year").toDate(),
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
const synchronizeGroupsWithExternalForUserAsync = async (db: Database, userId: string, externalGroups: string[]) => {
|
||||
const dbGroupMembers = await db.query.groupMembers.findMany({
|
||||
where: eq(groupMembers.userId, userId),
|
||||
with: {
|
||||
group: { columns: { name: true } },
|
||||
},
|
||||
});
|
||||
|
||||
/**
|
||||
* The below groups are those groups the user is part of in the external system, but not in Homarr.
|
||||
* So he has to be added to those groups.
|
||||
*/
|
||||
const missingExternalGroupsForUser = externalGroups.filter(
|
||||
(externalGroup) => !dbGroupMembers.some(({ group }) => group.name === externalGroup),
|
||||
);
|
||||
|
||||
if (missingExternalGroupsForUser.length > 0) {
|
||||
logger.debug(
|
||||
`Homarr does not have the user in certain groups. user=${userId} count=${missingExternalGroupsForUser.length}`,
|
||||
);
|
||||
|
||||
const groupIds = await db.query.groups.findMany({
|
||||
columns: {
|
||||
id: true,
|
||||
},
|
||||
where: inArray(groups.name, missingExternalGroupsForUser),
|
||||
});
|
||||
|
||||
logger.debug(`Homarr has found groups in the database user is not in. user=${userId} count=${groupIds.length}`);
|
||||
|
||||
if (groupIds.length > 0) {
|
||||
await db.insert(groupMembers).values(
|
||||
groupIds.map((group) => ({
|
||||
userId,
|
||||
groupId: group.id,
|
||||
})),
|
||||
);
|
||||
|
||||
logger.info(`Added user to groups successfully. user=${userId} count=${groupIds.length}`);
|
||||
} else {
|
||||
logger.debug(`User is already in all groups of Homarr. user=${userId}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The below groups are those groups the user is part of in Homarr, but not in the external system.
|
||||
* So he has to be removed from those groups.
|
||||
*/
|
||||
const groupsUserIsNoLongerMemberOfExternally = dbGroupMembers.filter(
|
||||
({ group }) => !externalGroups.includes(group.name),
|
||||
);
|
||||
|
||||
if (groupsUserIsNoLongerMemberOfExternally.length > 0) {
|
||||
logger.debug(
|
||||
`Homarr has the user in certain groups that LDAP does not have. user=${userId} count=${groupsUserIsNoLongerMemberOfExternally.length}`,
|
||||
);
|
||||
|
||||
await db.delete(groupMembers).where(
|
||||
and(
|
||||
eq(groupMembers.userId, userId),
|
||||
inArray(
|
||||
groupMembers.groupId,
|
||||
groupsUserIsNoLongerMemberOfExternally.map(({ groupId }) => groupId),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
logger.info(
|
||||
`Removed user from groups successfully. user=${userId} count=${groupsUserIsNoLongerMemberOfExternally.length}`,
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
import { headers } from "next/headers";
|
||||
import type { DefaultSession } from "@auth/core/types";
|
||||
|
||||
import type { ColorScheme, GroupPermissionKey } from "@homarr/definitions";
|
||||
import type { ColorScheme, GroupPermissionKey, SupportedAuthProvider } from "@homarr/definitions";
|
||||
|
||||
import { createConfiguration } from "./configuration";
|
||||
|
||||
@@ -19,6 +19,7 @@ declare module "next-auth" {
|
||||
|
||||
export * from "./security";
|
||||
|
||||
export const createHandlers = (isCredentialsRequest: boolean) => createConfiguration(isCredentialsRequest, headers());
|
||||
// See why it's unknown in the [...nextauth]/route.ts file
|
||||
export const createHandlers = (provider: SupportedAuthProvider | "unknown") => createConfiguration(provider, headers());
|
||||
|
||||
export { getSessionFromTokenAsync as getSessionFromToken, sessionTokenCookieName } from "./session";
|
||||
|
||||
@@ -2,7 +2,7 @@ import { cache } from "react";
|
||||
|
||||
import { createConfiguration } from "./configuration";
|
||||
|
||||
const { auth: defaultAuth } = createConfiguration(false, null);
|
||||
const { auth: defaultAuth } = createConfiguration("unknown", null);
|
||||
|
||||
/**
|
||||
* This is the main way to get session data for your RSCs.
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { CredentialsSignin } from "@auth/core/errors";
|
||||
|
||||
import type { Database, InferInsertModel } from "@homarr/db";
|
||||
import { and, createId, eq, inArray } from "@homarr/db";
|
||||
import { groupMembers, groups, users } from "@homarr/db/schema/sqlite";
|
||||
import { and, createId, eq } from "@homarr/db";
|
||||
import { users } from "@homarr/db/schema/sqlite";
|
||||
import { logger } from "@homarr/log";
|
||||
import type { validation } from "@homarr/validation";
|
||||
import { z } from "@homarr/validation";
|
||||
@@ -99,18 +99,6 @@ export const authorizeWithLdapCredentialsAsync = async (
|
||||
emailVerified: true,
|
||||
provider: true,
|
||||
},
|
||||
with: {
|
||||
groups: {
|
||||
with: {
|
||||
group: {
|
||||
columns: {
|
||||
id: true,
|
||||
name: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
where: and(eq(users.email, mailResult.data), eq(users.provider, "ldap")),
|
||||
});
|
||||
|
||||
@@ -128,79 +116,16 @@ export const authorizeWithLdapCredentialsAsync = async (
|
||||
|
||||
await db.insert(users).values(insertUser);
|
||||
|
||||
user = {
|
||||
...insertUser,
|
||||
groups: [],
|
||||
};
|
||||
user = insertUser;
|
||||
|
||||
logger.info(`User ${credentials.name} created successfully.`);
|
||||
}
|
||||
|
||||
if (user.name !== credentials.name) {
|
||||
logger.warn(`User ${credentials.name} found in the database but with different name. Updating...`);
|
||||
|
||||
user.name = credentials.name;
|
||||
|
||||
await db.update(users).set({ name: user.name }).where(eq(users.id, user.id));
|
||||
|
||||
logger.info(`User ${credentials.name} updated successfully.`);
|
||||
}
|
||||
|
||||
const ldapGroupsUserIsNotIn = userGroups.filter(
|
||||
(group) => !user.groups.some((userGroup) => userGroup.group.name === group),
|
||||
);
|
||||
|
||||
if (ldapGroupsUserIsNotIn.length > 0) {
|
||||
logger.debug(
|
||||
`Homarr does not have the user in certain groups. user=${user.name} count=${ldapGroupsUserIsNotIn.length}`,
|
||||
);
|
||||
|
||||
const groupIds = await db.query.groups.findMany({
|
||||
columns: {
|
||||
id: true,
|
||||
},
|
||||
where: inArray(groups.name, ldapGroupsUserIsNotIn),
|
||||
});
|
||||
|
||||
logger.debug(`Homarr has found groups in the database user is not in. user=${user.name} count=${groupIds.length}`);
|
||||
|
||||
if (groupIds.length > 0) {
|
||||
await db.insert(groupMembers).values(
|
||||
groupIds.map((group) => ({
|
||||
userId: user.id,
|
||||
groupId: group.id,
|
||||
})),
|
||||
);
|
||||
|
||||
logger.info(`Added user to groups successfully. user=${user.name} count=${groupIds.length}`);
|
||||
} else {
|
||||
logger.debug(`User is already in all groups of Homarr. user=${user.name}`);
|
||||
}
|
||||
}
|
||||
|
||||
const homarrGroupsUserIsNotIn = user.groups.filter((userGroup) => !userGroups.includes(userGroup.group.name));
|
||||
|
||||
if (homarrGroupsUserIsNotIn.length > 0) {
|
||||
logger.debug(
|
||||
`Homarr has the user in certain groups that LDAP does not have. user=${user.name} count=${homarrGroupsUserIsNotIn.length}`,
|
||||
);
|
||||
|
||||
await db.delete(groupMembers).where(
|
||||
and(
|
||||
eq(groupMembers.userId, user.id),
|
||||
inArray(
|
||||
groupMembers.groupId,
|
||||
homarrGroupsUserIsNotIn.map(({ groupId }) => groupId),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
logger.info(`Removed user from groups successfully. user=${user.name} count=${homarrGroupsUserIsNotIn.length}`);
|
||||
}
|
||||
|
||||
return {
|
||||
id: user.id,
|
||||
name: user.name,
|
||||
name: credentials.name,
|
||||
// Groups is used in events.ts to synchronize groups with external systems
|
||||
groups: userGroups,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -10,30 +10,25 @@ type CredentialsConfiguration = Parameters<typeof Credentials>[0];
|
||||
|
||||
export const createCredentialsConfiguration = (db: Database) =>
|
||||
({
|
||||
id: "credentials",
|
||||
type: "credentials",
|
||||
name: "Credentials",
|
||||
credentials: {
|
||||
name: {
|
||||
label: "Username",
|
||||
type: "text",
|
||||
},
|
||||
password: {
|
||||
label: "Password",
|
||||
type: "password",
|
||||
},
|
||||
isLdap: {
|
||||
label: "LDAP",
|
||||
type: "checkbox",
|
||||
},
|
||||
},
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
async authorize(credentials) {
|
||||
const data = await validation.user.signIn.parseAsync(credentials);
|
||||
|
||||
if (data.credentialType === "ldap") {
|
||||
return await authorizeWithLdapCredentialsAsync(db, data).catch(() => null);
|
||||
}
|
||||
|
||||
return await authorizeWithBasicCredentialsAsync(db, data);
|
||||
},
|
||||
}) satisfies CredentialsConfiguration;
|
||||
|
||||
export const createLdapConfiguration = (db: Database) =>
|
||||
({
|
||||
id: "ldap",
|
||||
type: "credentials",
|
||||
name: "Ldap",
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
async authorize(credentials) {
|
||||
const data = await validation.user.signIn.parseAsync(credentials);
|
||||
return await authorizeWithLdapCredentialsAsync(db, data).catch(() => null);
|
||||
},
|
||||
}) satisfies CredentialsConfiguration;
|
||||
|
||||
@@ -25,7 +25,6 @@ describe("authorizeWithBasicCredentials", () => {
|
||||
const result = await authorizeWithBasicCredentialsAsync(db, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "basic",
|
||||
});
|
||||
|
||||
// Assert
|
||||
@@ -47,7 +46,6 @@ describe("authorizeWithBasicCredentials", () => {
|
||||
const result = await authorizeWithBasicCredentialsAsync(db, {
|
||||
name: "test",
|
||||
password: "wrong",
|
||||
credentialType: "basic",
|
||||
});
|
||||
|
||||
// Assert
|
||||
@@ -69,7 +67,6 @@ describe("authorizeWithBasicCredentials", () => {
|
||||
const result = await authorizeWithBasicCredentialsAsync(db, {
|
||||
name: "wrong",
|
||||
password: "test",
|
||||
credentialType: "basic",
|
||||
});
|
||||
|
||||
// Assert
|
||||
@@ -88,7 +85,6 @@ describe("authorizeWithBasicCredentials", () => {
|
||||
const result = await authorizeWithBasicCredentialsAsync(db, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "basic",
|
||||
});
|
||||
|
||||
// Assert
|
||||
|
||||
@@ -3,7 +3,7 @@ import { describe, expect, test, vi } from "vitest";
|
||||
|
||||
import type { Database } from "@homarr/db";
|
||||
import { and, createId, eq } from "@homarr/db";
|
||||
import { groupMembers, groups, users } from "@homarr/db/schema/sqlite";
|
||||
import { groups, users } from "@homarr/db/schema/sqlite";
|
||||
import { createDb } from "@homarr/db/test";
|
||||
|
||||
import { authorizeWithLdapCredentialsAsync } from "../credentials/authorization/ldap-authorization";
|
||||
@@ -34,7 +34,6 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "ldap",
|
||||
});
|
||||
|
||||
// Assert
|
||||
@@ -57,7 +56,6 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "ldap",
|
||||
});
|
||||
|
||||
// Assert
|
||||
@@ -87,7 +85,6 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "ldap",
|
||||
});
|
||||
|
||||
// Assert
|
||||
@@ -120,7 +117,6 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "ldap",
|
||||
});
|
||||
|
||||
// Assert
|
||||
@@ -152,11 +148,11 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
const result = await authorizeWithLdapCredentialsAsync(db, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "ldap",
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result.name).toBe("test");
|
||||
expect(result.groups).toHaveLength(0); // Groups are needed in signIn events callback
|
||||
const dbUser = await db.query.users.findFirst({
|
||||
where: eq(users.name, "test"),
|
||||
});
|
||||
@@ -197,11 +193,11 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
const result = await authorizeWithLdapCredentialsAsync(db, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "ldap",
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result.name).toBe("test");
|
||||
expect(result.groups).toHaveLength(0); // Groups are needed in signIn events callback
|
||||
const dbUser = await db.query.users.findFirst({
|
||||
where: and(eq(users.name, "test"), eq(users.provider, "ldap")),
|
||||
});
|
||||
@@ -219,7 +215,8 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
expect(credentialsUser?.id).not.toBe(result.id);
|
||||
});
|
||||
|
||||
test("should authorize user with correct credentials and update name", async () => {
|
||||
// The name update occurs in the signIn event callback
|
||||
test("should authorize user with correct credentials and return updated name", async () => {
|
||||
// Arrange
|
||||
const spy = vi.spyOn(ldapClient, "LdapClient");
|
||||
spy.mockImplementation(
|
||||
@@ -251,11 +248,10 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
const result = await authorizeWithLdapCredentialsAsync(db, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "ldap",
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual({ id: userId, name: "test" });
|
||||
expect(result).toEqual({ id: userId, name: "test", groups: [] });
|
||||
|
||||
const dbUser = await db.query.users.findFirst({
|
||||
where: eq(users.id, userId),
|
||||
@@ -263,12 +259,12 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
|
||||
expect(dbUser).toBeDefined();
|
||||
expect(dbUser?.id).toBe(userId);
|
||||
expect(dbUser?.name).toBe("test");
|
||||
expect(dbUser?.name).toBe("test-old");
|
||||
expect(dbUser?.email).toBe("test@gmail.com");
|
||||
expect(dbUser?.provider).toBe("ldap");
|
||||
});
|
||||
|
||||
test("should authorize user with correct credentials and add him to the groups that he is in LDAP but not in Homar", async () => {
|
||||
test("should authorize user with correct credentials and return his groups", async () => {
|
||||
// Arrange
|
||||
const spy = vi.spyOn(ldapClient, "LdapClient");
|
||||
spy.mockImplementation(
|
||||
@@ -311,83 +307,9 @@ describe("authorizeWithLdapCredentials", () => {
|
||||
const result = await authorizeWithLdapCredentialsAsync(db, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "ldap",
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual({ id: userId, name: "test" });
|
||||
|
||||
const dbGroupMembers = await db.query.groupMembers.findMany();
|
||||
expect(dbGroupMembers).toHaveLength(1);
|
||||
});
|
||||
|
||||
test("should authorize user with correct credentials and remove him from groups he is in Homarr but not in LDAP", async () => {
|
||||
// Arrange
|
||||
const spy = vi.spyOn(ldapClient, "LdapClient");
|
||||
spy.mockImplementation(
|
||||
() =>
|
||||
({
|
||||
bindAsync: vi.fn(() => Promise.resolve()),
|
||||
searchAsync: vi.fn((argument: { options: { filter: string } }) =>
|
||||
argument.options.filter.includes("group")
|
||||
? Promise.resolve([
|
||||
{
|
||||
cn: "homarr_example",
|
||||
},
|
||||
])
|
||||
: Promise.resolve([
|
||||
{
|
||||
dn: "test55",
|
||||
mail: "test@gmail.com",
|
||||
},
|
||||
]),
|
||||
),
|
||||
disconnectAsync: vi.fn(),
|
||||
}) as unknown as ldapClient.LdapClient,
|
||||
);
|
||||
const db = createDb();
|
||||
const userId = createId();
|
||||
await db.insert(users).values({
|
||||
id: userId,
|
||||
name: "test",
|
||||
email: "test@gmail.com",
|
||||
provider: "ldap",
|
||||
});
|
||||
|
||||
const groupIds = [createId(), createId()] as const;
|
||||
await db.insert(groups).values([
|
||||
{
|
||||
id: groupIds[0],
|
||||
name: "homarr_example",
|
||||
},
|
||||
{
|
||||
id: groupIds[1],
|
||||
name: "homarr_no_longer_member",
|
||||
},
|
||||
]);
|
||||
await db.insert(groupMembers).values([
|
||||
{
|
||||
userId,
|
||||
groupId: groupIds[0],
|
||||
},
|
||||
{
|
||||
userId,
|
||||
groupId: groupIds[1],
|
||||
},
|
||||
]);
|
||||
|
||||
// Act
|
||||
const result = await authorizeWithLdapCredentialsAsync(db, {
|
||||
name: "test",
|
||||
password: "test",
|
||||
credentialType: "ldap",
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual({ id: userId, name: "test" });
|
||||
|
||||
const dbGroupMembers = await db.query.groupMembers.findMany();
|
||||
expect(dbGroupMembers).toHaveLength(1);
|
||||
expect(dbGroupMembers[0]?.groupId).toBe(groupIds[0]);
|
||||
expect(result).toEqual({ id: userId, name: "test", groups: ["homarr_example"] });
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,7 +5,8 @@ import type { Database } from "@homarr/db";
|
||||
|
||||
import { getCurrentUserPermissionsAsync } from "./callbacks";
|
||||
|
||||
export const sessionTokenCookieName = "next-auth.session-token";
|
||||
// Default of authjs
|
||||
export const sessionTokenCookieName = "authjs.session-token";
|
||||
|
||||
export const expireDateAfter = (seconds: number) => {
|
||||
return new Date(Date.now() + seconds * 1000);
|
||||
|
||||
67
packages/auth/test/adapter.spec.ts
Normal file
67
packages/auth/test/adapter.spec.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
import { describe, expect, test } from "vitest";
|
||||
|
||||
import { users } from "@homarr/db/schema/sqlite";
|
||||
import { createDb } from "@homarr/db/test";
|
||||
|
||||
import { createAdapter } from "../adapter";
|
||||
|
||||
describe("createAdapter should create drizzle adapter", () => {
|
||||
test.each([["credentials" as const], ["ldap" as const], ["oidc" as const]])(
|
||||
"createAdapter getUserByEmail should return user for provider %s when this provider provided",
|
||||
async (provider) => {
|
||||
// Arrange
|
||||
const db = createDb();
|
||||
const adapter = createAdapter(db, provider);
|
||||
const email = "test@example.com";
|
||||
await db.insert(users).values({ id: "1", name: "test", email, provider });
|
||||
|
||||
// Act
|
||||
const user = await adapter.getUserByEmail?.(email);
|
||||
|
||||
// Assert
|
||||
expect(user).toEqual({
|
||||
id: "1",
|
||||
name: "test",
|
||||
email,
|
||||
emailVerified: null,
|
||||
image: null,
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
test.each([
|
||||
["credentials", ["ldap", "oidc"]],
|
||||
["ldap", ["credentials", "oidc"]],
|
||||
["oidc", ["credentials", "ldap"]],
|
||||
] as const)(
|
||||
"createAdapter getUserByEmail should return null if only for other providers than %s exist",
|
||||
async (requestedProvider, existingProviders) => {
|
||||
// Arrange
|
||||
const db = createDb();
|
||||
const adapter = createAdapter(db, requestedProvider);
|
||||
const email = "test@example.com";
|
||||
for (const provider of existingProviders) {
|
||||
await db.insert(users).values({ id: provider, name: `test-${provider}`, email, provider });
|
||||
}
|
||||
|
||||
// Act
|
||||
const user = await adapter.getUserByEmail?.(email);
|
||||
|
||||
// Assert
|
||||
expect(user).toBeNull();
|
||||
},
|
||||
);
|
||||
|
||||
test("createAdapter getUserByEmail should throw error if provider is unknown", async () => {
|
||||
// Arrange
|
||||
const db = createDb();
|
||||
const adapter = createAdapter(db, "unknown");
|
||||
const email = "test@example.com";
|
||||
|
||||
// Act
|
||||
const actAsync = async () => await adapter.getUserByEmail?.(email);
|
||||
|
||||
// Assert
|
||||
await expect(actAsync()).rejects.toThrow("Unable to get user by email for unknown provider");
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,5 @@
|
||||
/* eslint-disable @typescript-eslint/no-non-null-assertion */
|
||||
import { cookies } from "next/headers";
|
||||
import type { Adapter, AdapterUser } from "@auth/core/adapters";
|
||||
import type { Account } from "next-auth";
|
||||
import type { AdapterUser } from "@auth/core/adapters";
|
||||
import type { JWT } from "next-auth/jwt";
|
||||
import { describe, expect, test, vi } from "vitest";
|
||||
|
||||
@@ -9,7 +7,7 @@ import { groupMembers, groupPermissions, groups, users } from "@homarr/db/schema
|
||||
import { createDb } from "@homarr/db/test";
|
||||
import * as definitions from "@homarr/definitions";
|
||||
|
||||
import { createSessionCallback, createSignInCallback, getCurrentUserPermissionsAsync } from "../callbacks";
|
||||
import { createSessionCallback, getCurrentUserPermissionsAsync } from "../callbacks";
|
||||
|
||||
// This one is placed here because it's used in multiple tests and needs to be the same reference
|
||||
const setCookies = vi.fn();
|
||||
@@ -141,151 +139,3 @@ describe("session callback", () => {
|
||||
expect(result.user!.name).toEqual(user.name);
|
||||
});
|
||||
});
|
||||
|
||||
type AdapterSessionInput = Parameters<Exclude<Adapter["createSession"], undefined>>[0];
|
||||
|
||||
const createAdapter = () => {
|
||||
const result = {
|
||||
createSession: (input: AdapterSessionInput) => input,
|
||||
};
|
||||
|
||||
vi.spyOn(result, "createSession");
|
||||
return result;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/consistent-type-imports
|
||||
type SessionExport = typeof import("../session");
|
||||
const mockSessionToken = "e9ef3010-6981-4a81-b9d6-8495d09cf3b5";
|
||||
const mockSessionExpiry = new Date("2023-07-01");
|
||||
vi.mock("../env.mjs", () => {
|
||||
return {
|
||||
env: {
|
||||
AUTH_SESSION_EXPIRY_TIME: 60 * 60 * 24 * 7,
|
||||
},
|
||||
};
|
||||
});
|
||||
vi.mock("../session", async (importOriginal) => {
|
||||
const mod = await importOriginal<SessionExport>();
|
||||
|
||||
const generateSessionToken = (): typeof mockSessionToken => mockSessionToken;
|
||||
const expireDateAfter = (_seconds: number) => mockSessionExpiry;
|
||||
|
||||
return {
|
||||
...mod,
|
||||
generateSessionToken,
|
||||
expireDateAfter,
|
||||
} satisfies SessionExport;
|
||||
});
|
||||
|
||||
describe("createSignInCallback", () => {
|
||||
test("should return true if not credentials request and set colorScheme & sessionToken cookie", async () => {
|
||||
// Arrange
|
||||
const isCredentialsRequest = false;
|
||||
const db = await prepareDbForSigninAsync("1");
|
||||
const signInCallback = createSignInCallback(createAdapter(), db, isCredentialsRequest);
|
||||
|
||||
// Act
|
||||
const result = await signInCallback({
|
||||
user: { id: "1", emailVerified: new Date("2023-01-13") },
|
||||
account: {} as Account,
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test("should return false if no adapter.createSession", async () => {
|
||||
// Arrange
|
||||
const isCredentialsRequest = true;
|
||||
const db = await prepareDbForSigninAsync("1");
|
||||
const signInCallback = createSignInCallback(
|
||||
// https://github.com/nextauthjs/next-auth/issues/6106
|
||||
{ createSession: undefined } as unknown as Adapter,
|
||||
db,
|
||||
isCredentialsRequest,
|
||||
);
|
||||
|
||||
// Act
|
||||
const result = await signInCallback({
|
||||
user: { id: "1", emailVerified: new Date("2023-01-13") },
|
||||
account: {} as Account,
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test("should call adapter.createSession with correct input", async () => {
|
||||
// Arrange
|
||||
const adapter = createAdapter();
|
||||
const isCredentialsRequest = true;
|
||||
const db = await prepareDbForSigninAsync("1");
|
||||
const signInCallback = createSignInCallback(adapter, db, isCredentialsRequest);
|
||||
const user = { id: "1", emailVerified: new Date("2023-01-13") };
|
||||
const account = {} as Account;
|
||||
// Act
|
||||
await signInCallback({ user, account });
|
||||
|
||||
// Assert
|
||||
expect(adapter.createSession).toHaveBeenCalledWith({
|
||||
sessionToken: mockSessionToken,
|
||||
userId: user.id,
|
||||
expires: mockSessionExpiry,
|
||||
});
|
||||
expect(cookies().set).toHaveBeenCalledWith("next-auth.session-token", mockSessionToken, {
|
||||
path: "/",
|
||||
expires: mockSessionExpiry,
|
||||
httpOnly: true,
|
||||
sameSite: "lax",
|
||||
secure: true,
|
||||
});
|
||||
});
|
||||
|
||||
test("should set colorScheme from db as cookie", async () => {
|
||||
// Arrange
|
||||
const isCredentialsRequest = true;
|
||||
const db = await prepareDbForSigninAsync("1");
|
||||
const signInCallback = createSignInCallback(createAdapter(), db, isCredentialsRequest);
|
||||
|
||||
// Act
|
||||
const result = await signInCallback({
|
||||
user: { id: "1", emailVerified: new Date("2023-01-13") },
|
||||
account: {} as Account,
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
expect(cookies().set).toHaveBeenCalledWith(
|
||||
"homarr-color-scheme",
|
||||
"dark",
|
||||
expect.objectContaining({
|
||||
path: "/",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test("should return false if user not found in db", async () => {
|
||||
// Arrange
|
||||
const isCredentialsRequest = true;
|
||||
const db = await prepareDbForSigninAsync("other-id");
|
||||
const signInCallback = createSignInCallback(createAdapter(), db, isCredentialsRequest);
|
||||
|
||||
// Act
|
||||
const result = await signInCallback({
|
||||
user: { id: "1", emailVerified: new Date("2023-01-13") },
|
||||
account: {} as Account,
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
const prepareDbForSigninAsync = async (userId: string) => {
|
||||
const db = createDb();
|
||||
await db.insert(users).values({
|
||||
id: userId,
|
||||
colorScheme: "dark",
|
||||
});
|
||||
return db;
|
||||
};
|
||||
|
||||
190
packages/auth/test/events.spec.ts
Normal file
190
packages/auth/test/events.spec.ts
Normal file
@@ -0,0 +1,190 @@
|
||||
import type { ResponseCookie } from "next/dist/compiled/@edge-runtime/cookies";
|
||||
import type { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
|
||||
import { cookies } from "next/headers";
|
||||
import { describe, expect, test, vi } from "vitest";
|
||||
|
||||
import { eq } from "@homarr/db";
|
||||
import type { Database } from "@homarr/db";
|
||||
import { groupMembers, groups, users } from "@homarr/db/schema/sqlite";
|
||||
import { createDb } from "@homarr/db/test";
|
||||
|
||||
import { createSignInEventHandler } from "../events";
|
||||
|
||||
vi.mock("../env.mjs", () => {
|
||||
return {
|
||||
env: {
|
||||
AUTH_OIDC_GROUPS_ATTRIBUTE: "someRandomGroupsKey",
|
||||
},
|
||||
};
|
||||
});
|
||||
// eslint-disable-next-line @typescript-eslint/consistent-type-imports
|
||||
type HeadersExport = typeof import("next/headers");
|
||||
vi.mock("next/headers", async (importOriginal) => {
|
||||
const mod = await importOriginal<HeadersExport>();
|
||||
|
||||
const result = {
|
||||
set: (name: string, value: string, options: Partial<ResponseCookie>) => options as ResponseCookie,
|
||||
} as unknown as ReadonlyRequestCookies;
|
||||
|
||||
vi.spyOn(result, "set");
|
||||
|
||||
const cookies = () => result;
|
||||
|
||||
return { ...mod, cookies } satisfies HeadersExport;
|
||||
});
|
||||
|
||||
describe("createSignInEventHandler should create signInEventHandler", () => {
|
||||
describe("signInEventHandler should synchronize ldap groups", () => {
|
||||
test("should add missing group membership", async () => {
|
||||
// Arrange
|
||||
const db = createDb();
|
||||
await createUserAsync(db);
|
||||
await createGroupAsync(db);
|
||||
const eventHandler = createSignInEventHandler(db);
|
||||
|
||||
// Act
|
||||
await eventHandler?.({
|
||||
user: { id: "1", name: "test", groups: ["test"] } as never,
|
||||
profile: undefined,
|
||||
account: null,
|
||||
});
|
||||
|
||||
// Assert
|
||||
const dbGroupMembers = await db.query.groupMembers.findFirst({
|
||||
where: eq(groupMembers.userId, "1"),
|
||||
});
|
||||
expect(dbGroupMembers?.groupId).toBe("1");
|
||||
});
|
||||
test("should remove group membership", async () => {
|
||||
// Arrange
|
||||
const db = createDb();
|
||||
await createUserAsync(db);
|
||||
await createGroupAsync(db);
|
||||
await db.insert(groupMembers).values({
|
||||
userId: "1",
|
||||
groupId: "1",
|
||||
});
|
||||
const eventHandler = createSignInEventHandler(db);
|
||||
|
||||
// Act
|
||||
await eventHandler?.({
|
||||
user: { id: "1", name: "test", groups: [] } as never,
|
||||
profile: undefined,
|
||||
account: null,
|
||||
});
|
||||
|
||||
// Assert
|
||||
const dbGroupMembers = await db.query.groupMembers.findFirst({
|
||||
where: eq(groupMembers.userId, "1"),
|
||||
});
|
||||
expect(dbGroupMembers).toBeUndefined();
|
||||
});
|
||||
});
|
||||
describe("signInEventHandler should synchronize oidc groups", () => {
|
||||
test("should add missing group membership", async () => {
|
||||
// Arrange
|
||||
const db = createDb();
|
||||
await createUserAsync(db);
|
||||
await createGroupAsync(db);
|
||||
const eventHandler = createSignInEventHandler(db);
|
||||
|
||||
// Act
|
||||
await eventHandler?.({
|
||||
user: { id: "1", name: "test" },
|
||||
profile: { preferred_username: "test", someRandomGroupsKey: ["test"] },
|
||||
account: null,
|
||||
});
|
||||
|
||||
// Assert
|
||||
const dbGroupMembers = await db.query.groupMembers.findFirst({
|
||||
where: eq(groupMembers.userId, "1"),
|
||||
});
|
||||
expect(dbGroupMembers?.groupId).toBe("1");
|
||||
});
|
||||
test("should remove group membership", async () => {
|
||||
// Arrange
|
||||
const db = createDb();
|
||||
await createUserAsync(db);
|
||||
await createGroupAsync(db);
|
||||
await db.insert(groupMembers).values({
|
||||
userId: "1",
|
||||
groupId: "1",
|
||||
});
|
||||
const eventHandler = createSignInEventHandler(db);
|
||||
|
||||
// Act
|
||||
await eventHandler?.({
|
||||
user: { id: "1", name: "test" },
|
||||
profile: { preferred_username: "test", someRandomGroupsKey: [] },
|
||||
account: null,
|
||||
});
|
||||
|
||||
// Assert
|
||||
const dbGroupMembers = await db.query.groupMembers.findFirst({
|
||||
where: eq(groupMembers.userId, "1"),
|
||||
});
|
||||
expect(dbGroupMembers).toBeUndefined();
|
||||
});
|
||||
});
|
||||
test.each([
|
||||
["ldap" as const, { name: "test-new" }, undefined],
|
||||
["oidc" as const, { name: "test" }, { preferred_username: "test-new" }],
|
||||
["oidc" as const, { name: "test" }, { preferred_username: "test@example.com", name: "test-new" }],
|
||||
])("signInEventHandler should update username for %s provider", async (_provider, user, profile) => {
|
||||
// Arrange
|
||||
const db = createDb();
|
||||
await createUserAsync(db);
|
||||
const eventHandler = createSignInEventHandler(db);
|
||||
|
||||
// Act
|
||||
await eventHandler?.({
|
||||
user: { id: "1", ...user },
|
||||
profile,
|
||||
account: null,
|
||||
});
|
||||
|
||||
// Assert
|
||||
const dbUser = await db.query.users.findFirst({
|
||||
where: eq(users.id, "1"),
|
||||
columns: {
|
||||
name: true,
|
||||
},
|
||||
});
|
||||
expect(dbUser?.name).toBe("test-new");
|
||||
});
|
||||
test("signInEventHandler should set homarr-color-scheme cookie", async () => {
|
||||
// Arrange
|
||||
const db = createDb();
|
||||
await createUserAsync(db);
|
||||
const eventHandler = createSignInEventHandler(db);
|
||||
|
||||
// Act
|
||||
await eventHandler?.({
|
||||
user: { id: "1", name: "test" },
|
||||
profile: undefined,
|
||||
account: null,
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(cookies().set).toHaveBeenCalledWith(
|
||||
"homarr-color-scheme",
|
||||
"dark",
|
||||
expect.objectContaining({
|
||||
path: "/",
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
const createUserAsync = async (db: Database) =>
|
||||
await db.insert(users).values({
|
||||
id: "1",
|
||||
name: "test",
|
||||
colorScheme: "dark",
|
||||
});
|
||||
|
||||
const createGroupAsync = async (db: Database) =>
|
||||
await db.insert(groups).values({
|
||||
id: "1",
|
||||
name: "test",
|
||||
});
|
||||
@@ -56,7 +56,6 @@ const initUserSchema = createUserSchema;
|
||||
const signInSchema = z.object({
|
||||
name: z.string().min(1),
|
||||
password: z.string().min(1),
|
||||
credentialType: z.enum(["basic", "ldap"]),
|
||||
});
|
||||
|
||||
const registrationSchema = z
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
"AUTH_OIDC_CLIENT_SECRET",
|
||||
"AUTH_OIDC_ISSUER",
|
||||
"AUTH_OIDC_SCOPE_OVERWRITE",
|
||||
"AUTH_OIDC_GROUPS_ATTRIBUTE",
|
||||
"AUTH_LDAP_USERNAME_ATTRIBUTE",
|
||||
"AUTH_LDAP_USER_MAIL_ATTRIBUTE",
|
||||
"AUTH_LDAP_USERNAME_FILTER_EXTRA_ARG",
|
||||
|
||||
Reference in New Issue
Block a user