refactor: replace signIn callback with signIn event, adjust getUserByEmail in adapter to check provider (#1223)

* refactor: replace signIn callback with signIn event, adjust getUserByEmail in adapter to check provider

* test: adjusting tests for adapter and events

* docs: add comments for unknown auth provider

* fix: missing dayjs import
This commit is contained in:
Meier Lukas
2024-10-07 21:13:15 +02:00
committed by GitHub
parent 4d51e3b344
commit eb21628ee4
19 changed files with 521 additions and 423 deletions

View File

@@ -12,8 +12,7 @@ import type { useForm } from "@homarr/form";
import { useZodForm } from "@homarr/form";
import { showErrorNotification, showSuccessNotification } from "@homarr/notifications";
import { useScopedI18n } from "@homarr/translation/client";
import type { z } from "@homarr/validation";
import { validation } from "@homarr/validation";
import { validation, z } from "@homarr/validation";
interface LoginFormProps {
providers: string[];
@@ -22,15 +21,17 @@ interface LoginFormProps {
callbackUrl: string;
}
const extendedValidation = validation.user.signIn.extend({ provider: z.enum(["credentials", "ldap"]) });
export const LoginForm = ({ providers, oidcClientName, isOidcAutoLoginEnabled, callbackUrl }: LoginFormProps) => {
const t = useScopedI18n("user");
const router = useRouter();
const [isPending, setIsPending] = useState(false);
const form = useZodForm(validation.user.signIn, {
const form = useZodForm(extendedValidation, {
initialValues: {
name: "",
password: "",
credentialType: "basic",
provider: "credentials",
},
});
@@ -95,14 +96,14 @@ export const LoginForm = ({ providers, oidcClientName, isOidcAutoLoginEnabled, c
<Stack gap="lg">
{credentialInputsVisible && (
<>
<form onSubmit={form.onSubmit((credentials) => void signInAsync("credentials", credentials))}>
<form onSubmit={form.onSubmit((credentials) => void signInAsync(credentials.provider, credentials))}>
<Stack gap="lg">
<TextInput label={t("field.username.label")} {...form.getInputProps("name")} />
<PasswordInput label={t("field.password.label")} {...form.getInputProps("password")} />
{providers.includes("credentials") && (
<Stack gap="sm">
<SubmitButton isPending={isPending} form={form} credentialType="basic">
<SubmitButton isPending={isPending} form={form} provider="credentials">
{t("action.login.label")}
</SubmitButton>
<PasswordForgottenCollapse username={form.values.name} />
@@ -110,7 +111,7 @@ export const LoginForm = ({ providers, oidcClientName, isOidcAutoLoginEnabled, c
)}
{providers.includes("ldap") && (
<SubmitButton isPending={isPending} form={form} credentialType="ldap">
<SubmitButton isPending={isPending} form={form} provider="ldap">
{t("action.login.labelWith", { provider: "LDAP" })}
</SubmitButton>
)}
@@ -133,18 +134,18 @@ export const LoginForm = ({ providers, oidcClientName, isOidcAutoLoginEnabled, c
interface SubmitButtonProps {
isPending: boolean;
form: ReturnType<typeof useForm<FormType, (values: FormType) => FormType>>;
credentialType: "basic" | "ldap";
provider: "credentials" | "ldap";
}
const SubmitButton = ({ isPending, form, credentialType, children }: PropsWithChildren<SubmitButtonProps>) => {
const isCurrentProviderActive = form.getValues().credentialType === credentialType;
const SubmitButton = ({ isPending, form, provider, children }: PropsWithChildren<SubmitButtonProps>) => {
const isCurrentProviderActive = form.getValues().provider === provider;
return (
<Button
type="submit"
name={credentialType}
name={provider}
fullWidth
onClick={() => form.setFieldValue("credentialType", credentialType)}
onClick={() => form.setFieldValue("provider", provider)}
loading={isPending && isCurrentProviderActive}
disabled={isPending && !isCurrentProviderActive}
>
@@ -181,4 +182,4 @@ const PasswordForgottenCollapse = ({ username }: PasswordForgottenCollapseProps)
);
};
type FormType = z.infer<typeof validation.user.signIn>;
type FormType = z.infer<typeof extendedValidation>;

View File

@@ -1,17 +1,37 @@
import { NextRequest } from "next/server";
import { createHandlers } from "@homarr/auth";
import type { SupportedAuthProvider } from "@homarr/definitions";
import { logger } from "@homarr/log";
export const GET = async (req: NextRequest) => {
return await createHandlers(isCredentialsRequest(req)).handlers.GET(reqWithTrustedOrigin(req));
return await createHandlers(extractProvider(req)).handlers.GET(reqWithTrustedOrigin(req));
};
export const POST = async (req: NextRequest) => {
return await createHandlers(isCredentialsRequest(req)).handlers.POST(reqWithTrustedOrigin(req));
return await createHandlers(extractProvider(req)).handlers.POST(reqWithTrustedOrigin(req));
};
const isCredentialsRequest = (req: NextRequest) => {
return req.url.includes("credentials") && req.method === "POST";
/**
* This method extracts the used provider from the url and allows us to override the getUserByEmail method in the adapter.
* @param req request containing the url
* @returns the provider or "unknown" if the provider could not be extracted
*/
const extractProvider = (req: NextRequest): SupportedAuthProvider | "unknown" => {
const url = new URL(req.url);
if (url.pathname.includes("oidc")) {
return "oidc";
}
if (url.pathname.includes("credentials")) {
return "credentials";
}
if (url.pathname.includes("ldap")) {
return "ldap";
}
return "unknown";
};
/**

View File

@@ -1,5 +1,45 @@
import type { Adapter } from "@auth/core/adapters";
import { DrizzleAdapter } from "@auth/drizzle-adapter";
import { db } from "@homarr/db";
import type { Database } from "@homarr/db";
import { and, eq } from "@homarr/db";
import { accounts, users } from "@homarr/db/schema/sqlite";
import type { SupportedAuthProvider } from "@homarr/definitions";
export const adapter = DrizzleAdapter(db);
export const createAdapter = (db: Database, provider: SupportedAuthProvider | "unknown"): Adapter => {
const drizzleAdapter = DrizzleAdapter(db, { usersTable: users, accountsTable: accounts });
return {
...drizzleAdapter,
// We override the default implementation as we want to have a provider
// flag in the user instead of the account to not intermingle users from different providers
// eslint-disable-next-line no-restricted-syntax
getUserByEmail: async (email) => {
if (provider === "unknown") {
throw new Error("Unable to get user by email for unknown provider");
}
const user = await db.query.users.findFirst({
where: and(eq(users.email, email), eq(users.provider, provider)),
columns: {
id: true,
name: true,
email: true,
emailVerified: true,
image: true,
},
});
if (!user) {
return null;
}
return {
...user,
// We allow null as email for credentials provider
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
email: user.email!,
};
},
};
};

View File

@@ -1,5 +1,3 @@
import { cookies } from "next/headers";
import type { Adapter } from "@auth/core/adapters";
import dayjs from "dayjs";
import type { NextAuthConfig } from "next-auth";
@@ -9,9 +7,6 @@ import { eq, inArray } from "@homarr/db";
import { groupMembers, groupPermissions, users } from "@homarr/db/schema/sqlite";
import { getPermissionsWithChildren } from "@homarr/definitions";
import { env } from "./env.mjs";
import { expireDateAfter, generateSessionToken, sessionTokenCookieName } from "./session";
export const getCurrentUserPermissionsAsync = async (db: Database, userId: string) => {
const dbGroupMembers = await db.query.groupMembers.findMany({
where: eq(groupMembers.userId, userId),
@@ -68,51 +63,6 @@ export const createSessionCallback = (db: Database): NextAuthCallbackOf<"session
};
};
export const createSignInCallback =
(adapter: Adapter, db: Database, isCredentialsRequest: boolean): NextAuthCallbackOf<"signIn"> =>
async ({ user }) => {
if (!isCredentialsRequest) return true;
// https://github.com/nextauthjs/next-auth/issues/6106
if (!adapter.createSession || !user.id) {
return false;
}
const sessionToken = generateSessionToken();
const sessionExpires = expireDateAfter(env.AUTH_SESSION_EXPIRY_TIME);
await adapter.createSession({
sessionToken,
userId: user.id,
expires: sessionExpires,
});
cookies().set(sessionTokenCookieName, sessionToken, {
path: "/",
expires: sessionExpires,
httpOnly: true,
sameSite: "lax",
secure: true,
});
const dbUser = await db.query.users.findFirst({
where: eq(users.id, user.id),
columns: {
colorScheme: true,
},
});
if (!dbUser) return false;
// We use a cookie as localStorage is not shared with server (otherwise flickering would occur)
cookies().set("homarr-color-scheme", dbUser.colorScheme, {
path: "/",
expires: dayjs().add(1, "year").toDate(),
});
return true;
};
type NextAuthCallbackRecord = Exclude<NextAuthConfig["callbacks"], undefined>;
export type NextAuthCallbackOf<TKey extends keyof NextAuthCallbackRecord> = Exclude<
NextAuthCallbackRecord[TKey],

View File

@@ -4,18 +4,21 @@ import NextAuth from "next-auth";
import Credentials from "next-auth/providers/credentials";
import { db } from "@homarr/db";
import type { SupportedAuthProvider } from "@homarr/definitions";
import { adapter } from "./adapter";
import { createSessionCallback, createSignInCallback } from "./callbacks";
import { createAdapter } from "./adapter";
import { createSessionCallback } from "./callbacks";
import { env } from "./env.mjs";
import { createCredentialsConfiguration } from "./providers/credentials/credentials-provider";
import { createSignInEventHandler } from "./events";
import { createCredentialsConfiguration, createLdapConfiguration } from "./providers/credentials/credentials-provider";
import { EmptyNextAuthProvider } from "./providers/empty/empty-provider";
import { filterProviders } from "./providers/filter-providers";
import { OidcProvider } from "./providers/oidc/oidc-provider";
import { createRedirectUri } from "./redirect";
import { sessionTokenCookieName } from "./session";
import { generateSessionToken, sessionTokenCookieName } from "./session";
export const createConfiguration = (isCredentialsRequest: boolean, headers: ReadonlyHeaders | null) =>
// See why it's unknown in the [...nextauth]/route.ts file
export const createConfiguration = (provider: SupportedAuthProvider | "unknown", headers: ReadonlyHeaders | null) =>
NextAuth({
logger: {
error: (code, ...message) => {
@@ -30,21 +33,25 @@ export const createConfiguration = (isCredentialsRequest: boolean, headers: Read
},
},
trustHost: true,
adapter,
adapter: createAdapter(db, provider),
providers: filterProviders([
Credentials(createCredentialsConfiguration(db)),
Credentials(createLdapConfiguration(db)),
EmptyNextAuthProvider(),
OidcProvider(headers),
]),
callbacks: {
session: createSessionCallback(db),
signIn: createSignInCallback(adapter, db, isCredentialsRequest),
},
events: {
signIn: createSignInEventHandler(db),
},
redirectProxyUrl: createRedirectUri(headers, "/api/auth"),
secret: "secret-is-not-defined-yet", // TODO: This should be added later
session: {
strategy: "database",
maxAge: env.AUTH_SESSION_EXPIRY_TIME,
generateSessionToken,
},
pages: {
signIn: "/auth/login",

View File

@@ -74,6 +74,7 @@ export const env = createEnv({
AUTH_OIDC_CLIENT_NAME: z.string().min(1).default("OIDC"),
AUTH_OIDC_AUTO_LOGIN: booleanSchema,
AUTH_OIDC_SCOPE_OVERWRITE: z.string().min(1).default("openid email profile groups"),
AUTH_OIDC_GROUPS_ATTRIBUTE: z.string().default("groups"), // Is used in the signIn event to assign the correct groups, key is from object of decoded id_token
}
: {}),
...(authProviders.includes("ldap")
@@ -113,6 +114,7 @@ export const env = createEnv({
AUTH_OIDC_CLIENT_SECRET: process.env.AUTH_OIDC_CLIENT_SECRET,
AUTH_OIDC_ISSUER: process.env.AUTH_OIDC_ISSUER,
AUTH_OIDC_SCOPE_OVERWRITE: process.env.AUTH_OIDC_SCOPE_OVERWRITE,
AUTH_OIDC_GROUPS_ATTRIBUTE: process.env.AUTH_OIDC_GROUPS_ATTRIBUTE,
AUTH_LDAP_USERNAME_ATTRIBUTE: process.env.AUTH_LDAP_USERNAME_ATTRIBUTE,
AUTH_LDAP_USER_MAIL_ATTRIBUTE: process.env.AUTH_LDAP_USER_MAIL_ATTRIBUTE,
AUTH_LDAP_USERNAME_FILTER_EXTRA_ARG: process.env.AUTH_LDAP_USERNAME_FILTER_EXTRA_ARG,

131
packages/auth/events.ts Normal file
View File

@@ -0,0 +1,131 @@
import { cookies } from "next/headers";
import dayjs from "dayjs";
import type { NextAuthConfig } from "next-auth";
import { and, eq, inArray } from "@homarr/db";
import type { Database } from "@homarr/db";
import { groupMembers, groups, users } from "@homarr/db/schema/sqlite";
import { logger } from "@homarr/log";
import { env } from "./env.mjs";
export const createSignInEventHandler = (db: Database): Exclude<NextAuthConfig["events"], undefined>["signIn"] => {
return async ({ user, profile }) => {
if (!user.id) throw new Error("User ID is missing");
const dbUser = await db.query.users.findFirst({
where: eq(users.id, user.id),
columns: {
name: true,
colorScheme: true,
},
});
if (!dbUser) throw new Error("User not found");
const groupsKey = env.AUTH_OIDC_GROUPS_ATTRIBUTE;
// Groups from oidc provider are provided from the profile, it's not typed.
if (profile && groupsKey in profile && Array.isArray(profile[groupsKey])) {
await synchronizeGroupsWithExternalForUserAsync(db, user.id, profile[groupsKey] as string[]);
}
// In ldap-authroization we return the groups from ldap, it's not typed.
if ("groups" in user && Array.isArray(user.groups)) {
await synchronizeGroupsWithExternalForUserAsync(db, user.id, user.groups as string[]);
}
if (dbUser.name !== user.name) {
await db.update(users).set({ name: user.name }).where(eq(users.id, user.id));
logger.info(
`Username for user of credentials provider has changed. user=${user.id} old=${dbUser.name} new=${user.name}`,
);
}
const profileUsername = profile?.preferred_username?.includes("@") ? profile.name : profile?.preferred_username;
if (profileUsername && dbUser.name !== profileUsername) {
await db.update(users).set({ name: profileUsername }).where(eq(users.id, user.id));
logger.info(
`Username for user of oidc provider has changed. user=${user.id} old='${dbUser.name}' new='${profileUsername}'`,
);
}
// We use a cookie as localStorage is not shared with server (otherwise flickering would occur)
cookies().set("homarr-color-scheme", dbUser.colorScheme, {
path: "/",
expires: dayjs().add(1, "year").toDate(),
});
};
};
const synchronizeGroupsWithExternalForUserAsync = async (db: Database, userId: string, externalGroups: string[]) => {
const dbGroupMembers = await db.query.groupMembers.findMany({
where: eq(groupMembers.userId, userId),
with: {
group: { columns: { name: true } },
},
});
/**
* The below groups are those groups the user is part of in the external system, but not in Homarr.
* So he has to be added to those groups.
*/
const missingExternalGroupsForUser = externalGroups.filter(
(externalGroup) => !dbGroupMembers.some(({ group }) => group.name === externalGroup),
);
if (missingExternalGroupsForUser.length > 0) {
logger.debug(
`Homarr does not have the user in certain groups. user=${userId} count=${missingExternalGroupsForUser.length}`,
);
const groupIds = await db.query.groups.findMany({
columns: {
id: true,
},
where: inArray(groups.name, missingExternalGroupsForUser),
});
logger.debug(`Homarr has found groups in the database user is not in. user=${userId} count=${groupIds.length}`);
if (groupIds.length > 0) {
await db.insert(groupMembers).values(
groupIds.map((group) => ({
userId,
groupId: group.id,
})),
);
logger.info(`Added user to groups successfully. user=${userId} count=${groupIds.length}`);
} else {
logger.debug(`User is already in all groups of Homarr. user=${userId}`);
}
}
/**
* The below groups are those groups the user is part of in Homarr, but not in the external system.
* So he has to be removed from those groups.
*/
const groupsUserIsNoLongerMemberOfExternally = dbGroupMembers.filter(
({ group }) => !externalGroups.includes(group.name),
);
if (groupsUserIsNoLongerMemberOfExternally.length > 0) {
logger.debug(
`Homarr has the user in certain groups that LDAP does not have. user=${userId} count=${groupsUserIsNoLongerMemberOfExternally.length}`,
);
await db.delete(groupMembers).where(
and(
eq(groupMembers.userId, userId),
inArray(
groupMembers.groupId,
groupsUserIsNoLongerMemberOfExternally.map(({ groupId }) => groupId),
),
),
);
logger.info(
`Removed user from groups successfully. user=${userId} count=${groupsUserIsNoLongerMemberOfExternally.length}`,
);
}
};

View File

@@ -1,7 +1,7 @@
import { headers } from "next/headers";
import type { DefaultSession } from "@auth/core/types";
import type { ColorScheme, GroupPermissionKey } from "@homarr/definitions";
import type { ColorScheme, GroupPermissionKey, SupportedAuthProvider } from "@homarr/definitions";
import { createConfiguration } from "./configuration";
@@ -19,6 +19,7 @@ declare module "next-auth" {
export * from "./security";
export const createHandlers = (isCredentialsRequest: boolean) => createConfiguration(isCredentialsRequest, headers());
// See why it's unknown in the [...nextauth]/route.ts file
export const createHandlers = (provider: SupportedAuthProvider | "unknown") => createConfiguration(provider, headers());
export { getSessionFromTokenAsync as getSessionFromToken, sessionTokenCookieName } from "./session";

View File

@@ -2,7 +2,7 @@ import { cache } from "react";
import { createConfiguration } from "./configuration";
const { auth: defaultAuth } = createConfiguration(false, null);
const { auth: defaultAuth } = createConfiguration("unknown", null);
/**
* This is the main way to get session data for your RSCs.

View File

@@ -1,8 +1,8 @@
import { CredentialsSignin } from "@auth/core/errors";
import type { Database, InferInsertModel } from "@homarr/db";
import { and, createId, eq, inArray } from "@homarr/db";
import { groupMembers, groups, users } from "@homarr/db/schema/sqlite";
import { and, createId, eq } from "@homarr/db";
import { users } from "@homarr/db/schema/sqlite";
import { logger } from "@homarr/log";
import type { validation } from "@homarr/validation";
import { z } from "@homarr/validation";
@@ -99,18 +99,6 @@ export const authorizeWithLdapCredentialsAsync = async (
emailVerified: true,
provider: true,
},
with: {
groups: {
with: {
group: {
columns: {
id: true,
name: true,
},
},
},
},
},
where: and(eq(users.email, mailResult.data), eq(users.provider, "ldap")),
});
@@ -128,79 +116,16 @@ export const authorizeWithLdapCredentialsAsync = async (
await db.insert(users).values(insertUser);
user = {
...insertUser,
groups: [],
};
user = insertUser;
logger.info(`User ${credentials.name} created successfully.`);
}
if (user.name !== credentials.name) {
logger.warn(`User ${credentials.name} found in the database but with different name. Updating...`);
user.name = credentials.name;
await db.update(users).set({ name: user.name }).where(eq(users.id, user.id));
logger.info(`User ${credentials.name} updated successfully.`);
}
const ldapGroupsUserIsNotIn = userGroups.filter(
(group) => !user.groups.some((userGroup) => userGroup.group.name === group),
);
if (ldapGroupsUserIsNotIn.length > 0) {
logger.debug(
`Homarr does not have the user in certain groups. user=${user.name} count=${ldapGroupsUserIsNotIn.length}`,
);
const groupIds = await db.query.groups.findMany({
columns: {
id: true,
},
where: inArray(groups.name, ldapGroupsUserIsNotIn),
});
logger.debug(`Homarr has found groups in the database user is not in. user=${user.name} count=${groupIds.length}`);
if (groupIds.length > 0) {
await db.insert(groupMembers).values(
groupIds.map((group) => ({
userId: user.id,
groupId: group.id,
})),
);
logger.info(`Added user to groups successfully. user=${user.name} count=${groupIds.length}`);
} else {
logger.debug(`User is already in all groups of Homarr. user=${user.name}`);
}
}
const homarrGroupsUserIsNotIn = user.groups.filter((userGroup) => !userGroups.includes(userGroup.group.name));
if (homarrGroupsUserIsNotIn.length > 0) {
logger.debug(
`Homarr has the user in certain groups that LDAP does not have. user=${user.name} count=${homarrGroupsUserIsNotIn.length}`,
);
await db.delete(groupMembers).where(
and(
eq(groupMembers.userId, user.id),
inArray(
groupMembers.groupId,
homarrGroupsUserIsNotIn.map(({ groupId }) => groupId),
),
),
);
logger.info(`Removed user from groups successfully. user=${user.name} count=${homarrGroupsUserIsNotIn.length}`);
}
return {
id: user.id,
name: user.name,
name: credentials.name,
// Groups is used in events.ts to synchronize groups with external systems
groups: userGroups,
};
};

View File

@@ -10,30 +10,25 @@ type CredentialsConfiguration = Parameters<typeof Credentials>[0];
export const createCredentialsConfiguration = (db: Database) =>
({
id: "credentials",
type: "credentials",
name: "Credentials",
credentials: {
name: {
label: "Username",
type: "text",
},
password: {
label: "Password",
type: "password",
},
isLdap: {
label: "LDAP",
type: "checkbox",
},
},
// eslint-disable-next-line no-restricted-syntax
async authorize(credentials) {
const data = await validation.user.signIn.parseAsync(credentials);
if (data.credentialType === "ldap") {
return await authorizeWithLdapCredentialsAsync(db, data).catch(() => null);
}
return await authorizeWithBasicCredentialsAsync(db, data);
},
}) satisfies CredentialsConfiguration;
export const createLdapConfiguration = (db: Database) =>
({
id: "ldap",
type: "credentials",
name: "Ldap",
// eslint-disable-next-line no-restricted-syntax
async authorize(credentials) {
const data = await validation.user.signIn.parseAsync(credentials);
return await authorizeWithLdapCredentialsAsync(db, data).catch(() => null);
},
}) satisfies CredentialsConfiguration;

View File

@@ -25,7 +25,6 @@ describe("authorizeWithBasicCredentials", () => {
const result = await authorizeWithBasicCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "basic",
});
// Assert
@@ -47,7 +46,6 @@ describe("authorizeWithBasicCredentials", () => {
const result = await authorizeWithBasicCredentialsAsync(db, {
name: "test",
password: "wrong",
credentialType: "basic",
});
// Assert
@@ -69,7 +67,6 @@ describe("authorizeWithBasicCredentials", () => {
const result = await authorizeWithBasicCredentialsAsync(db, {
name: "wrong",
password: "test",
credentialType: "basic",
});
// Assert
@@ -88,7 +85,6 @@ describe("authorizeWithBasicCredentials", () => {
const result = await authorizeWithBasicCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "basic",
});
// Assert

View File

@@ -3,7 +3,7 @@ import { describe, expect, test, vi } from "vitest";
import type { Database } from "@homarr/db";
import { and, createId, eq } from "@homarr/db";
import { groupMembers, groups, users } from "@homarr/db/schema/sqlite";
import { groups, users } from "@homarr/db/schema/sqlite";
import { createDb } from "@homarr/db/test";
import { authorizeWithLdapCredentialsAsync } from "../credentials/authorization/ldap-authorization";
@@ -34,7 +34,6 @@ describe("authorizeWithLdapCredentials", () => {
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
@@ -57,7 +56,6 @@ describe("authorizeWithLdapCredentials", () => {
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
@@ -87,7 +85,6 @@ describe("authorizeWithLdapCredentials", () => {
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
@@ -120,7 +117,6 @@ describe("authorizeWithLdapCredentials", () => {
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
@@ -152,11 +148,11 @@ describe("authorizeWithLdapCredentials", () => {
const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
expect(result.name).toBe("test");
expect(result.groups).toHaveLength(0); // Groups are needed in signIn events callback
const dbUser = await db.query.users.findFirst({
where: eq(users.name, "test"),
});
@@ -197,11 +193,11 @@ describe("authorizeWithLdapCredentials", () => {
const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
expect(result.name).toBe("test");
expect(result.groups).toHaveLength(0); // Groups are needed in signIn events callback
const dbUser = await db.query.users.findFirst({
where: and(eq(users.name, "test"), eq(users.provider, "ldap")),
});
@@ -219,7 +215,8 @@ describe("authorizeWithLdapCredentials", () => {
expect(credentialsUser?.id).not.toBe(result.id);
});
test("should authorize user with correct credentials and update name", async () => {
// The name update occurs in the signIn event callback
test("should authorize user with correct credentials and return updated name", async () => {
// Arrange
const spy = vi.spyOn(ldapClient, "LdapClient");
spy.mockImplementation(
@@ -251,11 +248,10 @@ describe("authorizeWithLdapCredentials", () => {
const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
expect(result).toEqual({ id: userId, name: "test" });
expect(result).toEqual({ id: userId, name: "test", groups: [] });
const dbUser = await db.query.users.findFirst({
where: eq(users.id, userId),
@@ -263,12 +259,12 @@ describe("authorizeWithLdapCredentials", () => {
expect(dbUser).toBeDefined();
expect(dbUser?.id).toBe(userId);
expect(dbUser?.name).toBe("test");
expect(dbUser?.name).toBe("test-old");
expect(dbUser?.email).toBe("test@gmail.com");
expect(dbUser?.provider).toBe("ldap");
});
test("should authorize user with correct credentials and add him to the groups that he is in LDAP but not in Homar", async () => {
test("should authorize user with correct credentials and return his groups", async () => {
// Arrange
const spy = vi.spyOn(ldapClient, "LdapClient");
spy.mockImplementation(
@@ -311,83 +307,9 @@ describe("authorizeWithLdapCredentials", () => {
const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
expect(result).toEqual({ id: userId, name: "test" });
const dbGroupMembers = await db.query.groupMembers.findMany();
expect(dbGroupMembers).toHaveLength(1);
});
test("should authorize user with correct credentials and remove him from groups he is in Homarr but not in LDAP", async () => {
// Arrange
const spy = vi.spyOn(ldapClient, "LdapClient");
spy.mockImplementation(
() =>
({
bindAsync: vi.fn(() => Promise.resolve()),
searchAsync: vi.fn((argument: { options: { filter: string } }) =>
argument.options.filter.includes("group")
? Promise.resolve([
{
cn: "homarr_example",
},
])
: Promise.resolve([
{
dn: "test55",
mail: "test@gmail.com",
},
]),
),
disconnectAsync: vi.fn(),
}) as unknown as ldapClient.LdapClient,
);
const db = createDb();
const userId = createId();
await db.insert(users).values({
id: userId,
name: "test",
email: "test@gmail.com",
provider: "ldap",
});
const groupIds = [createId(), createId()] as const;
await db.insert(groups).values([
{
id: groupIds[0],
name: "homarr_example",
},
{
id: groupIds[1],
name: "homarr_no_longer_member",
},
]);
await db.insert(groupMembers).values([
{
userId,
groupId: groupIds[0],
},
{
userId,
groupId: groupIds[1],
},
]);
// Act
const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
expect(result).toEqual({ id: userId, name: "test" });
const dbGroupMembers = await db.query.groupMembers.findMany();
expect(dbGroupMembers).toHaveLength(1);
expect(dbGroupMembers[0]?.groupId).toBe(groupIds[0]);
expect(result).toEqual({ id: userId, name: "test", groups: ["homarr_example"] });
});
});

View File

@@ -5,7 +5,8 @@ import type { Database } from "@homarr/db";
import { getCurrentUserPermissionsAsync } from "./callbacks";
export const sessionTokenCookieName = "next-auth.session-token";
// Default of authjs
export const sessionTokenCookieName = "authjs.session-token";
export const expireDateAfter = (seconds: number) => {
return new Date(Date.now() + seconds * 1000);

View File

@@ -0,0 +1,67 @@
import { describe, expect, test } from "vitest";
import { users } from "@homarr/db/schema/sqlite";
import { createDb } from "@homarr/db/test";
import { createAdapter } from "../adapter";
describe("createAdapter should create drizzle adapter", () => {
test.each([["credentials" as const], ["ldap" as const], ["oidc" as const]])(
"createAdapter getUserByEmail should return user for provider %s when this provider provided",
async (provider) => {
// Arrange
const db = createDb();
const adapter = createAdapter(db, provider);
const email = "test@example.com";
await db.insert(users).values({ id: "1", name: "test", email, provider });
// Act
const user = await adapter.getUserByEmail?.(email);
// Assert
expect(user).toEqual({
id: "1",
name: "test",
email,
emailVerified: null,
image: null,
});
},
);
test.each([
["credentials", ["ldap", "oidc"]],
["ldap", ["credentials", "oidc"]],
["oidc", ["credentials", "ldap"]],
] as const)(
"createAdapter getUserByEmail should return null if only for other providers than %s exist",
async (requestedProvider, existingProviders) => {
// Arrange
const db = createDb();
const adapter = createAdapter(db, requestedProvider);
const email = "test@example.com";
for (const provider of existingProviders) {
await db.insert(users).values({ id: provider, name: `test-${provider}`, email, provider });
}
// Act
const user = await adapter.getUserByEmail?.(email);
// Assert
expect(user).toBeNull();
},
);
test("createAdapter getUserByEmail should throw error if provider is unknown", async () => {
// Arrange
const db = createDb();
const adapter = createAdapter(db, "unknown");
const email = "test@example.com";
// Act
const actAsync = async () => await adapter.getUserByEmail?.(email);
// Assert
await expect(actAsync()).rejects.toThrow("Unable to get user by email for unknown provider");
});
});

View File

@@ -1,7 +1,5 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import { cookies } from "next/headers";
import type { Adapter, AdapterUser } from "@auth/core/adapters";
import type { Account } from "next-auth";
import type { AdapterUser } from "@auth/core/adapters";
import type { JWT } from "next-auth/jwt";
import { describe, expect, test, vi } from "vitest";
@@ -9,7 +7,7 @@ import { groupMembers, groupPermissions, groups, users } from "@homarr/db/schema
import { createDb } from "@homarr/db/test";
import * as definitions from "@homarr/definitions";
import { createSessionCallback, createSignInCallback, getCurrentUserPermissionsAsync } from "../callbacks";
import { createSessionCallback, getCurrentUserPermissionsAsync } from "../callbacks";
// This one is placed here because it's used in multiple tests and needs to be the same reference
const setCookies = vi.fn();
@@ -141,151 +139,3 @@ describe("session callback", () => {
expect(result.user!.name).toEqual(user.name);
});
});
type AdapterSessionInput = Parameters<Exclude<Adapter["createSession"], undefined>>[0];
const createAdapter = () => {
const result = {
createSession: (input: AdapterSessionInput) => input,
};
vi.spyOn(result, "createSession");
return result;
};
// eslint-disable-next-line @typescript-eslint/consistent-type-imports
type SessionExport = typeof import("../session");
const mockSessionToken = "e9ef3010-6981-4a81-b9d6-8495d09cf3b5";
const mockSessionExpiry = new Date("2023-07-01");
vi.mock("../env.mjs", () => {
return {
env: {
AUTH_SESSION_EXPIRY_TIME: 60 * 60 * 24 * 7,
},
};
});
vi.mock("../session", async (importOriginal) => {
const mod = await importOriginal<SessionExport>();
const generateSessionToken = (): typeof mockSessionToken => mockSessionToken;
const expireDateAfter = (_seconds: number) => mockSessionExpiry;
return {
...mod,
generateSessionToken,
expireDateAfter,
} satisfies SessionExport;
});
describe("createSignInCallback", () => {
test("should return true if not credentials request and set colorScheme & sessionToken cookie", async () => {
// Arrange
const isCredentialsRequest = false;
const db = await prepareDbForSigninAsync("1");
const signInCallback = createSignInCallback(createAdapter(), db, isCredentialsRequest);
// Act
const result = await signInCallback({
user: { id: "1", emailVerified: new Date("2023-01-13") },
account: {} as Account,
});
// Assert
expect(result).toBe(true);
});
test("should return false if no adapter.createSession", async () => {
// Arrange
const isCredentialsRequest = true;
const db = await prepareDbForSigninAsync("1");
const signInCallback = createSignInCallback(
// https://github.com/nextauthjs/next-auth/issues/6106
{ createSession: undefined } as unknown as Adapter,
db,
isCredentialsRequest,
);
// Act
const result = await signInCallback({
user: { id: "1", emailVerified: new Date("2023-01-13") },
account: {} as Account,
});
// Assert
expect(result).toBe(false);
});
test("should call adapter.createSession with correct input", async () => {
// Arrange
const adapter = createAdapter();
const isCredentialsRequest = true;
const db = await prepareDbForSigninAsync("1");
const signInCallback = createSignInCallback(adapter, db, isCredentialsRequest);
const user = { id: "1", emailVerified: new Date("2023-01-13") };
const account = {} as Account;
// Act
await signInCallback({ user, account });
// Assert
expect(adapter.createSession).toHaveBeenCalledWith({
sessionToken: mockSessionToken,
userId: user.id,
expires: mockSessionExpiry,
});
expect(cookies().set).toHaveBeenCalledWith("next-auth.session-token", mockSessionToken, {
path: "/",
expires: mockSessionExpiry,
httpOnly: true,
sameSite: "lax",
secure: true,
});
});
test("should set colorScheme from db as cookie", async () => {
// Arrange
const isCredentialsRequest = true;
const db = await prepareDbForSigninAsync("1");
const signInCallback = createSignInCallback(createAdapter(), db, isCredentialsRequest);
// Act
const result = await signInCallback({
user: { id: "1", emailVerified: new Date("2023-01-13") },
account: {} as Account,
});
// Assert
expect(result).toBe(true);
expect(cookies().set).toHaveBeenCalledWith(
"homarr-color-scheme",
"dark",
expect.objectContaining({
path: "/",
}),
);
});
test("should return false if user not found in db", async () => {
// Arrange
const isCredentialsRequest = true;
const db = await prepareDbForSigninAsync("other-id");
const signInCallback = createSignInCallback(createAdapter(), db, isCredentialsRequest);
// Act
const result = await signInCallback({
user: { id: "1", emailVerified: new Date("2023-01-13") },
account: {} as Account,
});
// Assert
expect(result).toBe(false);
});
});
const prepareDbForSigninAsync = async (userId: string) => {
const db = createDb();
await db.insert(users).values({
id: userId,
colorScheme: "dark",
});
return db;
};

View File

@@ -0,0 +1,190 @@
import type { ResponseCookie } from "next/dist/compiled/@edge-runtime/cookies";
import type { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
import { cookies } from "next/headers";
import { describe, expect, test, vi } from "vitest";
import { eq } from "@homarr/db";
import type { Database } from "@homarr/db";
import { groupMembers, groups, users } from "@homarr/db/schema/sqlite";
import { createDb } from "@homarr/db/test";
import { createSignInEventHandler } from "../events";
vi.mock("../env.mjs", () => {
return {
env: {
AUTH_OIDC_GROUPS_ATTRIBUTE: "someRandomGroupsKey",
},
};
});
// eslint-disable-next-line @typescript-eslint/consistent-type-imports
type HeadersExport = typeof import("next/headers");
vi.mock("next/headers", async (importOriginal) => {
const mod = await importOriginal<HeadersExport>();
const result = {
set: (name: string, value: string, options: Partial<ResponseCookie>) => options as ResponseCookie,
} as unknown as ReadonlyRequestCookies;
vi.spyOn(result, "set");
const cookies = () => result;
return { ...mod, cookies } satisfies HeadersExport;
});
describe("createSignInEventHandler should create signInEventHandler", () => {
describe("signInEventHandler should synchronize ldap groups", () => {
test("should add missing group membership", async () => {
// Arrange
const db = createDb();
await createUserAsync(db);
await createGroupAsync(db);
const eventHandler = createSignInEventHandler(db);
// Act
await eventHandler?.({
user: { id: "1", name: "test", groups: ["test"] } as never,
profile: undefined,
account: null,
});
// Assert
const dbGroupMembers = await db.query.groupMembers.findFirst({
where: eq(groupMembers.userId, "1"),
});
expect(dbGroupMembers?.groupId).toBe("1");
});
test("should remove group membership", async () => {
// Arrange
const db = createDb();
await createUserAsync(db);
await createGroupAsync(db);
await db.insert(groupMembers).values({
userId: "1",
groupId: "1",
});
const eventHandler = createSignInEventHandler(db);
// Act
await eventHandler?.({
user: { id: "1", name: "test", groups: [] } as never,
profile: undefined,
account: null,
});
// Assert
const dbGroupMembers = await db.query.groupMembers.findFirst({
where: eq(groupMembers.userId, "1"),
});
expect(dbGroupMembers).toBeUndefined();
});
});
describe("signInEventHandler should synchronize oidc groups", () => {
test("should add missing group membership", async () => {
// Arrange
const db = createDb();
await createUserAsync(db);
await createGroupAsync(db);
const eventHandler = createSignInEventHandler(db);
// Act
await eventHandler?.({
user: { id: "1", name: "test" },
profile: { preferred_username: "test", someRandomGroupsKey: ["test"] },
account: null,
});
// Assert
const dbGroupMembers = await db.query.groupMembers.findFirst({
where: eq(groupMembers.userId, "1"),
});
expect(dbGroupMembers?.groupId).toBe("1");
});
test("should remove group membership", async () => {
// Arrange
const db = createDb();
await createUserAsync(db);
await createGroupAsync(db);
await db.insert(groupMembers).values({
userId: "1",
groupId: "1",
});
const eventHandler = createSignInEventHandler(db);
// Act
await eventHandler?.({
user: { id: "1", name: "test" },
profile: { preferred_username: "test", someRandomGroupsKey: [] },
account: null,
});
// Assert
const dbGroupMembers = await db.query.groupMembers.findFirst({
where: eq(groupMembers.userId, "1"),
});
expect(dbGroupMembers).toBeUndefined();
});
});
test.each([
["ldap" as const, { name: "test-new" }, undefined],
["oidc" as const, { name: "test" }, { preferred_username: "test-new" }],
["oidc" as const, { name: "test" }, { preferred_username: "test@example.com", name: "test-new" }],
])("signInEventHandler should update username for %s provider", async (_provider, user, profile) => {
// Arrange
const db = createDb();
await createUserAsync(db);
const eventHandler = createSignInEventHandler(db);
// Act
await eventHandler?.({
user: { id: "1", ...user },
profile,
account: null,
});
// Assert
const dbUser = await db.query.users.findFirst({
where: eq(users.id, "1"),
columns: {
name: true,
},
});
expect(dbUser?.name).toBe("test-new");
});
test("signInEventHandler should set homarr-color-scheme cookie", async () => {
// Arrange
const db = createDb();
await createUserAsync(db);
const eventHandler = createSignInEventHandler(db);
// Act
await eventHandler?.({
user: { id: "1", name: "test" },
profile: undefined,
account: null,
});
// Assert
expect(cookies().set).toHaveBeenCalledWith(
"homarr-color-scheme",
"dark",
expect.objectContaining({
path: "/",
}),
);
});
});
const createUserAsync = async (db: Database) =>
await db.insert(users).values({
id: "1",
name: "test",
colorScheme: "dark",
});
const createGroupAsync = async (db: Database) =>
await db.insert(groups).values({
id: "1",
name: "test",
});

View File

@@ -56,7 +56,6 @@ const initUserSchema = createUserSchema;
const signInSchema = z.object({
name: z.string().min(1),
password: z.string().min(1),
credentialType: z.enum(["basic", "ldap"]),
});
const registrationSchema = z

View File

@@ -18,6 +18,7 @@
"AUTH_OIDC_CLIENT_SECRET",
"AUTH_OIDC_ISSUER",
"AUTH_OIDC_SCOPE_OVERWRITE",
"AUTH_OIDC_GROUPS_ATTRIBUTE",
"AUTH_LDAP_USERNAME_ATTRIBUTE",
"AUTH_LDAP_USER_MAIL_ATTRIBUTE",
"AUTH_LDAP_USERNAME_FILTER_EXTRA_ARG",