Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/wet-taxis-heal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@langchain/aws": minor
---

feat(aws): allow bedrock Application Inference Profile
27 changes: 27 additions & 0 deletions libs/providers/langchain-aws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,33 @@ const model = new ChatBedrockConverse({
const response = await model.invoke(new HumanMessage("Hello world!"));
```

### Using Application Inference Profiles

AWS Bedrock [Application Inference Profiles](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html) allow you to define custom endpoints that can route requests across regions or manage traffic for your models.

You can use an inference profile ARN by passing it to the `applicationInferenceProfile` parameter. When provided, this ARN will be used for the actual inference calls instead of the model ID:

```typescript
import { ChatBedrockConverse } from "@langchain/aws";

const model = new ChatBedrockConverse({
region: process.env.BEDROCK_AWS_REGION ?? "us-east-1",
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile:
"arn:aws:bedrock:eu-west-1:123456789102:application-inference-profile/fm16bt65tzgx",
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID,
},
});

const response = await model.invoke(new HumanMessage("Hello world!"));
```

**Important:** You must still provide the `model` parameter with the actual model ID (e.g., `"anthropic.claude-3-haiku-20240307-v1:0"`), even when using an inference profile. This ensures proper metadata tracking in tools like LangSmith, including accurate cost and latency measurements per model. The `applicationInferenceProfile` ARN will override the model ID only for the actual inference API calls.

> **Note:** AWS does not currently provide an API to programmatically retrieve the underlying model from an inference profile ARN, so it's the user's responsibility to ensure the `model` parameter matches the model configured in the inference profile.

### Streaming

```typescript
Expand Down
16 changes: 14 additions & 2 deletions libs/providers/langchain-aws/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ export interface ChatBedrockConverseInput
*/
model?: string;

/**
* Application Inference Profile ARN to use for the model.
* For example, "arn:aws:bedrock:eu-west-1:123456789102:application-inference-profile/fm16bt65tzgx", will override this.model in final /invoke URL call.
* Must still provide `model` as normal modelId to benefit from all the metadata.
* See the below link for more details on creating and using application inference profiles.
* @link https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html
*/
applicationInferenceProfile?: string;

/**
* The AWS region e.g. `us-west-2`.
* Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config
Expand Down Expand Up @@ -664,6 +673,8 @@ export class ChatBedrockConverse

model = "anthropic.claude-3-haiku-20240307-v1:0";

applicationInferenceProfile?: string;

streaming = false;

region: string;
Expand Down Expand Up @@ -745,6 +756,7 @@ export class ChatBedrockConverse

this.region = region;
this.model = rest?.model ?? this.model;
this.applicationInferenceProfile = rest?.applicationInferenceProfile;
this.streaming = rest?.streaming ?? this.streaming;
this.temperature = rest?.temperature;
this.maxTokens = rest?.maxTokens;
Expand Down Expand Up @@ -866,7 +878,7 @@ export class ChatBedrockConverse
const params = this.invocationParams(options);

const command = new ConverseCommand({
modelId: this.model,
modelId: this.applicationInferenceProfile ?? this.model,
messages: converseMessages,
system: converseSystem,
requestMetadata: options.requestMetadata,
Expand Down Expand Up @@ -907,7 +919,7 @@ export class ChatBedrockConverse
streamUsage = options.streamUsage;
}
const command = new ConverseStreamCommand({
modelId: this.model,
modelId: this.applicationInferenceProfile ?? this.model,
messages: converseMessages,
system: converseSystem,
requestMetadata: options.requestMetadata,
Expand Down
203 changes: 202 additions & 1 deletion libs/providers/langchain-aws/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import {
import { concat } from "@langchain/core/utils/stream";
import {
ConversationRole as BedrockConversationRole,
BedrockRuntimeClient,
type Message as BedrockMessage,
type SystemContentBlock as BedrockSystemContentBlock,
} from "@aws-sdk/client-bedrock-runtime";
import { z } from "zod/v3";
import { describe, expect, test, it } from "vitest";
import { describe, expect, test, it, vi } from "vitest";
import { convertToConverseMessages } from "../utils/message_inputs.js";
import { handleConverseStreamContentBlockDelta } from "../utils/message_outputs.js";
import { ChatBedrockConverse } from "../chat_models.js";
Expand Down Expand Up @@ -451,6 +452,206 @@ test("Streaming supports empty string chunks", async () => {
expect(finalChunk.content).toBe("Hello world!");
});

describe("applicationInferenceProfile parameter", () => {
const baseConstructorArgs = {
region: "us-east-1",
credentials: {
secretAccessKey: "test-secret-key",
accessKeyId: "test-access-key",
},
};

it("should initialize applicationInferenceProfile from constructor", () => {
const testArn =
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile: testArn,
});
expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
expect(model.applicationInferenceProfile).toBe(testArn);
});

it("should be undefined when not provided in constructor", () => {
const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
});

expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
expect(model.applicationInferenceProfile).toBeUndefined();
});

it("should send applicationInferenceProfile as modelId in ConverseCommand when provided", async () => {
const testArn =
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
const mockSend = vi.fn().mockResolvedValue({
output: {
message: {
role: "assistant",
content: [{ text: "Test response" }],
},
},
stopReason: "end_turn",
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile: testArn,
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

// Verify that send was called
expect(mockSend).toHaveBeenCalledTimes(1);

// Verify that the command was created with applicationInferenceProfile as modelId
const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(testArn);
expect(commandArg.input.modelId).not.toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});

it("should send model as modelId in ConverseCommand when applicationInferenceProfile is not provided", async () => {
const mockSend = vi.fn().mockResolvedValue({
output: {
message: {
role: "assistant",
content: [{ text: "Test response" }],
},
},
stopReason: "end_turn",
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

// Verify that send was called
expect(mockSend).toHaveBeenCalledTimes(1);

// Verify that the command was created with model as modelId
const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});

it("should send applicationInferenceProfile as modelId in ConverseStreamCommand when provided", async () => {
const testArn =
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
const mockSend = vi.fn().mockResolvedValue({
stream: (async function* () {
yield {
contentBlockDelta: {
contentBlockIndex: 0,
delta: { text: "Test" },
},
};
yield {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
},
};
})(),
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile: testArn,
streaming: true,
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

expect(mockSend).toHaveBeenCalledTimes(1);

const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(testArn);
expect(commandArg.input.modelId).not.toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});

it("should send model as modelId in ConverseStreamCommand when applicationInferenceProfile is not provided", async () => {
const mockSend = vi.fn().mockResolvedValue({
stream: (async function* () {
yield {
contentBlockDelta: {
contentBlockIndex: 0,
delta: { text: "Test" },
},
};
yield {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
},
};
})(),
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
streaming: true,
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

expect(mockSend).toHaveBeenCalledTimes(1);

const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});
});

describe("tool_choice works for supported models", () => {
const tool = {
name: "weather",
Expand Down
Loading