Skip to content

Commit 2f5c28d

Browse files
committed
feat(aws): allow bedrock Application Inference Profile
Implements #7809 and ports #7822 into the @langchain/aws library
1 parent 26a1ce6 commit 2f5c28d

File tree

4 files changed

+221
-4
lines changed

4 files changed

+221
-4
lines changed

.changeset/wet-taxis-heal.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@langchain/aws": minor
3+
---
4+
5+
feat(aws): allow bedrock Application Inference Profile

libs/langchain-aws/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,4 @@
9191
"index.d.ts",
9292
"index.d.cts"
9393
]
94-
}
94+
}

libs/langchain-aws/src/chat_models.ts

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ export interface ChatBedrockConverseInput
9696
*/
9797
model?: string;
9898

99+
/**
100+
* Application Inference Profile ARN to use for the model.
101+
* For example, "arn:aws:bedrock:eu-west-1:123456789102:application-inference-profile/fm16bt65tzgx", will override this.model in final /invoke URL call.
102+
* Must still provide `model` as normal modelId to benefit from all the metadata.
103+
* See the below link for more details on creating and using application inference profiles.
104+
* @link https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html
105+
*/
106+
applicationInferenceProfile?: string;
107+
99108
/**
100109
* The AWS region e.g. `us-west-2`.
101110
* Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config
@@ -662,6 +671,8 @@ export class ChatBedrockConverse
662671

663672
model = "anthropic.claude-3-haiku-20240307-v1:0";
664673

674+
applicationInferenceProfile?: string | undefined = undefined;
675+
665676
streaming = false;
666677

667678
region: string;
@@ -743,6 +754,7 @@ export class ChatBedrockConverse
743754

744755
this.region = region;
745756
this.model = rest?.model ?? this.model;
757+
this.applicationInferenceProfile = rest?.applicationInferenceProfile;
746758
this.streaming = rest?.streaming ?? this.streaming;
747759
this.temperature = rest?.temperature;
748760
this.maxTokens = rest?.maxTokens;
@@ -864,7 +876,7 @@ export class ChatBedrockConverse
864876
const params = this.invocationParams(options);
865877

866878
const command = new ConverseCommand({
867-
modelId: this.model,
879+
modelId: this.applicationInferenceProfile ?? this.model,
868880
messages: converseMessages,
869881
system: converseSystem,
870882
requestMetadata: options.requestMetadata,
@@ -905,7 +917,7 @@ export class ChatBedrockConverse
905917
streamUsage = options.streamUsage;
906918
}
907919
const command = new ConverseStreamCommand({
908-
modelId: this.model,
920+
modelId: this.applicationInferenceProfile ?? this.model,
909921
messages: converseMessages,
910922
system: converseSystem,
911923
requestMetadata: options.requestMetadata,

libs/langchain-aws/src/tests/chat_models.test.ts

Lines changed: 201 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import {
1313
type SystemContentBlock as BedrockSystemContentBlock,
1414
} from "@aws-sdk/client-bedrock-runtime";
1515
import { z } from "zod";
16-
import { describe, expect, test } from "@jest/globals";
16+
import { describe, expect, test, jest } from "@jest/globals";
1717
import {
1818
convertToConverseMessages,
1919
handleConverseStreamContentBlockDelta,
@@ -450,6 +450,206 @@ test("Streaming supports empty string chunks", async () => {
450450
expect(finalChunk.content).toBe("Hello world!");
451451
});
452452

453+
describe("applicationInferenceProfile parameter", () => {
454+
const baseConstructorArgs = {
455+
region: "us-east-1",
456+
credentials: {
457+
secretAccessKey: "test-secret-key",
458+
accessKeyId: "test-access-key",
459+
},
460+
};
461+
462+
it("should initialize applicationInferenceProfile from constructor", () => {
463+
const testArn =
464+
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
465+
const model = new ChatBedrockConverse({
466+
...baseConstructorArgs,
467+
model: "anthropic.claude-3-haiku-20240307-v1:0",
468+
applicationInferenceProfile: testArn,
469+
});
470+
expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
471+
expect(model.applicationInferenceProfile).toBe(testArn);
472+
});
473+
474+
it("should be undefined when not provided in constructor", () => {
475+
const model = new ChatBedrockConverse({
476+
...baseConstructorArgs,
477+
model: "anthropic.claude-3-haiku-20240307-v1:0",
478+
});
479+
480+
expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
481+
expect(model.applicationInferenceProfile).toBeUndefined();
482+
});
483+
484+
it("should send applicationInferenceProfile as modelId in ConverseCommand when provided", async () => {
485+
const testArn =
486+
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
487+
const mockSend = jest.fn().mockResolvedValue({
488+
output: {
489+
message: {
490+
role: "assistant",
491+
content: [{ text: "Test response" }],
492+
},
493+
},
494+
stopReason: "end_turn",
495+
usage: {
496+
inputTokens: 10,
497+
outputTokens: 5,
498+
totalTokens: 15,
499+
},
500+
});
501+
502+
const mockClient = {
503+
send: mockSend,
504+
} as any;
505+
506+
const model = new ChatBedrockConverse({
507+
...baseConstructorArgs,
508+
model: "anthropic.claude-3-haiku-20240307-v1:0",
509+
applicationInferenceProfile: testArn,
510+
client: mockClient,
511+
});
512+
513+
await model.invoke([new HumanMessage("Hello")]);
514+
515+
// Verify that send was called
516+
expect(mockSend).toHaveBeenCalledTimes(1);
517+
518+
// Verify that the command was created with applicationInferenceProfile as modelId
519+
const commandArg = mockSend.mock.calls[0][0];
520+
expect(commandArg.input.modelId).toBe(testArn);
521+
expect(commandArg.input.modelId).not.toBe(
522+
"anthropic.claude-3-haiku-20240307-v1:0"
523+
);
524+
});
525+
526+
it("should send model as modelId in ConverseCommand when applicationInferenceProfile is not provided", async () => {
527+
const mockSend = jest.fn().mockResolvedValue({
528+
output: {
529+
message: {
530+
role: "assistant",
531+
content: [{ text: "Test response" }],
532+
},
533+
},
534+
stopReason: "end_turn",
535+
usage: {
536+
inputTokens: 10,
537+
outputTokens: 5,
538+
totalTokens: 15,
539+
},
540+
});
541+
542+
const mockClient = {
543+
send: mockSend,
544+
} as any;
545+
546+
const model = new ChatBedrockConverse({
547+
...baseConstructorArgs,
548+
model: "anthropic.claude-3-haiku-20240307-v1:0",
549+
client: mockClient,
550+
});
551+
552+
await model.invoke([new HumanMessage("Hello")]);
553+
554+
// Verify that send was called
555+
expect(mockSend).toHaveBeenCalledTimes(1);
556+
557+
// Verify that the command was created with model as modelId
558+
const commandArg = mockSend.mock.calls[0][0];
559+
expect(commandArg.input.modelId).toBe(
560+
"anthropic.claude-3-haiku-20240307-v1:0"
561+
);
562+
});
563+
564+
it("should send applicationInferenceProfile as modelId in ConverseStreamCommand when provided", async () => {
565+
const testArn =
566+
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
567+
const mockSend = jest.fn().mockResolvedValue({
568+
stream: (async function* () {
569+
yield {
570+
contentBlockDelta: {
571+
contentBlockIndex: 0,
572+
delta: { text: "Test" },
573+
},
574+
};
575+
yield {
576+
metadata: {
577+
usage: {
578+
inputTokens: 10,
579+
outputTokens: 5,
580+
totalTokens: 15,
581+
},
582+
},
583+
};
584+
})(),
585+
});
586+
587+
const mockClient = {
588+
send: mockSend,
589+
} as any;
590+
591+
const model = new ChatBedrockConverse({
592+
...baseConstructorArgs,
593+
model: "anthropic.claude-3-haiku-20240307-v1:0",
594+
applicationInferenceProfile: testArn,
595+
streaming: true,
596+
client: mockClient,
597+
});
598+
599+
await model.invoke([new HumanMessage("Hello")]);
600+
601+
expect(mockSend).toHaveBeenCalledTimes(1);
602+
603+
const commandArg = mockSend.mock.calls[0][0];
604+
expect(commandArg.input.modelId).toBe(testArn);
605+
expect(commandArg.input.modelId).not.toBe(
606+
"anthropic.claude-3-haiku-20240307-v1:0"
607+
);
608+
});
609+
610+
it("should send model as modelId in ConverseStreamCommand when applicationInferenceProfile is not provided", async () => {
611+
const mockSend = jest.fn().mockResolvedValue({
612+
stream: (async function* () {
613+
yield {
614+
contentBlockDelta: {
615+
contentBlockIndex: 0,
616+
delta: { text: "Test" },
617+
},
618+
};
619+
yield {
620+
metadata: {
621+
usage: {
622+
inputTokens: 10,
623+
outputTokens: 5,
624+
totalTokens: 15,
625+
},
626+
},
627+
};
628+
})(),
629+
});
630+
631+
const mockClient = {
632+
send: mockSend,
633+
} as any;
634+
635+
const model = new ChatBedrockConverse({
636+
...baseConstructorArgs,
637+
model: "anthropic.claude-3-haiku-20240307-v1:0",
638+
streaming: true,
639+
client: mockClient,
640+
});
641+
642+
await model.invoke([new HumanMessage("Hello")]);
643+
644+
expect(mockSend).toHaveBeenCalledTimes(1);
645+
646+
const commandArg = mockSend.mock.calls[0][0];
647+
expect(commandArg.input.modelId).toBe(
648+
"anthropic.claude-3-haiku-20240307-v1:0"
649+
);
650+
});
651+
});
652+
453653
describe("tool_choice works for supported models", () => {
454654
const tool = {
455655
name: "weather",

0 commit comments

Comments
 (0)