66#include "access/ivfflat.h"
77#include "miscadmin.h"
88
9+ static int
10+ GtypeCompareVectors (const void * a , const void * b );
11+
912/*
1013 * Initialize with kmeans++
1114 *
@@ -29,7 +32,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float8 *lo
2932 collation = index -> rd_indcollation [0 ];
3033
3134 // Choose an initial center uniformly at random
32- VectorArraySet (centers , 0 , VectorArrayGet ( samples , RandomInt () % samples -> length ) );
35+ VectorArraySet (centers , 0 , & samples -> items [ RandomInt () % samples -> length ] );
3336 centers -> length ++ ;
3437
3538 for (j = 0 ; j < numSamples ; j ++ )
@@ -41,11 +44,11 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float8 *lo
4144 sum = 0.0 ;
4245
4346 for (j = 0 ; j < numSamples ; j ++ ) {
44- vec = VectorArrayGet ( samples , j ) ;
45-
47+ vec = & samples -> items [ j ] ;
48+
4649 // Only need to compute distance for new center
4750 // TODO Use triangle inequality to reduce distance calculations
48- distance = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (VectorArrayGet ( centers , i ) )));
51+ distance = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (& centers -> items [ i ] )));
4952
5053 // Set lower bound
5154 lowerBound [j * numCenters + i ] = distance ;
@@ -71,7 +74,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float8 *lo
7174 break ;
7275 }
7376
74- VectorArraySet (centers , i + 1 , VectorArrayGet ( samples , j ) );
77+ VectorArraySet (centers , i + 1 , & samples -> items [ j ] );
7578 centers -> length ++ ;
7679 }
7780
@@ -96,7 +99,7 @@ ApplyNorm(FmgrInfo *normprocinfo, Oid collation, gtype * vec) {
9699 * Compare vectors
97100 */
98101static int
99- CompareVectors (const void * a , const void * b ) {
102+ GtypeCompareVectors (const void * a , const void * b ) {
100103 return gtype_vector_cmp ((Vector * ) a , (Vector * ) b );
101104}
102105
@@ -112,11 +115,11 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) {
112115
113116 // Copy existing vectors while avoiding duplicates
114117 if (samples -> length > 0 ) {
115- qsort (samples -> items , samples -> length , VECTOR_SIZE (samples -> dim ), CompareVectors );
118+ qsort (samples -> items , samples -> length , VECTOR_SIZE (samples -> dim ), GtypeCompareVectors );
116119 for (int i = 0 ; i < samples -> length ; i ++ ) {
117- vec = VectorArrayGet (samples , i );
120+ vec = & samples -> items [ i ]; //GTypeVectorArrayGet (samples, i);
118121
119- if (i == 0 || CompareVectors (vec , VectorArrayGet ( samples , i - 1 ) ) != 0 ) {
122+ if (i == 0 || GtypeCompareVectors (vec , & samples -> items [ i - 1 ] ) != 0 ) {
120123 VectorArraySet (centers , centers -> length , vec );
121124 centers -> length ++ ;
122125 }
@@ -125,8 +128,9 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) {
125128
126129 // Fill remaining with random data
127130 while (centers -> length < centers -> maxlen ) {
128- vec = VectorArrayGet (centers , centers -> length );
129-
131+ //vec = GTypeVectorArrayGet(centers, centers->length);
132+ vec = & (centers -> items [centers -> length ]);
133+
130134 SET_VARSIZE (vec , VECTOR_SIZE (dimensions ));
131135 vec -> root .header = dimensions | GT_FEXTENDED_COMPOSITE ;
132136 vec -> root .children [0 ] = GT_HEADER_VECTOR ;
@@ -221,9 +225,9 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
221225 halfcdist = palloc_extended (halfcdistSize , MCXT_ALLOC_HUGE );
222226 newcdist = palloc (newcdistSize );
223227
224- newCenters = VectorArrayInit (numCenters , dimensions );
228+ newCenters = GtypeVectorArrayInit (numCenters , dimensions );
225229 for (j = 0 ; j < numCenters ; j ++ ) {
226- vec = VectorArrayGet ( newCenters , j ) ;
230+ vec = & newCenters -> items [ j ] ;
227231 SET_VARSIZE (vec , VECTOR_SIZE (dimensions ));
228232 vec -> root .header = dimensions | GT_FEXTENDED_COMPOSITE ;
229233 vec -> root .children [0 ] = GT_HEADER_VECTOR ;
@@ -263,11 +267,10 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
263267 // Step 1: For all centers, compute distance
264268 for (j = 0 ; j < numCenters ; j ++ )
265269 {
266- vec = VectorArrayGet (centers , j );
267-
270+ vec = & (centers -> items [j ]);
268271 for (k = j + 1 ; k < numCenters ; k ++ )
269272 {
270- distance = 0.5 * DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (VectorArrayGet (centers , k ))));
273+ distance = 0.5 * DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (& (centers -> items [ k ] ))));
271274 halfcdist [j * numCenters + k ] = distance ;
272275 halfcdist [k * numCenters + j ] = distance ;
273276 }
@@ -313,12 +316,12 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
313316 if (upperBound [j ] <= halfcdist [closestCenters [j ] * numCenters + k ])
314317 continue ;
315318
316- vec = VectorArrayGet ( samples , j ) ;
317-
319+ vec = & samples -> items [ j ] ;
320+
318321 // Step 3a
319322 if (rj )
320323 {
321- dxcx = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (VectorArrayGet (centers , closestCenters [j ]))));
324+ dxcx = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (& (centers -> items [ closestCenters [j ] ]))));
322325
323326 // d(x,c(x)) computed, which is a form of d(x,c)
324327 lowerBound [j * numCenters + closestCenters [j ]] = dxcx ;
@@ -332,7 +335,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
332335 // Step 3b
333336 if (dxcx > lowerBound [j * numCenters + k ] || dxcx > halfcdist [closestCenters [j ] * numCenters + k ])
334337 {
335- dxc = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (VectorArrayGet (centers , k ))));
338+ dxc = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (& (centers -> items [ k ] ))));
336339
337340 // d(x,c) calculated
338341 lowerBound [j * numCenters + k ] = dxc ;
@@ -354,7 +357,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
354357 // Step 4: For each center c, let m(c) be mean of all points assigned
355358 for (j = 0 ; j < numCenters ; j ++ )
356359 {
357- vec = VectorArrayGet (newCenters , j );
360+ vec = & (newCenters -> items [ j ] );
358361 for (k = 0 ; k < dimensions ; k ++ )
359362 * ((float8 * )& vec -> root .children [1 + (k * sizeof (float8 ))]) = 0.0 ;
360363
@@ -363,11 +366,11 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
363366
364367 for (j = 0 ; j < numSamples ; j ++ )
365368 {
366- vec = VectorArrayGet ( samples , j );
369+ vec = & samples -> items [ j ];
367370 closestCenter = closestCenters [j ];
368371
369372 // Increment sum and count of closest center
370- newCenter = VectorArrayGet (newCenters , closestCenter );
373+ newCenter = GTypeVectorArrayGet (newCenters , closestCenter );
371374 for (k = 0 ; k < dimensions ; k ++ )
372375 * ((float8 * )& newCenter -> root .children [1 + (k * sizeof (float8 ))]) += * ((float8 * )(& vec -> root .children [1 + (k * sizeof (float8 ))]));
373376
@@ -376,7 +379,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
376379
377380 for (j = 0 ; j < numCenters ; j ++ )
378381 {
379- vec = VectorArrayGet (newCenters , j );
382+ vec = GTypeVectorArrayGet (newCenters , j );
380383
381384 if (centerCounts [j ] > 0 )
382385 {
@@ -405,7 +408,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
405408
406409 // Step 5
407410 for (j = 0 ; j < numCenters ; j ++ )
408- newcdist [j ] = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (VectorArrayGet ( centers , j )) , PointerGetDatum (VectorArrayGet ( newCenters , j ) )));
411+ newcdist [j ] = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (& centers -> items [ j ]) , PointerGetDatum (& newCenters -> items [ j ] )));
409412
410413 for (j = 0 ; j < numSamples ; j ++ )
411414 {
@@ -427,7 +430,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
427430
428431 // Step 7
429432 for (j = 0 ; j < numCenters ; j ++ )
430- memcpy (VectorArrayGet (centers , j ), VectorArrayGet (newCenters , j ), VECTOR_SIZE (dimensions ));
433+ memcpy (& (centers -> items [ j ] ), & (newCenters -> items [ j ] ), VECTOR_SIZE (dimensions ));
431434
432435 if (changes == 0 && iteration != 0 )
433436 break ;
@@ -460,7 +463,8 @@ CheckCenters(Relation index, VectorArray centers)
460463 // Ensure no NaN or infinite values
461464 for (int i = 0 ; i < centers -> length ; i ++ )
462465 {
463- vec = VectorArrayGet (centers , i );
466+ //vec = GTypeVectorArrayGet(centers, i);
467+ vec = & (centers -> items [i ]);
464468
465469 for (int j = 0 ; j < AGT_ROOT_COUNT (vec ); j ++ ) {
466470 if (isnan ((double ) vec -> root .children [1 + (j * sizeof (float8 ))]))
@@ -473,10 +477,11 @@ CheckCenters(Relation index, VectorArray centers)
473477
474478 // Ensure no duplicate centers
475479 // Fine to sort in-place
476- qsort (centers -> items , centers -> length , VECTOR_SIZE (centers -> dim ), CompareVectors );
480+ qsort (centers -> items , centers -> length , VECTOR_SIZE (centers -> dim ), GtypeCompareVectors );
477481 for (int i = 1 ; i < centers -> length ; i ++ )
478482 {
479- if (CompareVectors (VectorArrayGet (centers , i ), VectorArrayGet (centers , i - 1 )) == 0 )
483+ //if (GtypeCompareVectors(GTypeVectorArrayGet(centers, i), GTypeVectorArrayGet(centers, i - 1)) == 0)
484+ if (GtypeCompareVectors (& (centers -> items [i ]), & (centers -> items [i - 1 ])) == 0 )
480485 elog (ERROR , "Duplicate centers detected. Please report a bug." );
481486 }
482487
@@ -489,7 +494,7 @@ CheckCenters(Relation index, VectorArray centers)
489494
490495 for (int i = 0 ; i < centers -> length ; i ++ )
491496 {
492- norm = DatumGetFloat8 (FunctionCall1Coll (normprocinfo , collation , PointerGetDatum (VectorArrayGet (centers , i ))));
497+ norm = DatumGetFloat8 (FunctionCall1Coll (normprocinfo , collation , PointerGetDatum (& (centers -> items [ i ] ))));
493498 if (norm == 0 )
494499 elog (ERROR , "Zero norm detected. Please report a bug." );
495500 }
0 commit comments