feat(EE): MFA implementation for enterprise edition

- Add TOTP-based two-factor authentication
- Add backup codes support
- Add MFA enforcement at workspace level
- Add MFA setup and challenge UI pages
- Support MFA for login and password reset flows
- Add MFA validation for secure pages
This commit is contained in:
Philipinho
2025-07-17 20:52:07 -07:00
parent 44e592763d
commit 8cd4b25dbc
49 changed files with 2051 additions and 54 deletions
+8
View File
@@ -1,6 +1,7 @@
import * as path from 'path';
import * as bcrypt from 'bcrypt';
import { sanitize } from 'sanitize-filename-ts';
import { FastifyRequest } from 'fastify';
export const envPath = path.resolve(process.cwd(), '..', '..', '.env');
@@ -74,3 +75,10 @@ export function sanitizeFileName(fileName: string): string {
const sanitizedFilename = sanitize(fileName).replace(/ /g, '_');
return sanitizedFilename.slice(0, 255);
}
export function extractBearerTokenFromHeader(
request: FastifyRequest,
): string | undefined {
const [type, token] = request.headers.authorization?.split(' ') ?? [];
return type === 'Bearer' ? token : undefined;
}
+58 -3
View File
@@ -6,6 +6,7 @@ import {
Post,
Res,
UseGuards,
Logger,
} from '@nestjs/common';
import { LoginDto } from './dto/login.dto';
import { AuthService } from './services/auth.service';
@@ -22,12 +23,16 @@ import { PasswordResetDto } from './dto/password-reset.dto';
import { VerifyUserTokenDto } from './dto/verify-user-token.dto';
import { FastifyReply } from 'fastify';
import { validateSsoEnforcement } from './auth.util';
import { ModuleRef } from '@nestjs/core';
@Controller('auth')
export class AuthController {
private readonly logger = new Logger(AuthController.name);
constructor(
private authService: AuthService,
private environmentService: EnvironmentService,
private moduleRef: ModuleRef,
) {}
@HttpCode(HttpStatus.OK)
@@ -39,6 +44,45 @@ export class AuthController {
) {
validateSsoEnforcement(workspace);
let MfaModule: any;
let isMfaModuleReady = false;
try {
// eslint-disable-next-line @typescript-eslint/no-require-imports
MfaModule = require('./../../ee/mfa/services/mfa.service');
isMfaModuleReady = true;
} catch (err) {
this.logger.debug(
'MFA module requested but EE module not bundled in this build',
);
isMfaModuleReady = false;
}
if (isMfaModuleReady) {
const mfaService = this.moduleRef.get(MfaModule.MfaService, {
strict: false,
});
const mfaResult = await mfaService.checkMfaRequirements(
loginInput,
workspace,
res,
);
if (mfaResult) {
// If user has MFA enabled OR workspace enforces MFA, require MFA verification
if (mfaResult.userHasMfa || mfaResult.requiresMfaSetup) {
return {
userHasMfa: mfaResult.userHasMfa,
requiresMfaSetup: mfaResult.requiresMfaSetup,
isMfaEnforced: mfaResult.isMfaEnforced,
};
} else if (mfaResult.authToken) {
// User doesn't have MFA and workspace doesn't require it
this.setAuthCookie(res, mfaResult.authToken);
return;
}
}
}
const authToken = await this.authService.login(loginInput, workspace.id);
this.setAuthCookie(res, authToken);
}
@@ -85,11 +129,22 @@ export class AuthController {
@Body() passwordResetDto: PasswordResetDto,
@AuthWorkspace() workspace: Workspace,
) {
const authToken = await this.authService.passwordReset(
const result = await this.authService.passwordReset(
passwordResetDto,
workspace.id,
workspace,
);
this.setAuthCookie(res, authToken);
if (result.requiresLogin) {
return {
requiresLogin: true,
};
}
// Set auth cookie if no MFA is required
this.setAuthCookie(res, result.authToken);
return {
requiresLogin: false,
};
}
@HttpCode(HttpStatus.OK)
@@ -3,6 +3,7 @@ export enum JwtType {
COLLAB = 'collab',
EXCHANGE = 'exchange',
ATTACHMENT = 'attachment',
MFA_TOKEN = 'mfa_token',
}
export type JwtPayload = {
sub: string;
@@ -30,3 +31,8 @@ export type JwtAttachmentPayload = {
type: 'attachment';
};
export interface JwtMfaTokenPayload {
sub: string;
workspaceId: string;
type: 'mfa_token';
}
@@ -47,7 +47,7 @@ export class AuthService {
includePassword: true,
});
const errorMessage = 'email or password does not match';
const errorMessage = 'Email or password does not match';
if (!user || user?.deletedAt) {
throw new UnauthorizedException(errorMessage);
}
@@ -156,10 +156,13 @@ export class AuthService {
});
}
async passwordReset(passwordResetDto: PasswordResetDto, workspaceId: string) {
async passwordReset(
passwordResetDto: PasswordResetDto,
workspace: Workspace,
) {
const userToken = await this.userTokenRepo.findById(
passwordResetDto.token,
workspaceId,
workspace.id,
);
if (
@@ -170,7 +173,9 @@ export class AuthService {
throw new BadRequestException('Invalid or expired token');
}
const user = await this.userRepo.findById(userToken.userId, workspaceId);
const user = await this.userRepo.findById(userToken.userId, workspace.id, {
includeUserMfa: true,
});
if (!user || user.deletedAt) {
throw new NotFoundException('User not found');
}
@@ -183,7 +188,7 @@ export class AuthService {
password: newPasswordHash,
},
user.id,
workspaceId,
workspace.id,
trx,
);
@@ -201,7 +206,18 @@ export class AuthService {
template: emailTemplate,
});
return this.tokenService.generateAccessToken(user);
// Check if user has MFA enabled or workspace enforces MFA
const userHasMfa = user?.['mfa']?.enabled || false;
const workspaceEnforcesMfa = workspace.enforceMfa || false;
if (userHasMfa || workspaceEnforcesMfa) {
return {
requiresLogin: true,
};
}
const authToken = await this.tokenService.generateAccessToken(user);
return { authToken };
}
async verifyUserToken(
@@ -9,6 +9,7 @@ import {
JwtAttachmentPayload,
JwtCollabPayload,
JwtExchangePayload,
JwtMfaTokenPayload,
JwtPayload,
JwtType,
} from '../dto/jwt-payload';
@@ -76,6 +77,22 @@ export class TokenService {
return this.jwtService.sign(payload, { expiresIn: '1h' });
}
async generateMfaToken(
user: User,
workspaceId: string,
): Promise<string> {
if (user.deactivatedAt || user.deletedAt) {
throw new ForbiddenException();
}
const payload: JwtMfaTokenPayload = {
sub: user.id,
workspaceId,
type: JwtType.MFA_TOKEN,
};
return this.jwtService.sign(payload, { expiresIn: '5m' });
}
async verifyJwt(token: string, tokenType: string) {
const payload = await this.jwtService.verifyAsync(token, {
secret: this.environmentService.getAppSecret(),
@@ -6,6 +6,7 @@ import { JwtPayload, JwtType } from '../dto/jwt-payload';
import { WorkspaceRepo } from '@docmost/db/repos/workspace/workspace.repo';
import { UserRepo } from '@docmost/db/repos/user/user.repo';
import { FastifyRequest } from 'fastify';
import { extractBearerTokenFromHeader } from '../../../common/helpers';
@Injectable()
export class JwtStrategy extends PassportStrategy(Strategy, 'jwt') {
@@ -18,7 +19,7 @@ export class JwtStrategy extends PassportStrategy(Strategy, 'jwt') {
) {
super({
jwtFromRequest: (req: FastifyRequest) => {
return req.cookies?.authToken || this.extractTokenFromHeader(req);
return req.cookies?.authToken || extractBearerTokenFromHeader(req);
},
ignoreExpiration: false,
secretOrKey: environmentService.getAppSecret(),
@@ -48,9 +49,4 @@ export class JwtStrategy extends PassportStrategy(Strategy, 'jwt') {
return { user, workspace };
}
private extractTokenFromHeader(request: FastifyRequest): string | undefined {
const [type, token] = request.headers.authorization?.split(' ') ?? [];
return type === 'Bearer' ? token : undefined;
}
}
@@ -29,7 +29,8 @@ import WorkspaceAbilityFactory from '../../casl/abilities/workspace-ability.fact
import {
WorkspaceCaslAction,
WorkspaceCaslSubject,
} from '../../casl/interfaces/workspace-ability.type';import { FastifyReply } from 'fastify';
} from '../../casl/interfaces/workspace-ability.type';
import { FastifyReply } from 'fastify';
import { EnvironmentService } from '../../../integrations/environment/environment.service';
import { CheckHostnameDto } from '../dto/check-hostname.dto';
import { RemoveWorkspaceUserDto } from '../dto/remove-workspace-user.dto';
@@ -257,17 +258,27 @@ export class WorkspaceController {
@AuthWorkspace() workspace: Workspace,
@Res({ passthrough: true }) res: FastifyReply,
) {
const authToken = await this.workspaceInvitationService.acceptInvitation(
const result = await this.workspaceInvitationService.acceptInvitation(
acceptInviteDto,
workspace,
);
res.setCookie('authToken', authToken, {
if (result.requiresLogin) {
return {
requiresLogin: true,
};
}
res.setCookie('authToken', result.authToken, {
httpOnly: true,
path: '/',
expires: this.environmentService.getCookieExpiresIn(),
secure: this.environmentService.isHttps(),
});
return {
requiresLogin: false,
};
}
@Public()
@@ -14,4 +14,8 @@ export class UpdateWorkspaceDto extends PartialType(CreateWorkspaceDto) {
@IsOptional()
@IsBoolean()
enforceSso: boolean;
@IsOptional()
@IsBoolean()
enforceMfa: boolean;
}
@@ -177,7 +177,14 @@ export class WorkspaceInvitationService {
}
}
async acceptInvitation(dto: AcceptInviteDto, workspace: Workspace) {
async acceptInvitation(
dto: AcceptInviteDto,
workspace: Workspace,
): Promise<{
authToken?: string;
requiresLogin?: boolean;
message?: string;
}> {
const invitation = await this.db
.selectFrom('workspaceInvitations')
.selectAll()
@@ -289,7 +296,14 @@ export class WorkspaceInvitationService {
});
}
return this.tokenService.generateAccessToken(newUser);
if (workspace.enforceMfa) {
return {
requiresLogin: true,
};
}
const authToken = await this.tokenService.generateAccessToken(newUser);
return { authToken };
}
async resendInvitation(
@@ -0,0 +1,41 @@
import { Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
await db.schema
.createTable('user_mfa')
.addColumn('id', 'uuid', (col) =>
col.primaryKey().defaultTo(sql`gen_uuid_v7()`),
)
.addColumn('user_id', 'uuid', (col) =>
col.references('users.id').onDelete('cascade').notNull(),
)
.addColumn('method', 'varchar(50)', (col) =>
col.notNull().defaultTo('totp'),
)
.addColumn('secret', 'varchar(255)', (col) => col)
.addColumn('enabled', 'boolean', (col) => col.defaultTo(false))
.addColumn('backup_codes', sql`text[]`, (col) => col)
.addColumn('workspace_id', 'uuid', (col) =>
col.references('workspaces.id').onDelete('cascade').notNull(),
)
.addColumn('created_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addColumn('updated_at', 'timestamptz', (col) =>
col.notNull().defaultTo(sql`now()`),
)
.addUniqueConstraint('user_mfa_user_id_unique', ['user_id'])
.execute();
// Add MFA policy columns to workspaces
await db.schema
.alterTable('workspaces')
.addColumn('enforce_mfa', 'boolean', (col) => col.defaultTo(false))
.execute();
}
export async function down(db: Kysely<any>): Promise<void> {
await db.schema.alterTable('workspaces').dropColumn('enforce_mfa').execute();
await db.schema.dropTable('user_mfa').execute();
}
@@ -1,7 +1,7 @@
import { Injectable } from '@nestjs/common';
import { InjectKysely } from 'nestjs-kysely';
import { KyselyDB, KyselyTransaction } from '@docmost/db/types/kysely.types';
import { Users } from '@docmost/db/types/db';
import { DB, Users } from '@docmost/db/types/db';
import { hashPassword } from '../../../common/helpers';
import { dbOrTx } from '@docmost/db/utils';
import {
@@ -11,7 +11,8 @@ import {
} from '@docmost/db/types/entity.types';
import { PaginationOptions } from '../../pagination/pagination-options';
import { executeWithPagination } from '@docmost/db/pagination/pagination';
import { sql } from 'kysely';
import { ExpressionBuilder, sql } from 'kysely';
import { jsonObjectFrom } from 'kysely/helpers/postgres';
@Injectable()
export class UserRepo {
@@ -40,6 +41,7 @@ export class UserRepo {
workspaceId: string,
opts?: {
includePassword?: boolean;
includeUserMfa?: boolean;
trx?: KyselyTransaction;
},
): Promise<User> {
@@ -48,6 +50,7 @@ export class UserRepo {
.selectFrom('users')
.select(this.baseFields)
.$if(opts?.includePassword, (qb) => qb.select('password'))
.$if(opts?.includeUserMfa, (qb) => qb.select(this.withUserMfa))
.where('id', '=', userId)
.where('workspaceId', '=', workspaceId)
.executeTakeFirst();
@@ -58,6 +61,7 @@ export class UserRepo {
workspaceId: string,
opts?: {
includePassword?: boolean;
includeUserMfa?: boolean;
trx?: KyselyTransaction;
},
): Promise<User> {
@@ -66,6 +70,7 @@ export class UserRepo {
.selectFrom('users')
.select(this.baseFields)
.$if(opts?.includePassword, (qb) => qb.select('password'))
.$if(opts?.includeUserMfa, (qb) => qb.select(this.withUserMfa))
.where(sql`LOWER(email)`, '=', sql`LOWER(${email})`)
.where('workspaceId', '=', workspaceId)
.executeTakeFirst();
@@ -177,4 +182,18 @@ export class UserRepo {
.returning(this.baseFields)
.executeTakeFirst();
}
withUserMfa(eb: ExpressionBuilder<DB, 'users'>) {
return jsonObjectFrom(
eb
.selectFrom('userMfa')
.select([
'userMfa.id',
'userMfa.method',
'userMfa.enabled',
'userMfa.createdAt',
])
.whereRef('userMfa.userId', '=', 'users.id'),
).as('mfa');
}
}
@@ -32,6 +32,7 @@ export class WorkspaceRepo {
'trialEndAt',
'enforceSso',
'plan',
'enforceMfa',
];
constructor(@InjectKysely() private readonly db: KyselyDB) {}
+34
View File
@@ -62,6 +62,7 @@ export interface AuthProviders {
deletedAt: Timestamp | null;
id: Generated<string>;
isEnabled: Generated<boolean>;
isGroupSyncEnabled: Generated<boolean>;
name: string;
oidcClientId: string | null;
oidcClientSecret: string | null;
@@ -122,6 +123,7 @@ export interface Comments {
pageId: string;
parentCommentId: string | null;
resolvedAt: Timestamp | null;
resolvedById: string | null;
selection: string | null;
type: string | null;
workspaceId: string;
@@ -165,6 +167,23 @@ export interface GroupUsers {
userId: string;
}
export interface Notifications {
actorId: string | null;
context: Generated<Json | null>;
createdAt: Generated<Timestamp>;
deduplicationKey: string | null;
emailSentAt: Timestamp | null;
entityId: string;
entityType: string;
id: Generated<string>;
readAt: Timestamp | null;
recipientId: string;
spaceId: string | null;
type: string;
updatedAt: Generated<Timestamp>;
workspaceId: string;
}
export interface PageHistory {
content: Json | null;
coverPhoto: string | null;
@@ -247,6 +266,18 @@ export interface Spaces {
workspaceId: string;
}
export interface UserMfa {
backupCodes: string[] | null;
createdAt: Generated<Timestamp>;
enabled: Generated<boolean | null>;
id: Generated<string>;
method: Generated<string>;
secret: string | null;
updatedAt: Generated<Timestamp>;
userId: string;
workspaceId: string;
}
export interface Users {
avatarUrl: string | null;
createdAt: Generated<Timestamp>;
@@ -305,6 +336,7 @@ export interface Workspaces {
id: Generated<string>;
licenseKey: string | null;
logo: string | null;
enforceMfa: Generated<boolean | null>;
name: string | null;
plan: string | null;
settings: Json | null;
@@ -324,11 +356,13 @@ export interface DB {
fileTasks: FileTasks;
groups: Groups;
groupUsers: GroupUsers;
notifications: Notifications;
pageHistory: PageHistory;
pages: Pages;
shares: Shares;
spaceMembers: SpaceMembers;
spaces: Spaces;
userMfa: UserMfa;
users: Users;
userTokens: UserTokens;
workspaceInvitations: WorkspaceInvitations;
@@ -18,6 +18,7 @@ import {
AuthAccounts,
Shares,
FileTasks,
UserMfa as _UserMFA,
} from './db';
// Workspace
@@ -113,3 +114,8 @@ export type UpdatableShare = Updateable<Omit<Shares, 'id'>>;
export type FileTask = Selectable<FileTasks>;
export type InsertableFileTask = Insertable<FileTasks>;
export type UpdatableFileTask = Updateable<Omit<FileTasks, 'id'>>;
// UserMFA
export type UserMFA = Selectable<_UserMFA>;
export type InsertableUserMFA = Insertable<_UserMFA>;
export type UpdatableUserMFA = Updateable<Omit<_UserMFA, 'id'>>;