Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/wet-taxis-heal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@langchain/aws": minor
---

feat(aws): allow bedrock Application Inference Profile
16 changes: 14 additions & 2 deletions libs/providers/langchain-aws/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ export interface ChatBedrockConverseInput
*/
model?: string;

/**
* Application Inference Profile ARN to use for the model.
* For example, "arn:aws:bedrock:eu-west-1:123456789102:application-inference-profile/fm16bt65tzgx", will override this.model in final /invoke URL call.
* Must still provide `model` as normal modelId to benefit from all the metadata.
* See the below link for more details on creating and using application inference profiles.
* @link https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html
*/
applicationInferenceProfile?: string;

/**
* The AWS region e.g. `us-west-2`.
* Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config
Expand Down Expand Up @@ -664,6 +673,8 @@ export class ChatBedrockConverse

model = "anthropic.claude-3-haiku-20240307-v1:0";

applicationInferenceProfile?: string | undefined = undefined;

streaming = false;

region: string;
Expand Down Expand Up @@ -745,6 +756,7 @@ export class ChatBedrockConverse

this.region = region;
this.model = rest?.model ?? this.model;
this.applicationInferenceProfile = rest?.applicationInferenceProfile;
this.streaming = rest?.streaming ?? this.streaming;
this.temperature = rest?.temperature;
this.maxTokens = rest?.maxTokens;
Expand Down Expand Up @@ -866,7 +878,7 @@ export class ChatBedrockConverse
const params = this.invocationParams(options);

const command = new ConverseCommand({
modelId: this.model,
modelId: this.applicationInferenceProfile ?? this.model,
messages: converseMessages,
system: converseSystem,
requestMetadata: options.requestMetadata,
Expand Down Expand Up @@ -907,7 +919,7 @@ export class ChatBedrockConverse
streamUsage = options.streamUsage;
}
const command = new ConverseStreamCommand({
modelId: this.model,
modelId: this.applicationInferenceProfile ?? this.model,
messages: converseMessages,
system: converseSystem,
requestMetadata: options.requestMetadata,
Expand Down
203 changes: 202 additions & 1 deletion libs/providers/langchain-aws/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import {
import { concat } from "@langchain/core/utils/stream";
import {
ConversationRole as BedrockConversationRole,
BedrockRuntimeClient,
type Message as BedrockMessage,
type SystemContentBlock as BedrockSystemContentBlock,
} from "@aws-sdk/client-bedrock-runtime";
import { z } from "zod/v3";
import { describe, expect, test, it } from "vitest";
import { describe, expect, test, it, vi } from "vitest";
import { convertToConverseMessages } from "../utils/message_inputs.js";
import { handleConverseStreamContentBlockDelta } from "../utils/message_outputs.js";
import { ChatBedrockConverse } from "../chat_models.js";
Expand Down Expand Up @@ -451,6 +452,206 @@ test("Streaming supports empty string chunks", async () => {
expect(finalChunk.content).toBe("Hello world!");
});

describe("applicationInferenceProfile parameter", () => {
const baseConstructorArgs = {
region: "us-east-1",
credentials: {
secretAccessKey: "test-secret-key",
accessKeyId: "test-access-key",
},
};

it("should initialize applicationInferenceProfile from constructor", () => {
const testArn =
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile: testArn,
});
expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
expect(model.applicationInferenceProfile).toBe(testArn);
});

it("should be undefined when not provided in constructor", () => {
const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
});

expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
expect(model.applicationInferenceProfile).toBeUndefined();
});

it("should send applicationInferenceProfile as modelId in ConverseCommand when provided", async () => {
const testArn =
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
const mockSend = vi.fn().mockResolvedValue({
output: {
message: {
role: "assistant",
content: [{ text: "Test response" }],
},
},
stopReason: "end_turn",
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile: testArn,
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

// Verify that send was called
expect(mockSend).toHaveBeenCalledTimes(1);

// Verify that the command was created with applicationInferenceProfile as modelId
const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(testArn);
expect(commandArg.input.modelId).not.toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});

it("should send model as modelId in ConverseCommand when applicationInferenceProfile is not provided", async () => {
const mockSend = vi.fn().mockResolvedValue({
output: {
message: {
role: "assistant",
content: [{ text: "Test response" }],
},
},
stopReason: "end_turn",
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

// Verify that send was called
expect(mockSend).toHaveBeenCalledTimes(1);

// Verify that the command was created with model as modelId
const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});

it("should send applicationInferenceProfile as modelId in ConverseStreamCommand when provided", async () => {
const testArn =
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
const mockSend = vi.fn().mockResolvedValue({
stream: (async function* () {
yield {
contentBlockDelta: {
contentBlockIndex: 0,
delta: { text: "Test" },
},
};
yield {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
},
};
})(),
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile: testArn,
streaming: true,
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

expect(mockSend).toHaveBeenCalledTimes(1);

const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(testArn);
expect(commandArg.input.modelId).not.toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});

it("should send model as modelId in ConverseStreamCommand when applicationInferenceProfile is not provided", async () => {
const mockSend = vi.fn().mockResolvedValue({
stream: (async function* () {
yield {
contentBlockDelta: {
contentBlockIndex: 0,
delta: { text: "Test" },
},
};
yield {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
},
};
})(),
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
streaming: true,
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

expect(mockSend).toHaveBeenCalledTimes(1);

const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});
});

describe("tool_choice works for supported models", () => {
const tool = {
name: "weather",
Expand Down