diff --git a/.changeset/wet-taxis-heal.md b/.changeset/wet-taxis-heal.md new file mode 100644 index 000000000000..fc4ca1d9b669 --- /dev/null +++ b/.changeset/wet-taxis-heal.md @@ -0,0 +1,5 @@ +--- +"@langchain/aws": minor +--- + +feat(aws): allow bedrock Application Inference Profile diff --git a/libs/providers/langchain-aws/src/chat_models.ts b/libs/providers/langchain-aws/src/chat_models.ts index 43a1e99a1cac..d2604b6d1189 100644 --- a/libs/providers/langchain-aws/src/chat_models.ts +++ b/libs/providers/langchain-aws/src/chat_models.ts @@ -98,6 +98,15 @@ export interface ChatBedrockConverseInput */ model?: string; + /** + * Application Inference Profile ARN to use for the model. + * For example, "arn:aws:bedrock:eu-west-1:123456789102:application-inference-profile/fm16bt65tzgx", will override this.model in final /invoke URL call. + * Must still provide `model` as normal modelId to benefit from all the metadata. + * See the below link for more details on creating and using application inference profiles. + * @link https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html + */ + applicationInferenceProfile?: string; + /** * The AWS region e.g. `us-west-2`. * Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config @@ -664,6 +673,8 @@ export class ChatBedrockConverse model = "anthropic.claude-3-haiku-20240307-v1:0"; + applicationInferenceProfile?: string | undefined = undefined; + streaming = false; region: string; @@ -745,6 +756,7 @@ export class ChatBedrockConverse this.region = region; this.model = rest?.model ?? this.model; + this.applicationInferenceProfile = rest?.applicationInferenceProfile; this.streaming = rest?.streaming ?? this.streaming; this.temperature = rest?.temperature; this.maxTokens = rest?.maxTokens; @@ -866,7 +878,7 @@ export class ChatBedrockConverse const params = this.invocationParams(options); const command = new ConverseCommand({ - modelId: this.model, + modelId: this.applicationInferenceProfile ?? this.model, messages: converseMessages, system: converseSystem, requestMetadata: options.requestMetadata, @@ -907,7 +919,7 @@ export class ChatBedrockConverse streamUsage = options.streamUsage; } const command = new ConverseStreamCommand({ - modelId: this.model, + modelId: this.applicationInferenceProfile ?? this.model, messages: converseMessages, system: converseSystem, requestMetadata: options.requestMetadata, diff --git a/libs/providers/langchain-aws/src/tests/chat_models.test.ts b/libs/providers/langchain-aws/src/tests/chat_models.test.ts index aeeb0427eb75..2990724c0926 100644 --- a/libs/providers/langchain-aws/src/tests/chat_models.test.ts +++ b/libs/providers/langchain-aws/src/tests/chat_models.test.ts @@ -9,11 +9,12 @@ import { import { concat } from "@langchain/core/utils/stream"; import { ConversationRole as BedrockConversationRole, + BedrockRuntimeClient, type Message as BedrockMessage, type SystemContentBlock as BedrockSystemContentBlock, } from "@aws-sdk/client-bedrock-runtime"; import { z } from "zod/v3"; -import { describe, expect, test, it } from "vitest"; +import { describe, expect, test, it, vi } from "vitest"; import { convertToConverseMessages } from "../utils/message_inputs.js"; import { handleConverseStreamContentBlockDelta } from "../utils/message_outputs.js"; import { ChatBedrockConverse } from "../chat_models.js"; @@ -451,6 +452,206 @@ test("Streaming supports empty string chunks", async () => { expect(finalChunk.content).toBe("Hello world!"); }); +describe("applicationInferenceProfile parameter", () => { + const baseConstructorArgs = { + region: "us-east-1", + credentials: { + secretAccessKey: "test-secret-key", + accessKeyId: "test-access-key", + }, + }; + + it("should initialize applicationInferenceProfile from constructor", () => { + const testArn = + "arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile"; + const model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-haiku-20240307-v1:0", + applicationInferenceProfile: testArn, + }); + expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0"); + expect(model.applicationInferenceProfile).toBe(testArn); + }); + + it("should be undefined when not provided in constructor", () => { + const model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-haiku-20240307-v1:0", + }); + + expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0"); + expect(model.applicationInferenceProfile).toBeUndefined(); + }); + + it("should send applicationInferenceProfile as modelId in ConverseCommand when provided", async () => { + const testArn = + "arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile"; + const mockSend = vi.fn().mockResolvedValue({ + output: { + message: { + role: "assistant", + content: [{ text: "Test response" }], + }, + }, + stopReason: "end_turn", + usage: { + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }, + }); + + const mockClient = { + send: mockSend, + } as unknown as BedrockRuntimeClient; + + const model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-haiku-20240307-v1:0", + applicationInferenceProfile: testArn, + client: mockClient, + }); + + await model.invoke([new HumanMessage("Hello")]); + + // Verify that send was called + expect(mockSend).toHaveBeenCalledTimes(1); + + // Verify that the command was created with applicationInferenceProfile as modelId + const commandArg = mockSend.mock.calls[0][0]; + expect(commandArg.input.modelId).toBe(testArn); + expect(commandArg.input.modelId).not.toBe( + "anthropic.claude-3-haiku-20240307-v1:0" + ); + }); + + it("should send model as modelId in ConverseCommand when applicationInferenceProfile is not provided", async () => { + const mockSend = vi.fn().mockResolvedValue({ + output: { + message: { + role: "assistant", + content: [{ text: "Test response" }], + }, + }, + stopReason: "end_turn", + usage: { + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }, + }); + + const mockClient = { + send: mockSend, + } as unknown as BedrockRuntimeClient; + + const model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-haiku-20240307-v1:0", + client: mockClient, + }); + + await model.invoke([new HumanMessage("Hello")]); + + // Verify that send was called + expect(mockSend).toHaveBeenCalledTimes(1); + + // Verify that the command was created with model as modelId + const commandArg = mockSend.mock.calls[0][0]; + expect(commandArg.input.modelId).toBe( + "anthropic.claude-3-haiku-20240307-v1:0" + ); + }); + + it("should send applicationInferenceProfile as modelId in ConverseStreamCommand when provided", async () => { + const testArn = + "arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile"; + const mockSend = vi.fn().mockResolvedValue({ + stream: (async function* () { + yield { + contentBlockDelta: { + contentBlockIndex: 0, + delta: { text: "Test" }, + }, + }; + yield { + metadata: { + usage: { + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }, + }, + }; + })(), + }); + + const mockClient = { + send: mockSend, + } as unknown as BedrockRuntimeClient; + + const model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-haiku-20240307-v1:0", + applicationInferenceProfile: testArn, + streaming: true, + client: mockClient, + }); + + await model.invoke([new HumanMessage("Hello")]); + + expect(mockSend).toHaveBeenCalledTimes(1); + + const commandArg = mockSend.mock.calls[0][0]; + expect(commandArg.input.modelId).toBe(testArn); + expect(commandArg.input.modelId).not.toBe( + "anthropic.claude-3-haiku-20240307-v1:0" + ); + }); + + it("should send model as modelId in ConverseStreamCommand when applicationInferenceProfile is not provided", async () => { + const mockSend = vi.fn().mockResolvedValue({ + stream: (async function* () { + yield { + contentBlockDelta: { + contentBlockIndex: 0, + delta: { text: "Test" }, + }, + }; + yield { + metadata: { + usage: { + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }, + }, + }; + })(), + }); + + const mockClient = { + send: mockSend, + } as unknown as BedrockRuntimeClient; + + const model = new ChatBedrockConverse({ + ...baseConstructorArgs, + model: "anthropic.claude-3-haiku-20240307-v1:0", + streaming: true, + client: mockClient, + }); + + await model.invoke([new HumanMessage("Hello")]); + + expect(mockSend).toHaveBeenCalledTimes(1); + + const commandArg = mockSend.mock.calls[0][0]; + expect(commandArg.input.modelId).toBe( + "anthropic.claude-3-haiku-20240307-v1:0" + ); + }); +}); + describe("tool_choice works for supported models", () => { const tool = { name: "weather",