diff --git a/apps/server/src/core/base/controllers/base.controller.ts b/apps/server/src/core/base/controllers/base.controller.ts index b9407bac3..b77c443e7 100644 --- a/apps/server/src/core/base/controllers/base.controller.ts +++ b/apps/server/src/core/base/controllers/base.controller.ts @@ -30,6 +30,8 @@ import { } from '../../casl/interfaces/space-ability.type'; import SpaceAbilityFactory from '../../casl/abilities/space-ability.factory'; import { SpaceIdDto } from '../../space/dto/space-id.dto'; +import { PageAccessService } from '../../page/page-access/page-access.service'; +import { PagePermissionRepo } from '@docmost/db/repos/page/page-permission.repo'; @UseGuards(JwtAuthGuard) @Controller('bases') @@ -40,6 +42,8 @@ export class BaseController { private readonly basePageResolverService: BasePageResolverService, private readonly baseRepo: BaseRepo, private readonly spaceAbility: SpaceAbilityFactory, + private readonly pageAccessService: PageAccessService, + private readonly pagePermissionRepo: PagePermissionRepo, ) {} @HttpCode(HttpStatus.OK) @@ -60,28 +64,25 @@ export class BaseController { @HttpCode(HttpStatus.OK) @Post('info') async getBase(@Body() dto: BaseIdDto, @AuthUser() user: User) { - const base = await this.baseService.getBaseInfo(dto.baseId); - - const ability = await this.spaceAbility.createForUser(user, base.spaceId); - if (ability.cannot(SpaceCaslAction.Read, SpaceCaslSubject.Base)) { - throw new ForbiddenException(); + const base = await this.baseRepo.findById(dto.pageId); + if (!base) { + throw new NotFoundException('Base not found'); } - return base; + await this.pageAccessService.validateCanView(base, user); + + return this.baseService.getBaseInfo(dto.pageId); } @HttpCode(HttpStatus.OK) @Post('update') async update(@Body() dto: UpdateBaseDto, @AuthUser() user: User) { - const base = await this.baseRepo.findById(dto.baseId); + const base = await this.baseRepo.findById(dto.pageId); if (!base) { throw new NotFoundException('Base not found'); } - const ability = await this.spaceAbility.createForUser(user, base.spaceId); - if (ability.cannot(SpaceCaslAction.Edit, SpaceCaslSubject.Base)) { - throw new ForbiddenException(); - } + await this.pageAccessService.validateCanEdit(base, user); return this.baseService.update(dto); } @@ -89,17 +90,14 @@ export class BaseController { @HttpCode(HttpStatus.OK) @Post('delete') async delete(@Body() dto: BaseIdDto, @AuthUser() user: User) { - const base = await this.baseRepo.findById(dto.baseId); + const base = await this.baseRepo.findById(dto.pageId); if (!base) { throw new NotFoundException('Base not found'); } - const ability = await this.spaceAbility.createForUser(user, base.spaceId); - if (ability.cannot(SpaceCaslAction.Manage, SpaceCaslSubject.Base)) { - throw new ForbiddenException(); - } + await this.pageAccessService.validateCanEdit(base, user); - await this.baseService.delete(dto.baseId); + await this.baseService.delete(dto.pageId); } @HttpCode(HttpStatus.OK) @@ -114,7 +112,16 @@ export class BaseController { throw new ForbiddenException(); } - return this.baseService.listBySpaceId(dto.spaceId, pagination); + const result = await this.baseService.listBySpaceId(dto.spaceId, pagination); + const accessible = await this.pagePermissionRepo.filterAccessiblePageIds({ + pageIds: result.items.map((b) => b.id), + userId: user.id, + }); + const accessibleSet = new Set(accessible); + return { + ...result, + items: result.items.filter((b) => accessibleSet.has(b.id)), + }; } @HttpCode(HttpStatus.OK) @@ -125,18 +132,15 @@ export class BaseController { @AuthWorkspace() workspace: Workspace, @Res() res: FastifyReply, ) { - const base = await this.baseRepo.findById(dto.baseId); + const base = await this.baseRepo.findById(dto.pageId); if (!base) { throw new NotFoundException('Base not found'); } - const ability = await this.spaceAbility.createForUser(user, base.spaceId); - if (ability.cannot(SpaceCaslAction.Read, SpaceCaslSubject.Base)) { - throw new ForbiddenException(); - } + await this.pageAccessService.validateCanView(base, user); await this.baseCsvExportService.streamBaseAsCsv( - dto.baseId, + dto.pageId, workspace.id, res, );