|
1 | 1 | using Microsoft.EntityFrameworkCore;
|
2 | 2 | using Microsoft.EntityFrameworkCore.Infrastructure;
|
3 | 3 | using Microsoft.EntityFrameworkCore.Storage;
|
| 4 | +using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal; |
4 | 5 | using Pgvector.EntityFrameworkCore;
|
5 | 6 | using System.Collections;
|
6 | 7 | using System.ComponentModel.DataAnnotations.Schema;
|
| 8 | +using System.Security.AccessControl; |
7 | 9 |
|
8 | 10 | namespace Pgvector.Tests;
|
9 | 11 |
|
@@ -66,79 +68,115 @@ public async Task Main()
|
66 | 68 |
|
67 | 69 | var embedding = new Vector(new float[] { 1, 1, 1 });
|
68 | 70 | var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync();
|
69 |
| - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); |
70 |
| - Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray()); |
71 |
| - Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray()); |
| 71 | + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); |
| 72 | + Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray()); |
| 73 | + Assert.Equal([(Half)1, (Half)1, (Half)1], items[0].HalfEmbedding!.ToArray()); |
72 | 74 | Assert.Equal(new BitArray(new bool[] { false, false, false }), items[0].BinaryEmbedding!);
|
73 |
| - Assert.Equal(new float[] { 1, 1, 1 }, items[0].SparseEmbedding!.ToArray()); |
| 75 | + Assert.Equal([1, 1, 1], items[0].SparseEmbedding!.ToArray()); |
74 | 76 |
|
75 | 77 | // vector distance functions
|
76 | 78 |
|
77 | 79 | items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync();
|
78 |
| - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); |
79 |
| - Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray()); |
| 80 | + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); |
| 81 | + Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray()); |
80 | 82 |
|
81 | 83 | items = await ctx.Items.OrderBy(x => x.Embedding!.MaxInnerProduct(embedding)).Take(5).ToListAsync();
|
82 |
| - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); |
| 84 | + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); |
83 | 85 |
|
84 | 86 | items = await ctx.Items.OrderBy(x => x.Embedding!.CosineDistance(embedding)).Take(5).ToListAsync();
|
85 | 87 | Assert.Equal(3, items[2].Id);
|
86 | 88 |
|
87 | 89 | items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync();
|
88 |
| - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); |
| 90 | + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); |
89 | 91 |
|
90 | 92 | // halfvec distance functions
|
91 | 93 |
|
92 | 94 | var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
|
93 | 95 | items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync();
|
94 |
| - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); |
| 96 | + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); |
95 | 97 |
|
96 | 98 | items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.MaxInnerProduct(halfEmbedding)).Take(5).ToListAsync();
|
97 |
| - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); |
| 99 | + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); |
98 | 100 |
|
99 | 101 | items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.CosineDistance(halfEmbedding)).Take(5).ToListAsync();
|
100 | 102 | Assert.Equal(3, items[2].Id);
|
101 | 103 |
|
102 | 104 | items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L1Distance(halfEmbedding)).Take(5).ToListAsync();
|
103 |
| - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); |
| 105 | + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); |
104 | 106 |
|
105 | 107 | // sparsevec distance functions
|
106 | 108 |
|
107 | 109 | var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 });
|
108 | 110 | items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync();
|
109 |
| - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); |
| 111 | + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); |
110 | 112 |
|
111 | 113 | items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.MaxInnerProduct(sparseEmbedding)).Take(5).ToListAsync();
|
112 |
| - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); |
| 114 | + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); |
113 | 115 |
|
114 | 116 | items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.CosineDistance(sparseEmbedding)).Take(5).ToListAsync();
|
115 | 117 | Assert.Equal(3, items[2].Id);
|
116 | 118 |
|
117 | 119 | items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L1Distance(sparseEmbedding)).Take(5).ToListAsync();
|
118 |
| - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); |
| 120 | + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); |
119 | 121 |
|
120 | 122 | // bit distance functions
|
121 | 123 |
|
122 | 124 | var binaryEmbedding = new BitArray(new bool[] { true, false, true });
|
123 | 125 | items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.HammingDistance(binaryEmbedding)).Take(5).ToListAsync();
|
124 |
| - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); |
| 126 | + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); |
125 | 127 |
|
126 | 128 | items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.JaccardDistance(binaryEmbedding)).Take(5).ToListAsync();
|
127 |
| - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); |
| 129 | + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); |
128 | 130 |
|
129 | 131 | // additional
|
130 | 132 |
|
131 | 133 | items = await ctx.Items
|
132 | 134 | .OrderBy(x => x.Id)
|
133 | 135 | .Where(x => x.Embedding!.L2Distance(embedding) < 1.5)
|
134 | 136 | .ToListAsync();
|
135 |
| - Assert.Equal(new int[] { 1, 3 }, items.Select(v => v.Id).ToArray()); |
| 137 | + Assert.Equal([1, 3], items.Select(v => v.Id).ToArray()); |
136 | 138 |
|
137 | 139 | var neighbors = await ctx.Items
|
138 | 140 | .OrderBy(x => x.Embedding!.L2Distance(embedding))
|
139 | 141 | .Select(x => new { Entity = x, Distance = x.Embedding!.L2Distance(embedding) })
|
140 | 142 | .ToListAsync();
|
141 |
| - Assert.Equal(new int[] { 1, 3, 2 }, neighbors.Select(v => v.Entity.Id).ToArray()); |
142 |
| - Assert.Equal(new double[] { 0, 1, Math.Sqrt(3) }, neighbors.Select(v => v.Distance).ToArray()); |
| 143 | + Assert.Equal([1, 3, 2], neighbors.Select(v => v.Entity.Id).ToArray()); |
| 144 | + Assert.Equal([0, 1, Math.Sqrt(3)], neighbors.Select(v => v.Distance).ToArray()); |
| 145 | + } |
| 146 | + |
| 147 | + [Theory] |
| 148 | + [InlineData(typeof(Vector), null, "vector")] |
| 149 | + [InlineData(typeof(Vector), 3, "vector(3)")] |
| 150 | + [InlineData(typeof(HalfVector), null, "halfvec")] |
| 151 | + [InlineData(typeof(HalfVector), 3, "halfvec(3)")] |
| 152 | + [InlineData(typeof(SparseVector), null, "sparsevec")] |
| 153 | + [InlineData(typeof(SparseVector), 3, "sparsevec(3)")] |
| 154 | + public void By_StoreType(Type type, int? size, string expectedStoreType) |
| 155 | + { |
| 156 | + using var ctx = new ItemContext(); |
| 157 | + var typeMappingSource = ctx.GetService<IRelationalTypeMappingSource>(); |
| 158 | + |
| 159 | + var typeMapping = typeMappingSource.FindMapping(type, storeTypeName: null, size: size)!; |
| 160 | + Assert.Equal(expectedStoreType, typeMapping.StoreType); |
| 161 | + Assert.Same(type, typeMapping.ClrType); |
| 162 | + Assert.Equal(size, typeMapping.Size); |
| 163 | + } |
| 164 | + |
| 165 | + [Theory] |
| 166 | + [InlineData("vector", typeof(Vector), null)] |
| 167 | + [InlineData("vector(3)", typeof(Vector), 3)] |
| 168 | + [InlineData("halfvec", typeof(HalfVector), null)] |
| 169 | + [InlineData("halfvec(3)", typeof(HalfVector), 3)] |
| 170 | + [InlineData("sparsevec", typeof(SparseVector), null)] |
| 171 | + [InlineData("sparsevec(3)", typeof(SparseVector), 3)] |
| 172 | + public void By_ClrType(string storeType, Type expectedType, int? expectedSize) |
| 173 | + { |
| 174 | + using var ctx = new ItemContext(); |
| 175 | + var typeMappingSource = ctx.GetService<IRelationalTypeMappingSource>(); |
| 176 | + |
| 177 | + var typeMapping = typeMappingSource.FindMapping(storeType)!; |
| 178 | + Assert.Equal(storeType, typeMapping.StoreType); |
| 179 | + Assert.Same(expectedType, typeMapping.ClrType); |
| 180 | + Assert.Equal(expectedSize, typeMapping.Size); |
143 | 181 | }
|
144 | 182 | }
|
0 commit comments