Skip to content

Commit 2ed02b2

Browse files
committed
feat(aws): allow bedrock Application Inference Profile
Implements #7809 and ports #7822 into the @langchain/aws library
1 parent 9e6eb91 commit 2ed02b2

File tree

3 files changed

+221
-3
lines changed

3 files changed

+221
-3
lines changed

.changeset/wet-taxis-heal.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@langchain/aws": minor
3+
---
4+
5+
feat(aws): allow bedrock Application Inference Profile

libs/providers/langchain-aws/src/chat_models.ts

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ export interface ChatBedrockConverseInput
9898
*/
9999
model?: string;
100100

101+
/**
102+
* Application Inference Profile ARN to use for the model.
103+
* For example, "arn:aws:bedrock:eu-west-1:123456789102:application-inference-profile/fm16bt65tzgx", will override this.model in final /invoke URL call.
104+
* Must still provide `model` as normal modelId to benefit from all the metadata.
105+
* See the below link for more details on creating and using application inference profiles.
106+
* @link https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html
107+
*/
108+
applicationInferenceProfile?: string;
109+
101110
/**
102111
* The AWS region e.g. `us-west-2`.
103112
* Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config
@@ -664,6 +673,8 @@ export class ChatBedrockConverse
664673

665674
model = "anthropic.claude-3-haiku-20240307-v1:0";
666675

676+
applicationInferenceProfile?: string | undefined = undefined;
677+
667678
streaming = false;
668679

669680
region: string;
@@ -745,6 +756,7 @@ export class ChatBedrockConverse
745756

746757
this.region = region;
747758
this.model = rest?.model ?? this.model;
759+
this.applicationInferenceProfile = rest?.applicationInferenceProfile;
748760
this.streaming = rest?.streaming ?? this.streaming;
749761
this.temperature = rest?.temperature;
750762
this.maxTokens = rest?.maxTokens;
@@ -866,7 +878,7 @@ export class ChatBedrockConverse
866878
const params = this.invocationParams(options);
867879

868880
const command = new ConverseCommand({
869-
modelId: this.model,
881+
modelId: this.applicationInferenceProfile ?? this.model,
870882
messages: converseMessages,
871883
system: converseSystem,
872884
requestMetadata: options.requestMetadata,
@@ -907,7 +919,7 @@ export class ChatBedrockConverse
907919
streamUsage = options.streamUsage;
908920
}
909921
const command = new ConverseStreamCommand({
910-
modelId: this.model,
922+
modelId: this.applicationInferenceProfile ?? this.model,
911923
messages: converseMessages,
912924
system: converseSystem,
913925
requestMetadata: options.requestMetadata,

libs/providers/langchain-aws/src/tests/chat_models.test.ts

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ import {
99
import { concat } from "@langchain/core/utils/stream";
1010
import {
1111
ConversationRole as BedrockConversationRole,
12+
BedrockRuntimeClient,
1213
type Message as BedrockMessage,
1314
type SystemContentBlock as BedrockSystemContentBlock,
1415
} from "@aws-sdk/client-bedrock-runtime";
1516
import { z } from "zod/v3";
16-
import { describe, expect, test, it } from "vitest";
17+
import { describe, expect, test, it, vi } from "vitest";
1718
import { convertToConverseMessages } from "../utils/message_inputs.js";
1819
import { handleConverseStreamContentBlockDelta } from "../utils/message_outputs.js";
1920
import { ChatBedrockConverse } from "../chat_models.js";
@@ -451,6 +452,206 @@ test("Streaming supports empty string chunks", async () => {
451452
expect(finalChunk.content).toBe("Hello world!");
452453
});
453454

455+
describe("applicationInferenceProfile parameter", () => {
456+
const baseConstructorArgs = {
457+
region: "us-east-1",
458+
credentials: {
459+
secretAccessKey: "test-secret-key",
460+
accessKeyId: "test-access-key",
461+
},
462+
};
463+
464+
it("should initialize applicationInferenceProfile from constructor", () => {
465+
const testArn =
466+
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
467+
const model = new ChatBedrockConverse({
468+
...baseConstructorArgs,
469+
model: "anthropic.claude-3-haiku-20240307-v1:0",
470+
applicationInferenceProfile: testArn,
471+
});
472+
expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
473+
expect(model.applicationInferenceProfile).toBe(testArn);
474+
});
475+
476+
it("should be undefined when not provided in constructor", () => {
477+
const model = new ChatBedrockConverse({
478+
...baseConstructorArgs,
479+
model: "anthropic.claude-3-haiku-20240307-v1:0",
480+
});
481+
482+
expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
483+
expect(model.applicationInferenceProfile).toBeUndefined();
484+
});
485+
486+
it("should send applicationInferenceProfile as modelId in ConverseCommand when provided", async () => {
487+
const testArn =
488+
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
489+
const mockSend = vi.fn().mockResolvedValue({
490+
output: {
491+
message: {
492+
role: "assistant",
493+
content: [{ text: "Test response" }],
494+
},
495+
},
496+
stopReason: "end_turn",
497+
usage: {
498+
inputTokens: 10,
499+
outputTokens: 5,
500+
totalTokens: 15,
501+
},
502+
});
503+
504+
const mockClient = {
505+
send: mockSend,
506+
} as unknown as BedrockRuntimeClient;
507+
508+
const model = new ChatBedrockConverse({
509+
...baseConstructorArgs,
510+
model: "anthropic.claude-3-haiku-20240307-v1:0",
511+
applicationInferenceProfile: testArn,
512+
client: mockClient,
513+
});
514+
515+
await model.invoke([new HumanMessage("Hello")]);
516+
517+
// Verify that send was called
518+
expect(mockSend).toHaveBeenCalledTimes(1);
519+
520+
// Verify that the command was created with applicationInferenceProfile as modelId
521+
const commandArg = mockSend.mock.calls[0][0];
522+
expect(commandArg.input.modelId).toBe(testArn);
523+
expect(commandArg.input.modelId).not.toBe(
524+
"anthropic.claude-3-haiku-20240307-v1:0"
525+
);
526+
});
527+
528+
it("should send model as modelId in ConverseCommand when applicationInferenceProfile is not provided", async () => {
529+
const mockSend = vi.fn().mockResolvedValue({
530+
output: {
531+
message: {
532+
role: "assistant",
533+
content: [{ text: "Test response" }],
534+
},
535+
},
536+
stopReason: "end_turn",
537+
usage: {
538+
inputTokens: 10,
539+
outputTokens: 5,
540+
totalTokens: 15,
541+
},
542+
});
543+
544+
const mockClient = {
545+
send: mockSend,
546+
} as unknown as BedrockRuntimeClient;
547+
548+
const model = new ChatBedrockConverse({
549+
...baseConstructorArgs,
550+
model: "anthropic.claude-3-haiku-20240307-v1:0",
551+
client: mockClient,
552+
});
553+
554+
await model.invoke([new HumanMessage("Hello")]);
555+
556+
// Verify that send was called
557+
expect(mockSend).toHaveBeenCalledTimes(1);
558+
559+
// Verify that the command was created with model as modelId
560+
const commandArg = mockSend.mock.calls[0][0];
561+
expect(commandArg.input.modelId).toBe(
562+
"anthropic.claude-3-haiku-20240307-v1:0"
563+
);
564+
});
565+
566+
it("should send applicationInferenceProfile as modelId in ConverseStreamCommand when provided", async () => {
567+
const testArn =
568+
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
569+
const mockSend = vi.fn().mockResolvedValue({
570+
stream: (async function* () {
571+
yield {
572+
contentBlockDelta: {
573+
contentBlockIndex: 0,
574+
delta: { text: "Test" },
575+
},
576+
};
577+
yield {
578+
metadata: {
579+
usage: {
580+
inputTokens: 10,
581+
outputTokens: 5,
582+
totalTokens: 15,
583+
},
584+
},
585+
};
586+
})(),
587+
});
588+
589+
const mockClient = {
590+
send: mockSend,
591+
} as unknown as BedrockRuntimeClient;
592+
593+
const model = new ChatBedrockConverse({
594+
...baseConstructorArgs,
595+
model: "anthropic.claude-3-haiku-20240307-v1:0",
596+
applicationInferenceProfile: testArn,
597+
streaming: true,
598+
client: mockClient,
599+
});
600+
601+
await model.invoke([new HumanMessage("Hello")]);
602+
603+
expect(mockSend).toHaveBeenCalledTimes(1);
604+
605+
const commandArg = mockSend.mock.calls[0][0];
606+
expect(commandArg.input.modelId).toBe(testArn);
607+
expect(commandArg.input.modelId).not.toBe(
608+
"anthropic.claude-3-haiku-20240307-v1:0"
609+
);
610+
});
611+
612+
it("should send model as modelId in ConverseStreamCommand when applicationInferenceProfile is not provided", async () => {
613+
const mockSend = vi.fn().mockResolvedValue({
614+
stream: (async function* () {
615+
yield {
616+
contentBlockDelta: {
617+
contentBlockIndex: 0,
618+
delta: { text: "Test" },
619+
},
620+
};
621+
yield {
622+
metadata: {
623+
usage: {
624+
inputTokens: 10,
625+
outputTokens: 5,
626+
totalTokens: 15,
627+
},
628+
},
629+
};
630+
})(),
631+
});
632+
633+
const mockClient = {
634+
send: mockSend,
635+
} as unknown as BedrockRuntimeClient;
636+
637+
const model = new ChatBedrockConverse({
638+
...baseConstructorArgs,
639+
model: "anthropic.claude-3-haiku-20240307-v1:0",
640+
streaming: true,
641+
client: mockClient,
642+
});
643+
644+
await model.invoke([new HumanMessage("Hello")]);
645+
646+
expect(mockSend).toHaveBeenCalledTimes(1);
647+
648+
const commandArg = mockSend.mock.calls[0][0];
649+
expect(commandArg.input.modelId).toBe(
650+
"anthropic.claude-3-haiku-20240307-v1:0"
651+
);
652+
});
653+
});
654+
454655
describe("tool_choice works for supported models", () => {
455656
const tool = {
456657
name: "weather",

0 commit comments

Comments
 (0)