Skip to content

Commit 33009f1

Browse files
authored
Refactoring work to move to Azure.AI.OpenAI v2.1.0 (#328)
1 parent 80bccd7 commit 33009f1

File tree

5 files changed

+190
-255
lines changed

5 files changed

+190
-255
lines changed

shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj

+4-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
</PropertyGroup>
2222

2323
<ItemGroup>
24-
<PackageReference Include="Azure.AI.OpenAI" Version="1.0.0-beta.17" />
25-
<PackageReference Include="Azure.Core" Version="1.39.0" />
26-
<PackageReference Include="SharpToken" Version="2.0.3" />
24+
<PackageReference Include="Azure.AI.OpenAI" Version="2.1.0" />
25+
<PackageReference Include="Microsoft.ML.Tokenizers" Version="1.0.1" />
26+
<PackageReference Include="Microsoft.ML.Tokenizers.Data.O200kBase" Version="1.0.1" />
27+
<PackageReference Include="Microsoft.ML.Tokenizers.Data.Cl100kBase" Version="1.0.1" />
2728
</ItemGroup>
2829

2930
<ItemGroup>

shell/agents/AIShell.OpenAI.Agent/Agent.cs

+22-14
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
using System.ClientModel;
12
using System.Text;
23
using System.Text.Json;
3-
using Azure.AI.OpenAI;
44
using AIShell.Abstraction;
5+
using OpenAI.Chat;
56

67
namespace AIShell.OpenAI.Agent;
78

@@ -106,37 +107,44 @@ public async Task<bool> ChatAsync(string input, IShell shell)
106107
return checkPass;
107108
}
108109

109-
string responseContent = null;
110-
StreamingResponse<StreamingChatCompletionsUpdate> response = await host.RunWithSpinnerAsync(
111-
() => _chatService.GetStreamingChatResponseAsync(input, token)
112-
).ConfigureAwait(false);
110+
IAsyncEnumerator<StreamingChatCompletionUpdate> response = await host
111+
.RunWithSpinnerAsync(
112+
() => _chatService.GetStreamingChatResponseAsync(input, token)
113+
).ConfigureAwait(false);
113114

114115
if (response is not null)
115116
{
117+
StreamingChatCompletionUpdate update = null;
116118
using var streamingRender = host.NewStreamRender(token);
117119

118120
try
119121
{
120-
await foreach (StreamingChatCompletionsUpdate chatUpdate in response)
122+
do
121123
{
122-
if (string.IsNullOrEmpty(chatUpdate.ContentUpdate))
124+
update = response.Current;
125+
if (update.ContentUpdate.Count > 0)
123126
{
124-
continue;
127+
streamingRender.Refresh(update.ContentUpdate[0].Text);
125128
}
126-
127-
streamingRender.Refresh(chatUpdate.ContentUpdate);
128129
}
130+
while (await response.MoveNextAsync().ConfigureAwait(continueOnCapturedContext: false));
129131
}
130132
catch (OperationCanceledException)
131133
{
132-
// Ignore the cancellation exception.
134+
update = null;
133135
}
134136

135-
responseContent = streamingRender.AccumulatedContent;
137+
if (update is null)
138+
{
139+
_chatService.CalibrateChatHistory(usage: null, response: null);
140+
}
141+
else
142+
{
143+
string responseContent = streamingRender.AccumulatedContent;
144+
_chatService.CalibrateChatHistory(update.Usage, new AssistantChatMessage(responseContent));
145+
}
136146
}
137147

138-
_chatService.AddResponseToHistory(responseContent);
139-
140148
return checkPass;
141149
}
142150

shell/agents/AIShell.OpenAI.Agent/Helpers.cs

+14-61
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
using System.Text.Json;
44
using System.Text.Json.Serialization;
55
using System.Text.Json.Serialization.Metadata;
6-
7-
using Azure;
8-
using Azure.Core;
9-
using Azure.Core.Pipeline;
6+
using System.ClientModel.Primitives;
107

118
namespace AIShell.OpenAI.Agent;
129

@@ -134,69 +131,25 @@ public override JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions option
134131
}
135132
}
136133

137-
#nullable enable
138-
139-
/// <summary>
140-
/// Used for setting user key for the Azure.OpenAI.Client.
141-
/// </summary>
142-
internal sealed class UserKeyPolicy : HttpPipelineSynchronousPolicy
143-
{
144-
private readonly string _name;
145-
private readonly AzureKeyCredential _credential;
146-
147-
/// <summary>
148-
/// Initializes a new instance of the <see cref="UserKeyPolicy"/> class.
149-
/// </summary>
150-
/// <param name="credential">The <see cref="AzureKeyCredential"/> used to authenticate requests.</param>
151-
/// <param name="name">The name of the key header used for the credential.</param>
152-
public UserKeyPolicy(AzureKeyCredential credential, string name)
153-
{
154-
ArgumentNullException.ThrowIfNull(credential);
155-
ArgumentException.ThrowIfNullOrEmpty(name);
156-
157-
_credential = credential;
158-
_name = name;
159-
}
160-
161-
/// <inheritdoc/>
162-
public override void OnSendingRequest(HttpMessage message)
163-
{
164-
base.OnSendingRequest(message);
165-
message.Request.Headers.SetValue(_name, _credential.Key);
166-
}
167-
}
168-
169134
/// <summary>
170-
/// Used for configuring the retry policy for Azure.OpenAI.Client.
135+
/// Initializes a new instance of the <see cref="ChatRetryPolicy"/> class.
171136
/// </summary>
172-
internal sealed class ChatRetryPolicy : RetryPolicy
137+
/// <param name="maxRetries">The maximum number of retries to attempt.</param>
138+
/// <param name="delayStrategy">The delay to use for computing the interval between retry attempts.</param>
139+
internal sealed class ChatRetryPolicy(int maxRetries = 2) : ClientRetryPolicy(maxRetries)
173140
{
174141
private const string RetryAfterHeaderName = "Retry-After";
175142
private const string RetryAfterMsHeaderName = "retry-after-ms";
176143
private const string XRetryAfterMsHeaderName = "x-ms-retry-after-ms";
177144

178-
/// <summary>
179-
/// Initializes a new instance of the <see cref="ChatRetryPolicy"/> class.
180-
/// </summary>
181-
/// <param name="maxRetries">The maximum number of retries to attempt.</param>
182-
/// <param name="delayStrategy">The delay to use for computing the interval between retry attempts.</param>
183-
public ChatRetryPolicy(int maxRetries = 2, DelayStrategy? delayStrategy = default) : base(
184-
maxRetries,
185-
delayStrategy ?? DelayStrategy.CreateExponentialDelayStrategy(
186-
initialDelay: TimeSpan.FromSeconds(0.8),
187-
maxDelay: TimeSpan.FromSeconds(5)))
188-
{
189-
// By default, we retry 2 times at most, and use a delay strategy that waits 5 seconds at most between retries.
190-
}
191-
192-
protected override bool ShouldRetry(HttpMessage message, Exception? exception) => ShouldRetryImpl(message, exception);
193-
protected override ValueTask<bool> ShouldRetryAsync(HttpMessage message, Exception? exception) => new(ShouldRetryImpl(message, exception));
145+
protected override bool ShouldRetry(PipelineMessage message, Exception exception) => ShouldRetryImpl(message, exception);
146+
protected override ValueTask<bool> ShouldRetryAsync(PipelineMessage message, Exception exception) => new(ShouldRetryImpl(message, exception));
194147

195-
private bool ShouldRetryImpl(HttpMessage message, Exception? exception)
148+
private bool ShouldRetryImpl(PipelineMessage message, Exception exception)
196149
{
197150
bool result = base.ShouldRetry(message, exception);
198151

199-
if (result && message.HasResponse)
152+
if (result && message.Response is not null)
200153
{
201154
TimeSpan? retryAfter = GetRetryAfterHeaderValue(message.Response.Headers);
202155
if (retryAfter > TimeSpan.FromSeconds(5))
@@ -209,22 +162,22 @@ private bool ShouldRetryImpl(HttpMessage message, Exception? exception)
209162
return result;
210163
}
211164

212-
private static TimeSpan? GetRetryAfterHeaderValue(ResponseHeaders headers)
165+
private static TimeSpan? GetRetryAfterHeaderValue(PipelineResponseHeaders headers)
213166
{
214167
if (headers.TryGetValue(RetryAfterMsHeaderName, out var retryAfterValue) ||
215168
headers.TryGetValue(XRetryAfterMsHeaderName, out retryAfterValue))
216169
{
217-
if (int.TryParse(retryAfterValue, out var delaySeconds))
170+
if (int.TryParse(retryAfterValue, out var delayInMS))
218171
{
219-
return TimeSpan.FromMilliseconds(delaySeconds);
172+
return TimeSpan.FromMilliseconds(delayInMS);
220173
}
221174
}
222175

223176
if (headers.TryGetValue(RetryAfterHeaderName, out retryAfterValue))
224177
{
225-
if (int.TryParse(retryAfterValue, out var delaySeconds))
178+
if (int.TryParse(retryAfterValue, out var delayInSec))
226179
{
227-
return TimeSpan.FromSeconds(delaySeconds);
180+
return TimeSpan.FromSeconds(delayInSec);
228181
}
229182

230183
if (DateTimeOffset.TryParse(retryAfterValue, out DateTimeOffset delayTime))

shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
using SharpToken;
1+
using Microsoft.ML.Tokenizers;
22

33
namespace AIShell.OpenAI.Agent;
44

55
internal class ModelInfo
66
{
77
// Models gpt4, gpt3.5, and the variants of them all use the 'cl100k_base' token encoding.
8-
// But the gpt-4o model uses the 'o200k_base' token encoding. For reference:
9-
// https://github.com/openai/tiktoken/blob/5d970c1100d3210b42497203d6b5c1e30cfda6cb/tiktoken/model.py#L7
10-
// https://github.com/dmitry-brazhenko/SharpToken/blob/main/SharpToken/Lib/Model.cs#L8
8+
// But gpt-4o and o1 models use the 'o200k_base' token encoding. For reference:
9+
// https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/tiktoken/model.py
1110
private const string Gpt4oEncoding = "o200k_base";
1211
private const string Gpt34Encoding = "cl100k_base";
1312

1413
private static readonly Dictionary<string, ModelInfo> s_modelMap;
15-
private static readonly Dictionary<string, Task<GptEncoding>> s_encodingMap;
14+
private static readonly Dictionary<string, Task<Tokenizer>> s_encodingMap;
1615

1716
static ModelInfo()
1817
{
@@ -21,6 +20,7 @@ static ModelInfo()
2120
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
2221
s_modelMap = new(StringComparer.OrdinalIgnoreCase)
2322
{
23+
["o1"] = new(tokenLimit: 200_000, encoding: Gpt4oEncoding),
2424
["gpt-4o"] = new(tokenLimit: 128_000, encoding: Gpt4oEncoding),
2525
["gpt-4"] = new(tokenLimit: 8_192),
2626
["gpt-4-32k"] = new(tokenLimit: 32_768),
@@ -35,8 +35,8 @@ static ModelInfo()
3535
// we don't block the startup and the values will be ready when we really need them.
3636
s_encodingMap = new(StringComparer.OrdinalIgnoreCase)
3737
{
38-
[Gpt34Encoding] = Task.Run(() => GptEncoding.GetEncoding(Gpt34Encoding)),
39-
[Gpt4oEncoding] = Task.Run(() => GptEncoding.GetEncoding(Gpt4oEncoding))
38+
[Gpt34Encoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt34Encoding)),
39+
[Gpt4oEncoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt4oEncoding))
4040
};
4141
}
4242

@@ -45,24 +45,24 @@ private ModelInfo(int tokenLimit, string encoding = null)
4545
TokenLimit = tokenLimit;
4646
_encodingName = encoding ?? Gpt34Encoding;
4747

48-
// For gpt4 and gpt3.5-turbo, the following 2 properties are the same.
48+
// For gpt4o, gpt4 and gpt3.5-turbo, the following 2 properties are the same.
4949
// See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
5050
TokensPerMessage = 3;
5151
TokensPerName = 1;
5252
}
5353

5454
private readonly string _encodingName;
55-
private GptEncoding _gptEncoding;
55+
private Tokenizer _gptEncoding;
5656

5757
internal int TokenLimit { get; }
5858
internal int TokensPerMessage { get; }
5959
internal int TokensPerName { get; }
60-
internal GptEncoding Encoding
60+
internal Tokenizer Encoding
6161
{
6262
get {
63-
_gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task<GptEncoding> value)
63+
_gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task<Tokenizer> value)
6464
? value.Result
65-
: GptEncoding.GetEncoding(_encodingName);
65+
: TiktokenTokenizer.CreateForEncoding(_encodingName);
6666
return _gptEncoding;
6767
}
6868
}

0 commit comments

Comments
 (0)