Skip to content

Commit e28b5a7

Browse files
authored
Improive Embedding.ConvertToVectorOfFloats (#38)
It's currently incorrect on big endian systems (e.g. IBM s390x), as it blits the little-endian bytes into an array of floats. On big endian the values need to be reversed. It also assumes the data being sent back from the service is always correct. If it's corrupted in certain ways, such as not being as long as was expected, not being quoted, etc., we can silently get bad data. It also more allocation than is necessary. It first allocates a string for the base64 string data. Then it allocates a new string with the quotes stripped off. Then it allocates a byte[] with the decoded bytes. And finally it allocates the resulting float[]. We can avoid the initial string by getting the raw memory from the BinaryData. We can avoid the substring by just slicing those bytes. And we can generally avoid the temporary byte[] by renting one from the array pool. That leaves just the required float[].
1 parent efd76b5 commit e28b5a7

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

src/Custom/Embeddings/Embedding.cs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
using System;
2+
using System.Buffers;
3+
using System.Buffers.Binary;
4+
using System.Buffers.Text;
25
using System.Collections.Generic;
6+
using System.Runtime.InteropServices;
37

48
namespace OpenAI.Embeddings;
59

@@ -89,13 +93,43 @@ internal Embedding(int index, BinaryData embeddingProperty, InternalEmbeddingObj
8993
// CUSTOM: Implemented custom logic to transform from BinaryData to ReadOnlyMemory<float>.
9094
private ReadOnlyMemory<float> ConvertToVectorOfFloats(BinaryData binaryData)
9195
{
92-
string base64EncodedVector = binaryData.ToString();
93-
base64EncodedVector = base64EncodedVector.Substring(1, base64EncodedVector.Length - 2);
96+
ReadOnlySpan<byte> base64 = binaryData.ToMemory().Span;
9497

95-
byte[] bytes = Convert.FromBase64String(base64EncodedVector);
96-
float[] vector = new float[bytes.Length / sizeof(float)];
97-
Buffer.BlockCopy(bytes, 0, vector, 0, bytes.Length);
98+
// Remove quotes around base64 string.
99+
if (base64.Length < 2 || base64[0] != (byte)'"' || base64[base64.Length - 1] != (byte)'"')
100+
{
101+
ThrowInvalidData();
102+
}
103+
base64 = base64.Slice(1, base64.Length - 2);
98104

105+
// Decode base64 string to bytes.
106+
byte[] bytes = ArrayPool<byte>.Shared.Rent(Base64.GetMaxDecodedFromUtf8Length(base64.Length));
107+
OperationStatus status = Base64.DecodeFromUtf8(base64, bytes.AsSpan(), out int bytesConsumed, out int bytesWritten);
108+
if (status != OperationStatus.Done || bytesWritten % sizeof(float) != 0)
109+
{
110+
ThrowInvalidData();
111+
}
112+
113+
// Interpret bytes as floats
114+
float[] vector = new float[bytesWritten / sizeof(float)];
115+
bytes.AsSpan(0, bytesWritten).CopyTo(MemoryMarshal.AsBytes(vector.AsSpan()));
116+
if (!BitConverter.IsLittleEndian)
117+
{
118+
Span<int> ints = MemoryMarshal.Cast<float, int>(vector.AsSpan());
119+
#if NET8_0_OR_GREATER
120+
BinaryPrimitives.ReverseEndianness(ints, ints);
121+
#else
122+
for (int i = 0; i < ints.Length; i++)
123+
{
124+
ints[i] = BinaryPrimitives.ReverseEndianness(ints[i]);
125+
}
126+
#endif
127+
}
128+
129+
ArrayPool<byte>.Shared.Return(bytes);
99130
return new ReadOnlyMemory<float>(vector);
131+
132+
static void ThrowInvalidData() =>
133+
throw new FormatException("The input is not a valid Base64 string of encoded floats.");
100134
}
101135
}

0 commit comments

Comments
 (0)