Skip to content

Commit 641dc11

Browse files
committed
Improvements to EF type mapping
* Support scaffolding of vector types * Consolidated different mappins and plugins into the same files * A bit of test code cleanup Closes #44
1 parent 4de753b commit 641dc11

8 files changed

+94
-88
lines changed

src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs

-19
This file was deleted.

src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs

-11
This file was deleted.

src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs

-19
This file was deleted.

src/Pgvector.EntityFrameworkCore/SparsevecTypeMappingSourcePlugin.cs

-11
This file was deleted.

src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs

-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ public void ApplyServices(IServiceCollection services)
1717
.TryAdd<IMethodCallTranslatorPlugin, VectorDbFunctionsTranslatorPlugin>();
1818

1919
services.AddSingleton<IRelationalTypeMappingSourcePlugin, VectorTypeMappingSourcePlugin>();
20-
services.AddSingleton<IRelationalTypeMappingSourcePlugin, HalfvecTypeMappingSourcePlugin>();
21-
services.AddSingleton<IRelationalTypeMappingSourcePlugin, SparsevecTypeMappingSourcePlugin>();
2220
}
2321

2422
public void Validate(IDbContextOptions options) { }

src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs

+11-4
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,18 @@ namespace Pgvector.EntityFrameworkCore;
66

77
public class VectorTypeMapping : RelationalTypeMapping
88
{
9-
public static VectorTypeMapping Default { get; } = new();
9+
public static VectorTypeMapping Default { get; } = new("vector", typeof(Vector));
1010

11-
public VectorTypeMapping() : base("vector", typeof(Vector)) { }
12-
13-
public VectorTypeMapping(string storeType) : base(storeType, typeof(Vector)) { }
11+
public VectorTypeMapping(string storeType, Type clrType, int? size = null)
12+
: this(
13+
new RelationalTypeMappingParameters(
14+
new CoreTypeMappingParameters(clrType),
15+
storeType,
16+
StoreTypePostfix.Size,
17+
size: size,
18+
fixedLength: true))
19+
{
20+
}
1421

1522
protected VectorTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }
1623

src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs

+26-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,30 @@ namespace Pgvector.EntityFrameworkCore;
55
public class VectorTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin
66
{
77
public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo)
8-
=> mappingInfo.ClrType == typeof(Vector)
9-
? new VectorTypeMapping(mappingInfo.StoreTypeName ?? "vector")
10-
: null;
8+
{
9+
if (mappingInfo.StoreTypeName is not null)
10+
{
11+
VectorTypeMapping? mapping = (mappingInfo.StoreTypeNameBase ?? mappingInfo.StoreTypeName) switch
12+
{
13+
"vector" => new(mappingInfo.StoreTypeName, typeof(Vector), mappingInfo.Size),
14+
"halfvec" => new(mappingInfo.StoreTypeName, typeof(HalfVector), mappingInfo.Size),
15+
"sparsevec" => new(mappingInfo.StoreTypeName, typeof(SparseVector), mappingInfo.Size),
16+
_ => null,
17+
};
18+
19+
// If the caller hasn't specified a CLR type (this is scaffolding), or if the user has specified
20+
// the one matching the store type, return the mapping.
21+
return mappingInfo.ClrType is null || mappingInfo.ClrType == mapping?.ClrType
22+
? mapping : null;
23+
}
24+
25+
// No store type specified, look up by the CLR type only
26+
return mappingInfo.ClrType switch
27+
{
28+
var t when t == typeof(Vector) => new VectorTypeMapping("vector", typeof(Vector), mappingInfo.Size),
29+
var t when t == typeof(HalfVector) => new VectorTypeMapping("halfvec", typeof(HalfVector), mappingInfo.Size),
30+
var t when t == typeof(SparseVector) => new VectorTypeMapping("sparsevec", typeof(SparseVector), mappingInfo.Size),
31+
_ => null,
32+
};
33+
}
1134
}

tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs

+57-19
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
using Microsoft.EntityFrameworkCore;
22
using Microsoft.EntityFrameworkCore.Infrastructure;
33
using Microsoft.EntityFrameworkCore.Storage;
4+
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal;
45
using Pgvector.EntityFrameworkCore;
56
using System.Collections;
67
using System.ComponentModel.DataAnnotations.Schema;
8+
using System.Security.AccessControl;
79

810
namespace Pgvector.Tests;
911

@@ -66,79 +68,115 @@ public async Task Main()
6668

6769
var embedding = new Vector(new float[] { 1, 1, 1 });
6870
var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync();
69-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
70-
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
71-
Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray());
71+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
72+
Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray());
73+
Assert.Equal([(Half)1, (Half)1, (Half)1], items[0].HalfEmbedding!.ToArray());
7274
Assert.Equal(new BitArray(new bool[] { false, false, false }), items[0].BinaryEmbedding!);
73-
Assert.Equal(new float[] { 1, 1, 1 }, items[0].SparseEmbedding!.ToArray());
75+
Assert.Equal([1, 1, 1], items[0].SparseEmbedding!.ToArray());
7476

7577
// vector distance functions
7678

7779
items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync();
78-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
79-
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
80+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
81+
Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray());
8082

8183
items = await ctx.Items.OrderBy(x => x.Embedding!.MaxInnerProduct(embedding)).Take(5).ToListAsync();
82-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
84+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
8385

8486
items = await ctx.Items.OrderBy(x => x.Embedding!.CosineDistance(embedding)).Take(5).ToListAsync();
8587
Assert.Equal(3, items[2].Id);
8688

8789
items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync();
88-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
90+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
8991

9092
// halfvec distance functions
9193

9294
var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
9395
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync();
94-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
96+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
9597

9698
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.MaxInnerProduct(halfEmbedding)).Take(5).ToListAsync();
97-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
99+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
98100

99101
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.CosineDistance(halfEmbedding)).Take(5).ToListAsync();
100102
Assert.Equal(3, items[2].Id);
101103

102104
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L1Distance(halfEmbedding)).Take(5).ToListAsync();
103-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
105+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
104106

105107
// sparsevec distance functions
106108

107109
var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 });
108110
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync();
109-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
111+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
110112

111113
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.MaxInnerProduct(sparseEmbedding)).Take(5).ToListAsync();
112-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
114+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
113115

114116
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.CosineDistance(sparseEmbedding)).Take(5).ToListAsync();
115117
Assert.Equal(3, items[2].Id);
116118

117119
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L1Distance(sparseEmbedding)).Take(5).ToListAsync();
118-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
120+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
119121

120122
// bit distance functions
121123

122124
var binaryEmbedding = new BitArray(new bool[] { true, false, true });
123125
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.HammingDistance(binaryEmbedding)).Take(5).ToListAsync();
124-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
126+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
125127

126128
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.JaccardDistance(binaryEmbedding)).Take(5).ToListAsync();
127-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
129+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
128130

129131
// additional
130132

131133
items = await ctx.Items
132134
.OrderBy(x => x.Id)
133135
.Where(x => x.Embedding!.L2Distance(embedding) < 1.5)
134136
.ToListAsync();
135-
Assert.Equal(new int[] { 1, 3 }, items.Select(v => v.Id).ToArray());
137+
Assert.Equal([1, 3], items.Select(v => v.Id).ToArray());
136138

137139
var neighbors = await ctx.Items
138140
.OrderBy(x => x.Embedding!.L2Distance(embedding))
139141
.Select(x => new { Entity = x, Distance = x.Embedding!.L2Distance(embedding) })
140142
.ToListAsync();
141-
Assert.Equal(new int[] { 1, 3, 2 }, neighbors.Select(v => v.Entity.Id).ToArray());
142-
Assert.Equal(new double[] { 0, 1, Math.Sqrt(3) }, neighbors.Select(v => v.Distance).ToArray());
143+
Assert.Equal([1, 3, 2], neighbors.Select(v => v.Entity.Id).ToArray());
144+
Assert.Equal([0, 1, Math.Sqrt(3)], neighbors.Select(v => v.Distance).ToArray());
145+
}
146+
147+
[Theory]
148+
[InlineData(typeof(Vector), null, "vector")]
149+
[InlineData(typeof(Vector), 3, "vector(3)")]
150+
[InlineData(typeof(HalfVector), null, "halfvec")]
151+
[InlineData(typeof(HalfVector), 3, "halfvec(3)")]
152+
[InlineData(typeof(SparseVector), null, "sparsevec")]
153+
[InlineData(typeof(SparseVector), 3, "sparsevec(3)")]
154+
public void By_StoreType(Type type, int? size, string expectedStoreType)
155+
{
156+
using var ctx = new ItemContext();
157+
var typeMappingSource = ctx.GetService<IRelationalTypeMappingSource>();
158+
159+
var typeMapping = typeMappingSource.FindMapping(type, storeTypeName: null, size: size)!;
160+
Assert.Equal(expectedStoreType, typeMapping.StoreType);
161+
Assert.Same(type, typeMapping.ClrType);
162+
Assert.Equal(size, typeMapping.Size);
163+
}
164+
165+
[Theory]
166+
[InlineData("vector", typeof(Vector), null)]
167+
[InlineData("vector(3)", typeof(Vector), 3)]
168+
[InlineData("halfvec", typeof(HalfVector), null)]
169+
[InlineData("halfvec(3)", typeof(HalfVector), 3)]
170+
[InlineData("sparsevec", typeof(SparseVector), null)]
171+
[InlineData("sparsevec(3)", typeof(SparseVector), 3)]
172+
public void By_ClrType(string storeType, Type expectedType, int? expectedSize)
173+
{
174+
using var ctx = new ItemContext();
175+
var typeMappingSource = ctx.GetService<IRelationalTypeMappingSource>();
176+
177+
var typeMapping = typeMappingSource.FindMapping(storeType)!;
178+
Assert.Equal(storeType, typeMapping.StoreType);
179+
Assert.Same(expectedType, typeMapping.ClrType);
180+
Assert.Equal(expectedSize, typeMapping.Size);
143181
}
144182
}

0 commit comments

Comments
 (0)