Skip to content
12 changes: 12 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"@nestjs/platform-express": "^11",
"@nestjs/swagger": "^11.1.1",
"@nestjs/terminus": "^11",
"@nestjs/throttler": "^6.5.0",
"@types/json-merge-patch": "^1.0.0",
"@types/jsonschema": "^1.1.1",
"@user-office-software/duo-logger": "^2.1.1",
Expand Down
12 changes: 12 additions & 0 deletions src/app.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ import { HistoryModule } from "./history/history.module";
import { MaskSensitiveDataInterceptorModule } from "./common/interceptors/mask-sensitive-data.interceptor";
import { RuntimeConfigModule } from "./config/runtime-config/runtime-config.module";
import { MetadataKeysModule } from "./metadata-keys/metadatakeys.module";
import { OidcClientModule } from "./common/openid-client/openid-client.module";
import { ThrottlerModule } from "@nestjs/throttler";

@Module({
imports: [
Expand All @@ -53,6 +55,7 @@ import { MetadataKeysModule } from "./metadata-keys/metadatakeys.module";
cache: true,
}),
AuthModule,
OidcClientModule,
CaslModule,
AttachmentsModule,
CommonModule,
Expand Down Expand Up @@ -164,6 +167,15 @@ import { MetadataKeysModule } from "./metadata-keys/metadatakeys.module";
MaskSensitiveDataInterceptorModule,
(env: NodeJS.ProcessEnv) => env.MASK_PERSONAL_INFO === "yes",
),
ThrottlerModule.forRoot({
throttlers: [
{
name: "login",
ttl: 1000,
limit: 1,
},
],
}),
],
controllers: [],
providers: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import { AccessGroupService as AccessGroupService } from "./access-group.service
import { Injectable, Logger } from "@nestjs/common";
///import fetch from "node-fetch";

import { UserPayload } from "../interfaces/userPayload.interface";
import { HttpService } from "@nestjs/axios";
import { firstValueFrom } from "rxjs";
import { UserPayload } from "src/auth/interfaces/userPayload.interface";
/**
* This service is used to fetch access groups from a GraphQL API.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Injectable } from "@nestjs/common";
import { UserPayload } from "../interfaces/userPayload.interface";
import { AccessGroupService } from "./access-group.service";
import { UserPayload } from "src/auth/interfaces/userPayload.interface";

/**
* This service is used to get the access groups from multiple providers.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//import { AccessGroupService as AccessGroupService } from "./access-group.service";
import { Injectable, Logger } from "@nestjs/common";
import { ConfigService } from "@nestjs/config";
import { UserPayload } from "../interfaces/userPayload.interface";
import { AccessGroupService } from "./access-group.service";
import { UserPayload } from "src/auth/interfaces/userPayload.interface";

/**
* This service is used to get the access groups from the payload of the IDP.
Expand Down
2 changes: 1 addition & 1 deletion src/auth/access-group-provider/access-group.service.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { UserPayload } from "../interfaces/userPayload.interface";
import { UserPayload } from "src/auth/interfaces/userPayload.interface";

export abstract class AccessGroupService {
abstract getAccessGroups(
Expand Down
12 changes: 12 additions & 0 deletions src/auth/auth.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { AuthService } from "./auth.service";
import { Response } from "express";
import { Session } from "express-session";
import { ConfigService } from "@nestjs/config";
import { ThrottlerModule } from "@nestjs/throttler";

class AuthServiceMock {
login() {
Expand All @@ -29,6 +30,17 @@ describe("AuthController", () => {

beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
imports: [
ThrottlerModule.forRoot({
throttlers: [
{
name: "login",
ttl: 1000,
limit: 1,
},
],
}),
],
controllers: [AuthController],
providers: [
ConfigService,
Expand Down
19 changes: 19 additions & 0 deletions src/auth/auth.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
Res,
Req,
HttpCode,
Body,
} from "@nestjs/common";
import { LocalAuthGuard } from "./guards/local-auth.guard";
import { AuthService } from "./auth.service";
Expand All @@ -17,6 +18,7 @@ import {
ApiResponse,
ApiTags,
ApiQuery,
ApiOkResponse,
} from "@nestjs/swagger";
import { CredentialsDto } from "./dto/credentials.dto";
import { LdapAuthGuard } from "./guards/ldap.guard";
Expand All @@ -25,6 +27,8 @@ import { User } from "src/users/schemas/user.schema";
import { OidcAuthGuard } from "./guards/oidc.guard";
import { Request, Response } from "express";
import { ReturnedAuthLoginDto } from "./dto/returnedLogin.dto";
import { IdTokenDto } from "./dto/idToken.dto";
import { ThrottlerGuard } from "@nestjs/throttler";

@ApiBearerAuth()
@ApiTags("auth")
Expand Down Expand Up @@ -91,6 +95,21 @@ export class AuthController {
// this function is invoked when the oidc is set as an auth method. It's behaviour comes from the oidc strategy
}

@AllowAny()
@UseGuards(ThrottlerGuard)
@ApiBody({ type: IdTokenDto })
@ApiOkResponse({
description:
"Successfully authenticated via OIDC. Returns SciCat access token.",
type: ReturnedAuthLoginDto,
})
@Post("oidc/token")
async oidcTokenLogin(
@Body() idTokenDto: IdTokenDto,
): Promise<ReturnedAuthLoginDto> {
return this.authService.oidcTokenLogin(idTokenDto.idToken);
}

@AllowAny()
@UseGuards(OidcAuthGuard)
@Get("oidc/callback")
Expand Down
26 changes: 10 additions & 16 deletions src/auth/auth.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,31 @@ import { JwtModule } from "@nestjs/jwt";
import { JwtStrategy } from "./strategies/jwt.strategy";
import { LdapStrategy } from "./strategies/ldap.strategy";
import { ConfigService } from "@nestjs/config";
import { UsersService } from "src/users/users.service";
import { OidcConfig } from "src/config/configuration";
import { BuildOpenIdClient, OidcStrategy } from "./strategies/oidc.strategy";
import { OidcStrategy } from "./strategies/oidc.strategy";
import { accessGroupServiceFactory } from "./access-group-provider/access-group-service-factory";
import { AccessGroupService } from "./access-group-provider/access-group.service";
import { CaslModule } from "src/casl/casl.module";
import { SessionMiddleware } from "./middlewares/session.middleware";
import { OidcClientService } from "../common/openid-client/openid-client.service";
import { OidcAuthService } from "src/common/openid-client/openid-auth.service";

const OidcStrategyFactory = {
provide: "OidcStrategy",
useFactory: async (
authService: AuthService,
oidcClientService: OidcClientService,
oidcAuthService: OidcAuthService,
configService: ConfigService,
userService: UsersService,
accessGroupService: AccessGroupService,
) => {
if (!configService.get<OidcConfig>("oidc")?.issuer) {
return null;
}
const clientBuilder = new BuildOpenIdClient(configService);
const client = await clientBuilder.build();
const strategy = new OidcStrategy(
authService,
client,
configService,
userService,
accessGroupService,
);

const client = await oidcClientService.getClient();

const strategy = new OidcStrategy(client, configService, oidcAuthService);
return strategy;
},
inject: [AuthService, ConfigService, UsersService, AccessGroupService],
inject: [OidcClientService, OidcAuthService, ConfigService],
};

@Module({
Expand Down
59 changes: 57 additions & 2 deletions src/auth/auth.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,31 @@ import { JwtService } from "@nestjs/jwt";
import { Test, TestingModule } from "@nestjs/testing";
import { UsersService } from "src/users/users.service";
import { AuthService } from "./auth.service";
import { OidcClientService } from "src/common/openid-client/openid-client.service";
import { OidcAuthService } from "src/common/openid-client/openid-auth.service";

class JwtServiceMock {}
class JwtServiceMock {
sign = jest.fn();
}

class UsersServiceMock {}
class UsersServiceMock {
findByIdUserSettings = jest.fn().mockResolvedValue({});
createUserSettings = jest.fn().mockResolvedValue({});
}

class OidcClientServiceMock {
getClient = jest.fn();
}

class OidcAuthServiceMock {
validate = jest.fn();
}
describe("AuthService", () => {
let authService: AuthService;
let oidcClientService: OidcClientServiceMock;
let oidcAuthService: OidcAuthServiceMock;
let jwtService: JwtServiceMock;
let configService: ConfigService;

beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
Expand All @@ -18,13 +36,50 @@ describe("AuthService", () => {
ConfigService,
{ provide: JwtService, useClass: JwtServiceMock },
{ provide: UsersService, useClass: UsersServiceMock },
{ provide: OidcClientService, useClass: OidcClientServiceMock },
{ provide: OidcAuthService, useClass: OidcAuthServiceMock },
],
}).compile();

authService = module.get<AuthService>(AuthService);
oidcClientService = module.get(OidcClientService);
oidcAuthService = module.get(OidcAuthService);
jwtService = module.get(JwtService);
configService = module.get(ConfigService);
});

it("should be defined", () => {
expect(authService).toBeDefined();
});

describe("Oidc Token Login", () => {
const mockIdToken = "valid-id-token";
const mockUser = { _id: "user_123", email: "test@example.com" };
const mockAccessToken = "signed-jwt-token";
const mockExpiresIn = 3600;

it("should successfully validate token and return auth login dto", async () => {
const postLoginSpy = jest.spyOn(authService, "postLoginTasks");

const mockClient = {
callback: jest.fn().mockResolvedValue({ id_token: mockIdToken }),
};
oidcClientService.getClient.mockResolvedValue(mockClient);

jest.spyOn(configService, "get").mockImplementation((key: string) => {
if (key === "oidc.callbackURL") return "http://localhost/callback";
if (key === "jwt.expiresIn") return mockExpiresIn;
return null;
});

oidcAuthService.validate.mockResolvedValue(mockUser);
jwtService.sign.mockReturnValue(mockAccessToken);

const result = await authService.oidcTokenLogin(mockIdToken);

expect(result.access_token).toBe(mockAccessToken);
expect(result.userId).toBe(mockUser._id);
expect(postLoginSpy).toHaveBeenCalledWith(mockUser);
});
});
});
42 changes: 40 additions & 2 deletions src/auth/auth.service.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import { HttpException, HttpStatus, Injectable, Logger } from "@nestjs/common";
import {
HttpException,
HttpStatus,
Injectable,
Logger,
UnauthorizedException,
} from "@nestjs/common";
import { ConfigService } from "@nestjs/config";
import { JwtService } from "@nestjs/jwt";
import { compare } from "bcrypt";
Expand All @@ -7,16 +13,20 @@ import { UsersService } from "../users/users.service";
import { Request } from "express";
import { OidcConfig } from "src/config/configuration";
import { flattenObject, parseBoolean } from "src/common/utils";
import { Issuer } from "openid-client";
import { Issuer, TokenSet } from "openid-client";
import { ReturnedAuthLoginDto } from "./dto/returnedLogin.dto";
import { ReturnedUserDto } from "src/users/dto/returned-user.dto";
import { CreateUserSettingsDto } from "src/users/dto/create-user-settings.dto";
import { OidcClientService } from "../common/openid-client/openid-client.service";
import { OidcAuthService } from "src/common/openid-client/openid-auth.service";

@Injectable()
export class AuthService {
constructor(
private configService: ConfigService,
private usersService: UsersService,
private oidcClientService: OidcClientService,
private oidcAuthService: OidcAuthService,
private jwtService: JwtService,
) {}

Expand Down Expand Up @@ -56,6 +66,34 @@ export class AuthService {
};
}

async oidcTokenLogin(idToken: string): Promise<ReturnedAuthLoginDto> {
let tokenSet: TokenSet;
const client = await this.oidcClientService.getClient();
const callbackUrl = this.configService.get<string>("oidc.callbackURL");

try {
tokenSet = await client.callback(callbackUrl, { id_token: idToken }, {});
} catch (error) {
throw new UnauthorizedException(
`Invalid idToken: ${(error as Error).message}`,
);
}
const user = await this.oidcAuthService.validate(tokenSet);
const expiresIn = this.configService.get<number>("jwt.expiresIn");
const accessToken = this.jwtService.sign(user, { expiresIn });
await this.postLoginTasks(user);

return {
access_token: accessToken,
id: accessToken,
expires_in: expiresIn,
ttl: expiresIn,
created: new Date().toISOString(),
userId: user._id,
user: user as ReturnedUserDto,
};
}

async logout(req: Request) {
const logoutURL = this.configService.get<string>("logoutURL") || "";
const expressSessionSecret = this.configService.get<string>(
Expand Down
16 changes: 16 additions & 0 deletions src/auth/dto/idToken.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { ApiProperty } from "@nestjs/swagger";
import { IsNotEmpty, IsString } from "class-validator";

export class IdTokenDto {
@ApiProperty({
required: true,
description:
"OpenID Connect ID Token issued by the external identity provider for SciCat user authentication. " +
"The token must contain the user's identity claims (e.g., subject, email, name) and is verified " +
"by the backend to authenticate the user and generate a SciCat JWT access token.",
example: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9...",
})
@IsString()
@IsNotEmpty()
readonly idToken: string;
}
Loading
Loading