diff --git a/package.json b/package.json index 32107908..a902cf59 100644 --- a/package.json +++ b/package.json @@ -100,7 +100,8 @@ "lint-staged": { "*.{ts,js}": [ "npm run format:check", - "npm run lint -- --cache" + "npm run lint -- --cache", + "npm run prepack" ] } } diff --git a/src/collections/aggregate/index.ts b/src/collections/aggregate/index.ts index 1dad6460..a4cce65e 100644 --- a/src/collections/aggregate/index.ts +++ b/src/collections/aggregate/index.ts @@ -9,7 +9,7 @@ import { WeaviateInvalidInputError, WeaviateQueryError } from '../../errors.js'; import { Aggregator } from '../../graphql/index.js'; import { PrimitiveKeys, toBase64FromMedia } from '../../index.js'; import { Deserialize } from '../deserialize/index.js'; -import { Bm25QueryProperty, NearVectorInputType } from '../query/types.js'; +import { Bm25QueryProperty, NearVectorInputType, TargetVector } from '../query/types.js'; import { NearVectorInputGuards } from '../query/utils.js'; import { Serialize } from '../serialize/index.js'; @@ -31,27 +31,27 @@ export type GroupByAggregate = { export type AggregateOverAllOptions = AggregateBaseOptions; -export type AggregateNearOptions = AggregateBaseOptions & { +export type AggregateNearOptions = AggregateBaseOptions & { certainty?: number; distance?: number; objectLimit?: number; - targetVector?: string; + targetVector?: TargetVector; }; -export type AggregateHybridOptions = AggregateBaseOptions & { +export type AggregateHybridOptions = AggregateBaseOptions & { alpha?: number; maxVectorDistance?: number; objectLimit?: number; queryProperties?: (PrimitiveKeys | Bm25QueryProperty)[]; - targetVector?: string; + targetVector?: TargetVector; vector?: number[]; }; -export type AggregateGroupByHybridOptions = AggregateHybridOptions & { +export type AggregateGroupByHybridOptions = AggregateHybridOptions & { groupBy: PropertyOf | GroupByAggregate; }; -export type AggregateGroupByNearOptions = AggregateNearOptions & { +export type AggregateGroupByNearOptions = AggregateNearOptions & { groupBy: PropertyOf | GroupByAggregate; }; @@ -346,9 +346,9 @@ export type AggregateGroupByResult< }; }; -class AggregateManager implements Aggregate { +class AggregateManager implements Aggregate { connection: Connection; - groupBy: AggregateGroupBy; + groupBy: AggregateGroupBy; name: string; dbVersionSupport: DbVersionSupport; consistencyLevel?: ConsistencyLevel; @@ -373,14 +373,14 @@ class AggregateManager implements Aggregate { this.groupBy = { hybrid: async >( query: string, - opts: AggregateGroupByHybridOptions + opts: AggregateGroupByHybridOptions ): Promise[]> => { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; return this.grpc() - .then((aggregate) => + .then(async (aggregate) => aggregate.withHybrid({ - ...Serialize.aggregate.hybrid(query, opts), + ...(await Serialize.aggregate.hybrid(query, opts)), groupBy: Serialize.aggregate.groupBy(group), limit: group.limit, }) @@ -402,7 +402,7 @@ class AggregateManager implements Aggregate { }, nearImage: async >( image: string | Buffer, - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]> => { const [b64, usesGrpc] = await Promise.all([await toBase64FromMedia(image), await this.grpcChecker]); if (usesGrpc) { @@ -430,7 +430,7 @@ class AggregateManager implements Aggregate { }, nearObject: async >( id: string, - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]> => { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; @@ -457,7 +457,7 @@ class AggregateManager implements Aggregate { }, nearText: async >( query: string | string[], - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]> => { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; @@ -484,14 +484,14 @@ class AggregateManager implements Aggregate { }, nearVector: async >( vector: number[], - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]> => { if (await this.grpcChecker) { const group = typeof opts.groupBy === 'string' ? { property: opts.groupBy } : opts.groupBy; return this.grpc() - .then((aggregate) => + .then(async (aggregate) => aggregate.withNearVector({ - ...Serialize.aggregate.nearVector(vector, opts), + ...(await Serialize.aggregate.nearVector(vector, opts)), groupBy: Serialize.aggregate.groupBy(group), limit: group.limit, }) @@ -593,23 +593,23 @@ class AggregateManager implements Aggregate { return `${propertyName} { ${body} }`; } - static use( + static use( connection: Connection, name: string, dbVersionSupport: DbVersionSupport, consistencyLevel?: ConsistencyLevel, tenant?: string - ): AggregateManager { - return new AggregateManager(connection, name, dbVersionSupport, consistencyLevel, tenant); + ): AggregateManager { + return new AggregateManager(connection, name, dbVersionSupport, consistencyLevel, tenant); } async hybrid>( query: string, - opts?: AggregateHybridOptions + opts?: AggregateHybridOptions ): Promise> { if (await this.grpcChecker) { return this.grpc() - .then((aggregate) => aggregate.withHybrid(Serialize.aggregate.hybrid(query, opts))) + .then(async (aggregate) => aggregate.withHybrid(await Serialize.aggregate.hybrid(query, opts))) .then((reply) => Deserialize.aggregate(reply)); } let builder = this.base(opts?.returnMetrics, opts?.filters).withHybrid({ @@ -628,7 +628,7 @@ class AggregateManager implements Aggregate { async nearImage>( image: string | Buffer, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise> { const [b64, usesGrpc] = await Promise.all([await toBase64FromMedia(image), await this.grpcChecker]); if (usesGrpc) { @@ -650,7 +650,7 @@ class AggregateManager implements Aggregate { async nearObject>( id: string, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise> { if (await this.grpcChecker) { return this.grpc() @@ -671,7 +671,7 @@ class AggregateManager implements Aggregate { async nearText>( query: string | string[], - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise> { if (await this.grpcChecker) { return this.grpc() @@ -692,14 +692,16 @@ class AggregateManager implements Aggregate { async nearVector>( vector: NearVectorInputType, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise> { if (await this.grpcChecker) { return this.grpc() - .then((aggregate) => aggregate.withNearVector(Serialize.aggregate.nearVector(vector, opts))) + .then(async (aggregate) => + aggregate.withNearVector(await Serialize.aggregate.nearVector(vector, opts)) + ) .then((reply) => Deserialize.aggregate(reply)); } - if (!NearVectorInputGuards.is1DArray(vector)) { + if (!NearVectorInputGuards.is1D(vector)) { throw new WeaviateInvalidInputError( 'Vector can only be a 1D array of numbers when using `nearVector` with <1.29 Weaviate versions.' ); @@ -768,9 +770,9 @@ class AggregateManager implements Aggregate { }; } -export interface Aggregate { +export interface Aggregate { /** This namespace contains methods perform a group by search while aggregating metrics. */ - groupBy: AggregateGroupBy; + groupBy: AggregateGroupBy; /** * Aggregate metrics over the objects returned by a hybrid search on this collection. * @@ -782,7 +784,7 @@ export interface Aggregate { */ hybrid>( query: string, - opts?: AggregateHybridOptions + opts?: AggregateHybridOptions ): Promise>; /** * Aggregate metrics over the objects returned by a near image vector search on this collection. @@ -797,7 +799,7 @@ export interface Aggregate { */ nearImage>( image: string | Buffer, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise>; /** * Aggregate metrics over the objects returned by a near object search on this collection. @@ -812,7 +814,7 @@ export interface Aggregate { */ nearObject>( id: string, - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise>; /** * Aggregate metrics over the objects returned by a near vector search on this collection. @@ -827,7 +829,7 @@ export interface Aggregate { */ nearText>( query: string | string[], - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise>; /** * Aggregate metrics over the objects returned by a near vector search on this collection. @@ -842,7 +844,7 @@ export interface Aggregate { */ nearVector>( vector: number[], - opts?: AggregateNearOptions + opts?: AggregateNearOptions ): Promise>; /** * Aggregate metrics over all the objects in this collection without any vector search. @@ -853,7 +855,7 @@ export interface Aggregate { overAll>(opts?: AggregateOverAllOptions): Promise>; } -export interface AggregateGroupBy { +export interface AggregateGroupBy { /** * Aggregate metrics over the objects grouped by a specified property and returned by a hybrid search on this collection. * @@ -865,7 +867,7 @@ export interface AggregateGroupBy { */ hybrid>( query: string, - opts: AggregateGroupByHybridOptions + opts: AggregateGroupByHybridOptions ): Promise[]>; /** * Aggregate metrics over the objects grouped by a specified property and returned by a near image vector search on this collection. @@ -880,7 +882,7 @@ export interface AggregateGroupBy { */ nearImage>( image: string | Buffer, - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]>; /** * Aggregate metrics over the objects grouped by a specified property and returned by a near object search on this collection. @@ -895,7 +897,7 @@ export interface AggregateGroupBy { */ nearObject>( id: string, - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]>; /** * Aggregate metrics over the objects grouped by a specified property and returned by a near text vector search on this collection. @@ -910,7 +912,7 @@ export interface AggregateGroupBy { */ nearText>( query: string | string[], - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]>; /** * Aggregate metrics over the objects grouped by a specified property and returned by a near vector search on this collection. @@ -925,7 +927,7 @@ export interface AggregateGroupBy { */ nearVector>( vector: number[], - opts: AggregateGroupByNearOptions + opts: AggregateGroupByNearOptions ): Promise[]>; /** * Aggregate metrics over all the objects in this collection grouped by a specified property without any vector search. diff --git a/src/collections/aggregate/integration.test.ts b/src/collections/aggregate/integration.test.ts index 284606b6..fc3d0c43 100644 --- a/src/collections/aggregate/integration.test.ts +++ b/src/collections/aggregate/integration.test.ts @@ -485,7 +485,7 @@ describe('Testing of collection.aggregate search methods', () => { it('should return an aggregation on a nearVector search', async () => { const obj = await collection.query.fetchObjectById(uuid, { includeVector: true }); - const result = await collection.aggregate.nearVector(obj?.vectors.default!, { + const result = await collection.aggregate.nearVector(obj?.vectors.default as number[], { objectLimit: 1000, returnMetrics: collection.metrics.aggregate('text').text(['count']), }); @@ -494,7 +494,7 @@ describe('Testing of collection.aggregate search methods', () => { it('should return a grouped aggregation on a nearVector search', async () => { const obj = await collection.query.fetchObjectById(uuid, { includeVector: true }); - const result = await collection.aggregate.groupBy.nearVector(obj?.vectors.default!, { + const result = await collection.aggregate.groupBy.nearVector(obj?.vectors.default as number[], { objectLimit: 1000, groupBy: 'text', returnMetrics: collection.metrics.aggregate('text').text(['count']), diff --git a/src/collections/collection/index.ts b/src/collections/collection/index.ts index c4f7ee31..c4164352 100644 --- a/src/collections/collection/index.ts +++ b/src/collections/collection/index.ts @@ -14,12 +14,13 @@ import { Iterator } from '../iterator/index.js'; import query, { Query } from '../query/index.js'; import sort, { Sort } from '../sort/index.js'; import tenants, { TenantBase, Tenants } from '../tenants/index.js'; -import { QueryMetadata, QueryProperty, QueryReference } from '../types/index.js'; +import { QueryMetadata, QueryProperty, QueryReference, ReturnVectors } from '../types/index.js'; +import { IncludeVector } from '../types/internal.js'; import multiTargetVector, { MultiTargetVector } from '../vectors/multiTargetVector.js'; -export interface Collection { +export interface Collection { /** This namespace includes all the querying methods available to you when using Weaviate's standard aggregation capabilities. */ - aggregate: Aggregate; + aggregate: Aggregate; /** This namespace includes all the backup methods available to you when backing up a collection in Weaviate. */ backup: BackupCollection; /** This namespace includes all the CRUD methods available to you when modifying the configuration of the collection in Weaviate. */ @@ -29,19 +30,19 @@ export interface Collection { /** This namespace includes the methods by which you can create the `FilterValue` values for use when filtering queries over your collection. */ filter: Filter; /** This namespace includes all the querying methods available to you when using Weaviate's generative capabilities. */ - generate: Generate; + generate: Generate; /** This namespace includes the methods by which you can create the `MetricsX` values for use when aggregating over your collection. */ metrics: Metrics; /** The name of the collection. */ name: N; /** This namespace includes all the querying methods available to you when using Weaviate's standard query capabilities. */ - query: Query; + query: Query; /** This namespaces includes the methods by which you can create the `Sorting` values for use when sorting queries over your collection. */ sort: Sort; /** This namespace includes all the CRUD methods available to you when modifying the tenants of a multi-tenancy-enabled collection in Weaviate. */ tenants: Tenants; /** This namespaces includes the methods by which you cna create the `MultiTargetVectorJoin` values for use when performing multi-target vector searches over your collection. */ - multiTargetVector: MultiTargetVector; + multiTargetVector: MultiTargetVector; /** * Use this method to check if the collection exists in Weaviate. * @@ -54,15 +55,19 @@ export interface Collection { * This iterator keeps a record of the last object that it returned to be used in each subsequent call to Weaviate. * Once the collection is exhausted, the iterator exits. * + * @typeParam I - The vector(s) to include in the response. If using named vectors, pass an array of strings to include only specific vectors. + * @typeParam RV - The vectors(s) to be returned in the response depending on the input in opts.includeVector. * @param {IteratorOptions} opts The options to use when fetching objects from Weaviate. * @returns {Iterator} An iterator over the objects in the collection as an async generator. * * @description If `return_properties` is not provided, all the properties of each object will be * requested from Weaviate except for its vector as this is an expensive operation. Specify `include_vector` - * to request the vector back as well. In addition, if `return_references=None` then none of the references + * to request the vectors back as well. In addition, if `return_references=None` then none of the references * are returned. Use `wvc.QueryReference` to specify which references to return. */ - iterator: (opts?: IteratorOptions) => Iterator; + iterator: , RV extends ReturnVectors>( + opts?: IteratorOptions + ) => Iterator; /** * Use this method to return the total number of objects in the collection. * @@ -77,9 +82,9 @@ export interface Collection { * This method does not send a request to Weaviate. It only returns a new collection object that is specific to the consistency level you specify. * * @param {ConsistencyLevel} consistencyLevel The consistency level to use. - * @returns {Collection} A new collection object specific to the consistency level you specified. + * @returns {Collection} A new collection object specific to the consistency level you specified. */ - withConsistency: (consistencyLevel: ConsistencyLevel) => Collection; + withConsistency: (consistencyLevel: ConsistencyLevel) => Collection; /** * Use this method to return a collection object specific to a single tenant. * @@ -89,13 +94,13 @@ export interface Collection { * * @typedef {TenantBase} TT A type that extends TenantBase. * @param {string | TT} tenant The tenant name or tenant object to use. - * @returns {Collection} A new collection object specific to the tenant you specified. + * @returns {Collection} A new collection object specific to the tenant you specified. */ - withTenant: (tenant: string | TT) => Collection; + withTenant: (tenant: string | TT) => Collection; } -export type IteratorOptions = { - includeVector?: boolean | string[]; +export type IteratorOptions = { + includeVector?: I; returnMetadata?: QueryMetadata; returnProperties?: QueryProperty[]; returnReferences?: QueryReference[]; @@ -106,43 +111,49 @@ const isString = (value: any): value is string => typeof value === 'string'; const capitalizeCollectionName = (name: N): N => (name.charAt(0).toUpperCase() + name.slice(1)) as N; -const collection = ( +const collection = ( connection: Connection, name: N, dbVersionSupport: DbVersionSupport, consistencyLevel?: ConsistencyLevel, tenant?: string -): Collection => { +): Collection => { if (!isString(name)) { throw new WeaviateInvalidInputError(`The collection name must be a string, got: ${typeof name}`); } const capitalizedName = capitalizeCollectionName(name); - const aggregateCollection = aggregate( + const aggregateCollection = aggregate( + connection, + capitalizedName, + dbVersionSupport, + consistencyLevel, + tenant + ); + const queryCollection = query( connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant ); - const queryCollection = query(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant); return { aggregate: aggregateCollection, backup: backupCollection(connection, capitalizedName), config: config(connection, capitalizedName, dbVersionSupport, tenant), data: data(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), filter: filter(), - generate: generate(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), + generate: generate(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), metrics: metrics(), - multiTargetVector: multiTargetVector(), + multiTargetVector: multiTargetVector(), name: name, query: queryCollection, sort: sort(), tenants: tenants(connection, capitalizedName, dbVersionSupport), exists: () => new ClassExists(connection).withClassName(capitalizedName).do(), - iterator: (opts?: IteratorOptions) => - new Iterator((limit: number, after?: string) => + iterator: , RV extends ReturnVectors>(opts?: IteratorOptions) => + new Iterator((limit: number, after?: string) => queryCollection - .fetchObjects({ + .fetchObjects({ limit, after, includeVector: opts?.includeVector, @@ -154,9 +165,9 @@ const collection = ( ), length: () => aggregateCollection.overAll().then(({ totalCount }) => totalCount), withConsistency: (consistencyLevel: ConsistencyLevel) => - collection(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), + collection(connection, capitalizedName, dbVersionSupport, consistencyLevel, tenant), withTenant: (tenant: string | TT) => - collection( + collection( connection, capitalizedName, dbVersionSupport, diff --git a/src/collections/config/integration.test.ts b/src/collections/config/integration.test.ts index 9b363d06..d8535206 100644 --- a/src/collections/config/integration.test.ts +++ b/src/collections/config/integration.test.ts @@ -73,6 +73,7 @@ describe('Testing of the collection.config namespace', () => { filterStrategy: 'sweeping', flatSearchCutoff: 40000, distance: 'cosine', + multiVector: undefined, quantizer: undefined, type: 'hnsw', }); @@ -128,6 +129,7 @@ describe('Testing of the collection.config namespace', () => { filterStrategy: 'sweeping', flatSearchCutoff: 40000, distance: 'cosine', + multiVector: undefined, quantizer: undefined, type: 'hnsw', }); @@ -387,11 +389,7 @@ describe('Testing of the collection.config namespace', () => { ]); }); - requireAtLeast( - 1, - 31, - 0 - )('Mutable named vectors', () => { + requireAtLeast(1, 31, 0)(describe)('Mutable named vectors', () => { it('should be able to add named vectors to a collection', async () => { const collectionName = 'TestCollectionConfigAddVector' as const; const collection = await client.collections.create({ @@ -537,6 +535,7 @@ describe('Testing of the collection.config namespace', () => { filterStrategy: 'sweeping', flatSearchCutoff: 40000, distance: 'cosine', + multiVector: undefined, quantizer: { bitCompression: false, segments: 0, @@ -646,6 +645,7 @@ describe('Testing of the collection.config namespace', () => { filterStrategy: 'sweeping', flatSearchCutoff: 40000, distance: 'cosine', + multiVector: undefined, type: 'hnsw', quantizer: { bitCompression: false, @@ -711,4 +711,51 @@ describe('Testing of the collection.config namespace', () => { }, }); }); + + requireAtLeast(1, 31, 0)(it)( + 'should be able to create and get a multi-vector collection with encoding', + async () => { + const collectionName = 'TestCollectionConfigCreateWithMuveraEncoding'; + const collection = await client.collections.create({ + name: collectionName, + vectorizers: weaviate.configure.vectorizer.none({ + vectorIndexConfig: weaviate.configure.vectorIndex.hnsw({ + multiVector: weaviate.configure.vectorIndex.multiVector.multiVector({ + aggregation: 'maxSim', + encoding: weaviate.configure.vectorIndex.multiVector.encoding.muvera(), + }), + }), + }), + }); + const config = await collection.config.get(); + expect(config.name).toEqual(collectionName); + + const indexConfig = config.vectorizers.default.indexConfig as VectorIndexConfigHNSW; + expect(indexConfig.multiVector).toBeDefined(); + expect(indexConfig.multiVector?.aggregation).toEqual('maxSim'); + expect(indexConfig.multiVector?.encoding).toBeDefined(); + } + ); + + requireAtLeast(1, 31, 0)(it)( + 'should be able to create and get a multi-vector collection without encoding', + async () => { + const collectionName = 'TestCollectionConfigCreateWithoutMuveraEncoding'; + const collection = await client.collections.create({ + name: collectionName, + vectorizers: weaviate.configure.vectorizer.none({ + vectorIndexConfig: weaviate.configure.vectorIndex.hnsw({ + multiVector: weaviate.configure.vectorIndex.multiVector.multiVector(), + }), + }), + }); + const config = await collection.config.get(); + expect(config.name).toEqual(collectionName); + + const indexConfig = config.vectorizers.default.indexConfig as VectorIndexConfigHNSW; + expect(indexConfig.multiVector).toBeDefined(); + expect(indexConfig.multiVector?.aggregation).toEqual('maxSim'); + expect(indexConfig.multiVector?.encoding).toBeUndefined(); + } + ); }); diff --git a/src/collections/config/types/vectorIndex.ts b/src/collections/config/types/vectorIndex.ts index ddd6ea90..1ce29a22 100644 --- a/src/collections/config/types/vectorIndex.ts +++ b/src/collections/config/types/vectorIndex.ts @@ -9,6 +9,7 @@ export type VectorIndexConfigHNSW = { filterStrategy: VectorIndexFilterStrategy; flatSearchCutoff: number; maxConnections: number; + multiVector: MultiVectorConfig | undefined; quantizer: PQConfig | BQConfig | SQConfig | undefined; skip: boolean; vectorCacheMaxObjects: number; @@ -61,6 +62,20 @@ export type PQConfig = { type: 'pq'; }; +export type MultiVectorConfig = { + aggregation: 'maxSim' | string; + encoding?: MultiVectorEncodingConfig; +}; + +export type MuveraEncodingConfig = { + ksim?: number; + dprojections?: number; + repetitions?: number; + type: 'muvera'; +}; + +export type MultiVectorEncodingConfig = MuveraEncodingConfig | Record; + export type PQEncoderConfig = { type: PQEncoderType; distribution: PQEncoderDistribution; diff --git a/src/collections/config/utils.ts b/src/collections/config/utils.ts index a66a1313..e5c92651 100644 --- a/src/collections/config/utils.ts +++ b/src/collections/config/utils.ts @@ -4,7 +4,6 @@ import { WeaviateUnsupportedFeatureError, } from '../../errors.js'; import { - Properties, WeaviateBM25Config, WeaviateClass, WeaviateInvertedIndexConfig, @@ -19,14 +18,13 @@ import { WeaviateVectorsConfig, } from '../../openapi/types.js'; import { DbVersionSupport } from '../../utils/dbVersion.js'; -import { QuantizerGuards } from '../configure/parsing.js'; +import { MultiVectorEncodingGuards, QuantizerGuards, VectorIndexGuards } from '../configure/parsing.js'; import { PropertyConfigCreate, ReferenceConfigCreate, ReferenceMultiTargetConfigCreate, ReferenceSingleTargetConfigCreate, VectorIndexConfigCreate, - VectorIndexConfigDynamicCreate, VectorIndexConfigFlatCreate, VectorIndexConfigHNSWCreate, VectorizersConfigAdd, @@ -40,6 +38,8 @@ import { InvertedIndexConfig, ModuleConfig, MultiTenancyConfig, + MultiVectorConfig, + MultiVectorEncodingConfig, PQConfig, PQEncoderConfig, PQEncoderDistribution, @@ -139,18 +139,43 @@ export const classToCollection = (cls: WeaviateClass): CollectionConfig => { export const parseVectorIndex = (module: ModuleConfig): any => { if (module.config === undefined) return undefined; - if (module.name === 'dynamic') { - const { hnsw, flat, ...conf } = module.config as VectorIndexConfigDynamicCreate; + if (VectorIndexGuards.isDynamic(module.config)) { + const { hnsw, flat, ...conf } = module.config; return { ...conf, hnsw: parseVectorIndex({ name: 'hnsw', config: hnsw }), flat: parseVectorIndex({ name: 'flat', config: flat }), }; } - const { quantizer, ...conf } = module.config as + + let multivector: any; + if (VectorIndexGuards.isHNSW(module.config) && module.config.multiVector !== undefined) { + multivector = { + aggregation: module.config.multiVector.aggregation, + enabled: true, + }; + if ( + module.config.multiVector.encoding !== undefined && + MultiVectorEncodingGuards.isMuvera(module.config.multiVector.encoding) + ) { + multivector.muvera = { + enabled: true, + ksim: module.config.multiVector.encoding.ksim, + dprojections: module.config.multiVector.encoding.dprojections, + repetitions: module.config.multiVector.encoding.repetitions, + }; + } + } + + const { quantizer, ...rest } = module.config as | VectorIndexConfigFlatCreate | VectorIndexConfigHNSWCreate | Record; + + const conf = { + ...rest, + multivector, + }; if (quantizer === undefined) return conf; if (QuantizerGuards.isBQCreate(quantizer)) { const { type, ...quant } = quantizer; @@ -183,8 +208,8 @@ export const parseVectorizerConfig = (config?: VectorizerConfig): any => { }; }; -export const makeVectorsConfig = ( - configVectorizers: VectorizersConfigCreate | VectorizersConfigAdd, +export const makeVectorsConfig = ( + configVectorizers: VectorizersConfigCreate | VectorizersConfigAdd, supportsDynamicVectorIndex: Awaited> ) => { let vectorizers: string[] = []; @@ -468,6 +493,7 @@ class ConfigMapping { } else { quantizer = undefined; } + return { cleanupIntervalSeconds: v.cleanupIntervalSeconds, distance: v.distance, @@ -479,12 +505,42 @@ class ConfigMapping { filterStrategy: exists(v.filterStrategy) ? v.filterStrategy : 'sweeping', flatSearchCutoff: v.flatSearchCutoff, maxConnections: v.maxConnections, + multiVector: exists(v.multivector) + ? ConfigMapping.multiVector(v.multivector) + : undefined, quantizer: quantizer, skip: v.skip, vectorCacheMaxObjects: v.vectorCacheMaxObjects, type: 'hnsw', }; } + static multiVector(v: Record): MultiVectorConfig | undefined { + if (!exists(v.enabled)) + throw new WeaviateDeserializationError('Multi vector enabled was not returned by Weaviate'); + if (v.enabled === false) return undefined; + if (!exists(v.aggregation)) + throw new WeaviateDeserializationError('Multi vector aggregation was not returned by Weaviate'); + let encoding: MultiVectorEncodingConfig | undefined; + if ( + exists<{ + ksim: number; + dprojections: number; + repetitions: number; + enabled: boolean; + }>(v.muvera) + ) { + encoding = v.muvera.enabled + ? { + type: 'muvera', + ...v.muvera, + } + : undefined; + } + return { + aggregation: v.aggregation, + encoding, + }; + } static bq(v?: Record): BQConfig | undefined { if (v === undefined) throw new WeaviateDeserializationError('BQ was not returned by Weaviate'); if (!exists(v.enabled)) diff --git a/src/collections/configure/parsing.ts b/src/collections/configure/parsing.ts index 09319424..dbd2690b 100644 --- a/src/collections/configure/parsing.ts +++ b/src/collections/configure/parsing.ts @@ -1,3 +1,4 @@ +import { MuveraEncodingConfigCreate } from '../index.js'; import { BQConfigCreate, BQConfigUpdate, @@ -5,6 +6,9 @@ import { PQConfigUpdate, SQConfigCreate, SQConfigUpdate, + VectorIndexConfigDynamicCreate, + VectorIndexConfigFlatCreate, + VectorIndexConfigHNSWCreate, } from './types/index.js'; type QuantizerConfig = @@ -36,6 +40,30 @@ export class QuantizerGuards { } } +type VectorIndexConfig = + | VectorIndexConfigHNSWCreate + | VectorIndexConfigFlatCreate + | VectorIndexConfigDynamicCreate + | Record; + +export class VectorIndexGuards { + static isHNSW(config?: VectorIndexConfig): config is VectorIndexConfigHNSWCreate { + return (config as VectorIndexConfigHNSWCreate)?.type === 'hnsw'; + } + static isFlat(config?: VectorIndexConfig): config is VectorIndexConfigFlatCreate { + return (config as VectorIndexConfigFlatCreate)?.type === 'flat'; + } + static isDynamic(config?: VectorIndexConfig): config is VectorIndexConfigDynamicCreate { + return (config as VectorIndexConfigDynamicCreate)?.type === 'dynamic'; + } +} + +export class MultiVectorEncodingGuards { + static isMuvera(config?: Record): config is MuveraEncodingConfigCreate { + return (config as { type: string })?.type === 'muvera'; + } +} + export function parseWithDefault(value: D | undefined, defaultValue: D): D { return value !== undefined ? value : defaultValue; } diff --git a/src/collections/configure/types/vectorIndex.ts b/src/collections/configure/types/vectorIndex.ts index 4f759a7f..41bf7800 100644 --- a/src/collections/configure/types/vectorIndex.ts +++ b/src/collections/configure/types/vectorIndex.ts @@ -1,6 +1,8 @@ import { BQConfig, ModuleConfig, + MultiVectorConfig, + MuveraEncodingConfig, PQConfig, PQEncoderDistribution, PQEncoderType, @@ -46,6 +48,10 @@ export type SQConfigUpdate = { type: 'sq'; }; +export type MultiVectorConfigCreate = RecursivePartial; + +export type MuveraEncodingConfigCreate = RecursivePartial; + export type VectorIndexConfigHNSWCreate = RecursivePartial; export type VectorIndexConfigDynamicCreate = RecursivePartial; @@ -130,6 +136,8 @@ export type VectorIndexConfigHNSWCreateOptions = { filterStrategy?: VectorIndexFilterStrategy; /** The maximum number of connections. Default is 64. */ maxConnections?: number; + /** The multi-vector configuration to use. Use `vectorIndex.multiVector` to make one. */ + multiVector?: MultiVectorConfigCreate; /** The quantizer configuration to use. Use `vectorIndex.quantizer.bq` or `vectorIndex.quantizer.pq` to make one. */ quantizer?: PQConfigCreate | BQConfigCreate | SQConfigCreate; /** Whether to skip the index. Default is false. */ diff --git a/src/collections/configure/types/vectorizer.ts b/src/collections/configure/types/vectorizer.ts index 14f32da7..2e30923d 100644 --- a/src/collections/configure/types/vectorizer.ts +++ b/src/collections/configure/types/vectorizer.ts @@ -54,9 +54,13 @@ export type VectorConfigUpdate>; }; -export type VectorizersConfigCreate = - | VectorConfigCreate, string | undefined, VectorIndexType, Vectorizer> - | VectorConfigCreate, string, VectorIndexType, Vectorizer>[]; +export type VectorizersConfigCreate = V extends undefined + ? + | VectorConfigCreate, string | undefined, VectorIndexType, Vectorizer> + | VectorConfigCreate, string, VectorIndexType, Vectorizer>[] + : + | VectorConfigCreate, (keyof V & string) | undefined, VectorIndexType, Vectorizer> + | VectorConfigCreate, keyof V & string, VectorIndexType, Vectorizer>[]; export type VectorizersConfigAdd = | VectorConfigCreate, string, VectorIndexType, Vectorizer> diff --git a/src/collections/configure/unit.test.ts b/src/collections/configure/unit.test.ts index 0eddb278..ae8ee503 100644 --- a/src/collections/configure/unit.test.ts +++ b/src/collections/configure/unit.test.ts @@ -135,6 +135,7 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { quantizer: { type: 'pq', }, + type: 'hnsw', }, }); }); @@ -189,6 +190,7 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { type: 'pq', }, skip: true, + type: 'hnsw', vectorCacheMaxObjects: 2000000000000, }, }); @@ -202,6 +204,7 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { quantizer: { type: 'bq', }, + type: 'flat', }, }); }); @@ -226,6 +229,7 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { rescoreLimit: 100, type: 'bq', }, + type: 'flat', }, }); }); @@ -245,6 +249,22 @@ describe('Unit testing of the configure & reconfigure factory classes', () => { trainingLimit: 200, type: 'sq', }, + type: 'hnsw', + }, + }); + }); + + it('should create an hnsw VectorIndexConfig type with multivector enabled', () => { + const config = configure.vectorIndex.hnsw({ + multiVector: configure.vectorIndex.multiVector.multiVector({ aggregation: 'maxSim' }), + }); + expect(config).toEqual>({ + name: 'hnsw', + config: { + multiVector: { + aggregation: 'maxSim', + }, + type: 'hnsw', }, }); }); diff --git a/src/collections/configure/vectorIndex.ts b/src/collections/configure/vectorIndex.ts index a9e8790e..74e69b49 100644 --- a/src/collections/configure/vectorIndex.ts +++ b/src/collections/configure/vectorIndex.ts @@ -7,6 +7,8 @@ import { import { BQConfigCreate, BQConfigUpdate, + MultiVectorConfigCreate, + MuveraEncodingConfigCreate, PQConfigCreate, PQConfigUpdate, SQConfigCreate, @@ -44,6 +46,7 @@ const configure = { distance, vectorCacheMaxObjects, quantizer: quantizer, + type: 'flat', }, }; }, @@ -66,6 +69,7 @@ const configure = { ...rest, distance: distanceMetric, quantizer: rest.quantizer, + type: 'hnsw', } : undefined, }; @@ -89,10 +93,57 @@ const configure = { threshold: opts.threshold, hnsw: isModuleConfig(opts.hnsw) ? opts.hnsw.config : configure.hnsw(opts.hnsw).config, flat: isModuleConfig(opts.flat) ? opts.flat.config : configure.flat(opts.flat).config, + type: 'dynamic', } : undefined, }; }, + /** + * Define the configuration for a multi-vector index. + */ + multiVector: { + /** + * Specify the encoding configuration for a multi-vector index. + */ + encoding: { + /** + * Create an object of type `MuveraEncodingConfigCreate` to be used when defining the encoding configuration of a multi-vector index using MUVERA. + * + * @param {number} [options.ksim] The number of nearest neighbors to consider for similarity. Default is undefined. + * @param {number} [options.dprojections] The number of projections to use. Default is undefined. + * @param {number} [options.repetitions] The number of repetitions to use. Default is undefined. + * @returns {MuveraEncodingConfigCreate} The object of type `MuveraEncodingConfigCreate`. + */ + muvera: (options?: { + ksim?: number; + dprojections?: number; + repetitions?: number; + }): MuveraEncodingConfigCreate => { + return { + ksim: options?.ksim, + dprojections: options?.dprojections, + repetitions: options?.repetitions, + type: 'muvera', + }; + }, + }, + /** + * Create an object of type `MultiVectorConfigCreate` to be used when defining the configuration of a multi-vector index. + * + * @param {string} [options.aggregation] The aggregation method to use. Default is 'maxSim'. + * @param {MultiVectorConfig['encoding']} [options.encoding] The encoding configuration for the multi-vector index. Default is undefined. + * @returns {MultiVectorConfigCreate} The object of type `MultiVectorConfigCreate`. + */ + multiVector: (options?: { + aggregation?: 'maxSim' | string; + encoding?: MultiVectorConfigCreate['encoding']; + }): MultiVectorConfigCreate => { + return { + aggregation: options?.aggregation, + encoding: options?.encoding, + }; + }, + }, /** * Define the quantizer configuration to use when creating a vector index. */ diff --git a/src/collections/data/integration.test.ts b/src/collections/data/integration.test.ts index 681609a2..0be9ccd9 100644 --- a/src/collections/data/integration.test.ts +++ b/src/collections/data/integration.test.ts @@ -328,7 +328,7 @@ describe('Testing of the collection.data methods with a single target reference' collection.query.fetchObjectById(toBeReplacedID, { returnReferences: [{ linkOn: 'ref' }], }); - const assert = (obj: WeaviateObject | null, id: string) => { + const assert = (obj: WeaviateObject | null, id: string) => { expect(obj).not.toBeNull(); expect(obj?.references?.ref?.objects[0].uuid).toEqual(id); }; diff --git a/src/collections/deserialize/index.ts b/src/collections/deserialize/index.ts index fdf37f3d..0b998d86 100644 --- a/src/collections/deserialize/index.ts +++ b/src/collections/deserialize/index.ts @@ -11,12 +11,14 @@ import { AggregateReply_Aggregations_Aggregation_Text, AggregateReply_Group_GroupedBy, } from '../../proto/v1/aggregate.js'; +import { Vectors_VectorType } from '../../proto/v1/base.js'; import { BatchObject as BatchObjectGRPC, BatchObjectsReply } from '../../proto/v1/batch.js'; import { BatchDeleteReply } from '../../proto/v1/batch_delete.js'; import { ListValue, Properties as PropertiesGrpc, Value } from '../../proto/v1/properties.js'; import { MetadataResult, PropertiesResult, SearchReply } from '../../proto/v1/search_get.js'; import { TenantActivityStatus, TenantsGetReply } from '../../proto/v1/tenants.js'; import { DbVersionSupport } from '../../utils/dbVersion.js'; +import { yieldToEventLoop } from '../../utils/yield.js'; import { AggregateBoolean, AggregateDate, @@ -28,7 +30,10 @@ import { GenerativeConfigRuntime, GenerativeMetadata, PropertiesMetrics, + Vectors, + WeaviateObject, } from '../index.js'; +import { MultiVectorType, SingleVectorType } from '../query/types.js'; import { referenceFromObjects } from '../references/utils.js'; import { Tenant } from '../tenants/index.js'; import { @@ -47,6 +52,9 @@ import { WeaviateReturn, } from '../types/index.js'; +const UINT16LEN = 2; +const UINT32LEN = 4; + export class Deserialize { private supports125ListValue: boolean; @@ -195,49 +203,54 @@ export class Deserialize { }); } - public query(reply: SearchReply): WeaviateReturn { + public async query(reply: SearchReply): Promise> { return { - objects: reply.results.map((result) => { - return { - metadata: Deserialize.metadata(result.metadata), - properties: this.properties(result.properties), - references: this.references(result.properties), - uuid: Deserialize.uuid(result.metadata), - vectors: Deserialize.vectors(result.metadata), - } as any; - }), + objects: await Promise.all( + reply.results.map(async (result) => { + return { + metadata: Deserialize.metadata(result.metadata), + properties: this.properties(result.properties), + references: await this.references(result.properties), + uuid: Deserialize.uuid(result.metadata), + vectors: await Deserialize.vectors(result.metadata), + } as unknown as WeaviateObject; + }) + ), }; } - public generate( + public async generate( reply: SearchReply - ): GenerativeReturn { + ): Promise> { return { - objects: reply.results.map((result) => { - return { - generated: result.metadata?.generativePresent - ? result.metadata?.generative - : result.generative - ? result.generative.values[0].result - : undefined, - generative: result.generative - ? { - text: result.generative.values[0].result, - debug: result.generative.values[0].debug, - metadata: result.generative.values[0].metadata as GenerativeMetadata, - } - : result.metadata?.generativePresent - ? { - text: result.metadata?.generative, - } - : undefined, - metadata: Deserialize.metadata(result.metadata), - properties: this.properties(result.properties), - references: this.references(result.properties), - uuid: Deserialize.uuid(result.metadata), - vectors: Deserialize.vectors(result.metadata), - } as any; - }), + objects: await Promise.all( + reply.results.map( + async (result) => + ({ + generated: result.metadata?.generativePresent + ? result.metadata?.generative + : result.generative + ? result.generative.values[0].result + : undefined, + generative: result.generative + ? { + text: result.generative.values[0].result, + debug: result.generative.values[0].debug, + metadata: result.generative.values[0].metadata as GenerativeMetadata, + } + : result.metadata?.generativePresent + ? { + text: result.metadata?.generative, + } + : undefined, + metadata: Deserialize.metadata(result.metadata), + properties: this.properties(result.properties), + references: await this.references(result.properties), + uuid: Deserialize.uuid(result.metadata), + vectors: await Deserialize.vectors(result.metadata), + } as any) + ) + ), generated: reply.generativeGroupedResult !== '' ? reply.generativeGroupedResult @@ -257,20 +270,23 @@ export class Deserialize { }; } - public queryGroupBy(reply: SearchReply): GroupByReturn { - const objects: GroupByObject[] = []; - const groups: Record> = {}; - reply.groupByResults.forEach((result) => { - const objs = result.objects.map((object) => { - return { - belongsToGroup: result.name, - metadata: Deserialize.metadata(object.metadata), - properties: this.properties(object.properties), - references: this.references(object.properties), - uuid: Deserialize.uuid(object.metadata), - vectors: Deserialize.vectors(object.metadata), - } as any; - }); + public async queryGroupBy(reply: SearchReply): Promise> { + const objects: GroupByObject[] = []; + const groups: Record> = {}; + for (const result of reply.groupByResults) { + // eslint-disable-next-line no-await-in-loop + const objs = await Promise.all( + result.objects.map(async (object) => { + return { + belongsToGroup: result.name, + metadata: Deserialize.metadata(object.metadata), + properties: this.properties(object.properties), + references: await this.references(object.properties), + uuid: Deserialize.uuid(object.metadata), + vectors: await Deserialize.vectors(object.metadata), + } as unknown as GroupByObject; + }) + ); groups[result.name] = { maxDistance: result.maxDistance, minDistance: result.minDistance, @@ -279,27 +295,30 @@ export class Deserialize { objects: objs, }; objects.push(...objs); - }); + } return { objects: objects, groups: groups, }; } - public generateGroupBy(reply: SearchReply): GenerativeGroupByReturn { - const objects: GroupByObject[] = []; - const groups: Record> = {}; - reply.groupByResults.forEach((result) => { - const objs = result.objects.map((object) => { - return { - belongsToGroup: result.name, - metadata: Deserialize.metadata(object.metadata), - properties: this.properties(object.properties), - references: this.references(object.properties), - uuid: Deserialize.uuid(object.metadata), - vectors: Deserialize.vectors(object.metadata), - } as any; - }); + public async generateGroupBy(reply: SearchReply): Promise> { + const objects: GroupByObject[] = []; + const groups: Record> = {}; + for (const result of reply.groupByResults) { + // eslint-disable-next-line no-await-in-loop + const objs = await Promise.all( + result.objects.map(async (object) => { + return { + belongsToGroup: result.name, + metadata: Deserialize.metadata(object.metadata), + properties: this.properties(object.properties), + references: await this.references(object.properties), + uuid: Deserialize.uuid(object.metadata), + vectors: await Deserialize.vectors(object.metadata), + } as unknown as GroupByObject; + }) + ); groups[result.name] = { maxDistance: result.maxDistance, minDistance: result.minDistance, @@ -309,7 +328,7 @@ export class Deserialize { generated: result.generative?.result, }; objects.push(...objs); - }); + } return { objects: objects, groups: groups, @@ -322,28 +341,31 @@ export class Deserialize { return this.objectProperties(properties.nonRefProps); } - private references(properties?: PropertiesResult) { + private async references(properties?: PropertiesResult) { if (!properties) return undefined; if (properties.refProps.length === 0) return properties.refPropsRequested ? {} : undefined; const out: any = {}; - properties.refProps.forEach((property) => { + for (const property of properties.refProps) { const uuids: string[] = []; out[property.propName] = referenceFromObjects( - property.properties.map((property) => { - const uuid = Deserialize.uuid(property.metadata); - uuids.push(uuid); - return { - metadata: Deserialize.metadata(property.metadata), - properties: this.properties(property), - references: this.references(property), - uuid: uuid, - vectors: Deserialize.vectors(property.metadata), - }; - }), + // eslint-disable-next-line no-await-in-loop + await Promise.all( + property.properties.map(async (property) => { + const uuid = Deserialize.uuid(property.metadata); + uuids.push(uuid); + return { + metadata: Deserialize.metadata(property.metadata), + properties: this.properties(property), + references: await this.references(property), + uuid: uuid, + vectors: await Deserialize.vectors(property.metadata), + }; + }) + ), property.properties.length > 0 ? property.properties[0].targetCollection : '', uuids ); - }); + } return out; } @@ -409,7 +431,33 @@ export class Deserialize { return metadata.id; } - private static vectorFromBytes(bytes: Uint8Array) { + /** + * Convert an Uint8Array into a 2D vector array. + * + * Defined as an async method so that control can be relinquished back to the event loop on each outer loop for large vectors. + */ + private static vectorsFromBytes(bytes: Uint8Array): Promise { + const dimOffset = UINT16LEN; + const dimBytes = Buffer.from(bytes.slice(0, dimOffset)); + const vectorDimension = dimBytes.readUInt16LE(0); + + const vecByteLength = UINT32LEN * vectorDimension; + const howMany = (bytes.byteLength - dimOffset) / vecByteLength; + + return Promise.all( + Array(howMany) + .fill(0) + .map((_, i) => + yieldToEventLoop().then(() => + Deserialize.vectorFromBytes( + bytes.slice(dimOffset + i * vecByteLength, dimOffset + (i + 1) * vecByteLength) + ) + ) + ) + ); + } + + private static vectorFromBytes(bytes: Uint8Array): SingleVectorType { const buffer = Buffer.from(bytes); const view = new Float32Array(buffer.buffer, buffer.byteOffset, buffer.byteLength / 4); // vector is float32 in weaviate return Array.from(view); @@ -427,14 +475,21 @@ export class Deserialize { return Array.from(view); } - private static vectors(metadata?: MetadataResult): Record { + private static async vectors(metadata?: MetadataResult): Promise { if (!metadata) return {}; if (metadata.vectorBytes.length === 0 && metadata.vector.length === 0 && metadata.vectors.length === 0) return {}; if (metadata.vectorBytes.length > 0) return { default: Deserialize.vectorFromBytes(metadata.vectorBytes) }; return Object.fromEntries( - metadata.vectors.map((vector) => [vector.name, Deserialize.vectorFromBytes(vector.vectorBytes)]) + await Promise.all( + metadata.vectors.map(async (vector) => [ + vector.name, + vector.type === Vectors_VectorType.VECTOR_TYPE_MULTI_FP32 + ? await Deserialize.vectorsFromBytes(vector.vectorBytes) + : Deserialize.vectorFromBytes(vector.vectorBytes), + ]) + ) ); } diff --git a/src/collections/filters/integration.test.ts b/src/collections/filters/integration.test.ts index d4a72d05..6d159c1a 100644 --- a/src/collections/filters/integration.test.ts +++ b/src/collections/filters/integration.test.ts @@ -95,7 +95,7 @@ describe('Testing of the filter class with a simple collection', () => { return uuids; }); const res = await collection.query.fetchObjectById(ids[0], { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); it('should filter a fetch objects query with a single filter and generic collection', async () => { diff --git a/src/collections/generate/index.ts b/src/collections/generate/index.ts index 3a5fb1b3..db532b71 100644 --- a/src/collections/generate/index.ts +++ b/src/collections/generate/index.ts @@ -33,103 +33,127 @@ import { GenerativeGroupByReturn, GenerativeReturn, GroupByOptions, + ReturnVectors, } from '../types/index.js'; +import { IncludeVector } from '../types/internal.js'; import { Generate } from './types.js'; -class GenerateManager implements Generate { - private check: Check; +class GenerateManager implements Generate { + private check: Check; - private constructor(check: Check) { + private constructor(check: Check) { this.check = check; } - public static use( + public static use( connection: Connection, name: string, dbVersionSupport: DbVersionSupport, consistencyLevel?: ConsistencyLevel, tenant?: string - ): GenerateManager { - return new GenerateManager(new Check(connection, name, dbVersionSupport, consistencyLevel, tenant)); + ): GenerateManager { + return new GenerateManager( + new Check(connection, name, dbVersionSupport, consistencyLevel, tenant) + ); } - private async parseReply(reply: SearchReply) { + private async parseReply(reply: SearchReply) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); - return deserialize.generate(reply); + return deserialize.generate(reply); } - private async parseGroupByReply( - opts: SearchOptions | GroupByOptions | undefined, + private async parseGroupByReply( + opts: SearchOptions | GroupByOptions | undefined, reply: SearchReply ) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); return Serialize.search.isGroupBy(opts) - ? deserialize.generateGroupBy(reply) - : deserialize.generate(reply); + ? deserialize.generateGroupBy(reply) + : deserialize.generate(reply); } public fetchObjects( generate: GenerateOptions, - opts?: FetchObjectsOptions - ): Promise> { + opts?: FetchObjectsOptions + ): Promise> { return Promise.all([ this.check.fetchObjects(opts), this.check.supportForSingleGroupedGenerative(), this.check.supportForGenerativeConfigRuntime(generate.config), ]) - .then(async ([{ search }, supportsSingleGrouped]) => - search.withFetch({ + .then(async ([{ search }, supportsSingleGrouped]) => ({ + search, + args: { ...Serialize.search.fetchObjects(opts), generative: await Serialize.generative({ supportsSingleGrouped }, generate), - }) - ) + }, + })) + .then(({ search, args }) => search.withFetch(args)) .then((reply) => this.parseReply(reply)); } - public bm25( - query: string, - generate: GenerateOptions, - opts?: BaseBm25Options - ): Promise>; - public bm25( + public bm25< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts: GroupByBm25Options - ): Promise>; - public bm25( + opts?: BaseBm25Options + ): Promise>; + public bm25< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts?: Bm25Options - ): GenerateReturn { + opts: GroupByBm25Options + ): Promise>; + public bm25< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >(query: string, generate: GenerateOptions, opts?: Bm25Options): GenerateReturn { return Promise.all([ this.check.bm25(opts), this.check.supportForSingleGroupedGenerative(), this.check.supportForGenerativeConfigRuntime(generate.config), ]) - .then(async ([{ search }, supportsSingleGrouped]) => - search.withBm25({ + .then(async ([{ search }, supportsSingleGrouped]) => ({ + search, + args: { ...Serialize.search.bm25(query, opts), generative: await Serialize.generative({ supportsSingleGrouped }, generate), - }) - ) + }, + })) + .then(({ search, args }) => search.withBm25(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public hybrid( - query: string, - generate: GenerateOptions, - opts?: BaseHybridOptions - ): Promise>; - public hybrid( + public hybrid< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts: GroupByHybridOptions - ): Promise>; - public hybrid( + opts?: BaseHybridOptions + ): Promise>; + public hybrid< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts?: HybridOptions - ): GenerateReturn { + opts: GroupByHybridOptions + ): Promise>; + public hybrid< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >(query: string, generate: GenerateOptions, opts?: HybridOptions): GenerateReturn { return Promise.all([ this.check.hybridSearch(opts), this.check.supportForSingleGroupedGenerative(), @@ -137,88 +161,110 @@ class GenerateManager implements Generate { ]) .then( async ([ - { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }, + { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets, supportsVectors }, supportsSingleGrouped, - ]) => - search.withHybrid({ - ...Serialize.search.hybrid( + ]) => ({ + search, + args: { + ...(await Serialize.search.hybrid( { query, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets, + supportsVectors, }, opts - ), + )), generative: await Serialize.generative({ supportsSingleGrouped }, generate), - }) + }, + }) ) + .then(({ search, args }) => search.withHybrid(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearImage( + public nearImage< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( image: string | Buffer, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; - public nearImage( + opts?: BaseNearOptions + ): Promise>; + public nearImage< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( image: string | Buffer, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; - public nearImage( + opts: GroupByNearOptions + ): Promise>; + public nearImage< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( image: string | Buffer, generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn { + opts?: NearOptions + ): GenerateReturn { return Promise.all([ this.check.nearSearch(opts), this.check.supportForSingleGroupedGenerative(), this.check.supportForGenerativeConfigRuntime(generate.config), ]) - .then(([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => - Promise.all([ - toBase64FromMedia(image), - Serialize.generative({ supportsSingleGrouped }, generate), - ]).then(([image, generative]) => - search.withNearImage({ - ...Serialize.search.nearImage( - { - image, - supportsTargets, - supportsWeightsForTargets, - }, - opts - ), - generative, - }) - ) - ) + .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => ({ + search, + args: { + ...Serialize.search.nearImage( + { + image: await toBase64FromMedia(image), + supportsTargets, + supportsWeightsForTargets, + }, + opts + ), + generative: await Serialize.generative({ supportsSingleGrouped }, generate), + }, + })) + .then(({ search, args }) => search.withNearImage(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearObject( + public nearObject< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( id: string, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; - public nearObject( + opts?: BaseNearOptions + ): Promise>; + public nearObject< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( id: string, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; - public nearObject( - id: string, - generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn { + opts: GroupByNearOptions + ): Promise>; + public nearObject< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >(id: string, generate: GenerateOptions, opts?: NearOptions): GenerateReturn { return Promise.all([ this.check.nearSearch(opts), this.check.supportForSingleGroupedGenerative(), this.check.supportForGenerativeConfigRuntime(generate.config), ]) - .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => - search.withNearObject({ + .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => ({ + search, + args: { ...Serialize.search.nearObject( { id, @@ -228,33 +274,47 @@ class GenerateManager implements Generate { opts ), generative: await Serialize.generative({ supportsSingleGrouped }, generate), - }) - ) + }, + })) + .then(({ search, args }) => search.withNearObject(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearText( + public nearText< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string | string[], generate: GenerateOptions, - opts?: BaseNearTextOptions - ): Promise>; - public nearText( + opts?: BaseNearTextOptions + ): Promise>; + public nearText< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string | string[], generate: GenerateOptions, - opts: GroupByNearTextOptions - ): Promise>; - public nearText( + opts: GroupByNearTextOptions + ): Promise>; + public nearText< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string | string[], generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn { + opts?: NearOptions + ): GenerateReturn { return Promise.all([ this.check.nearSearch(opts), this.check.supportForSingleGroupedGenerative(), this.check.supportForGenerativeConfigRuntime(generate.config), ]) - .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => - search.withNearText({ + .then(async ([{ search, supportsTargets, supportsWeightsForTargets }, supportsSingleGrouped]) => ({ + search, + args: { ...Serialize.search.nearText( { query, @@ -264,26 +324,39 @@ class GenerateManager implements Generate { opts ), generative: await Serialize.generative({ supportsSingleGrouped }, generate), - }) - ) + }, + })) + .then(({ search, args }) => search.withNearText(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearVector( + public nearVector< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( vector: number[], generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; - public nearVector( + opts?: BaseNearOptions + ): Promise>; + public nearVector< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( vector: number[], generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; - public nearVector( + opts: GroupByNearOptions + ): Promise>; + public nearVector< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( vector: number[], generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn { + opts?: NearOptions + ): GenerateReturn { return Promise.all([ this.check.nearVector(vector, opts), this.check.supportForSingleGroupedGenerative(), @@ -291,43 +364,59 @@ class GenerateManager implements Generate { ]) .then( async ([ - { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }, + { search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets, supportsVectors }, supportsSingleGrouped, - ]) => - search.withNearVector({ - ...Serialize.search.nearVector( + ]) => ({ + search, + args: { + ...(await Serialize.search.nearVector( { vector, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets, + supportsVectors, }, opts - ), + )), generative: await Serialize.generative({ supportsSingleGrouped }, generate), - }) + }, + }) ) + .then(({ search, args }) => search.withNearVector(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearMedia( + public nearMedia< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; - public nearMedia( + opts?: BaseNearOptions + ): Promise>; + public nearMedia< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; - public nearMedia( + opts: GroupByNearOptions + ): Promise>; + public nearMedia< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn { + opts?: NearOptions + ): GenerateReturn { return Promise.all([ this.check.nearSearch(opts), this.check.supportForSingleGroupedGenerative(), diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index 39e8e351..add4bc7d 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -66,7 +66,7 @@ maybe('Testing of the collection.generate methods with a simple collection', () }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); describe('using a non-generic collection', () => { @@ -227,7 +227,7 @@ maybe('Testing of the groupBy collection.generate methods with a simple collecti }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); // it('should groupBy without search', async () => { @@ -387,8 +387,8 @@ maybe('Testing of the collection.generate methods with a multi vector collection }, }); const res = await collection.query.fetchObjectById(id1, { includeVector: true }); - titleVector = res!.vectors.title!; - title2Vector = res!.vectors.title2!; + titleVector = res!.vectors.title as number[]; + title2Vector = res!.vectors.title2 as number[]; }); if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 24, 0))) { await expect(query()).rejects.toThrow(WeaviateUnsupportedFeatureError); diff --git a/src/collections/generate/types.ts b/src/collections/generate/types.ts index 27548bfb..b9022f7c 100644 --- a/src/collections/generate/types.ts +++ b/src/collections/generate/types.ts @@ -21,9 +21,11 @@ import { GenerativeConfigRuntime, GenerativeGroupByReturn, GenerativeReturn, + ReturnVectors, } from '../types/index.js'; +import { IncludeVector } from '../types/internal.js'; -interface Bm25 { +interface Bm25 { /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -33,14 +35,18 @@ interface Bm25 { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseBm25Options} [opts] - The available options for performing the BM25 search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseBm25Options} [opts] - The available options for performing the BM25 search. + * @return {Promise>} - The results of the search including the generated data. */ - bm25( + bm25< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts?: BaseBm25Options - ): Promise>; + opts?: BaseBm25Options + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -50,14 +56,18 @@ interface Bm25 { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByBm25Options} opts - The available options for performing the BM25 search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByBm25Options} opts - The available options for performing the BM25 search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - bm25( + bm25< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts: GroupByBm25Options - ): Promise>; + opts: GroupByBm25Options + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a keyword-based BM25 search of objects in this collection. * @@ -67,17 +77,21 @@ interface Bm25 { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {Bm25Options} [opts] - The available options for performing the BM25 search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {Bm25Options} [opts] - The available options for performing the BM25 search. + * @return {GenerateReturn} - The results of the search including the generated data. */ - bm25( + bm25< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts?: Bm25Options - ): GenerateReturn; + opts?: Bm25Options + ): GenerateReturn; } -interface Hybrid { +interface Hybrid { /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -87,14 +101,18 @@ interface Hybrid { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseHybridOptions} [opts] - The available options for performing the hybrid search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseHybridOptions} [opts] - The available options for performing the hybrid search. + * @return {Promise>} - The results of the search including the generated data. */ - hybrid( + hybrid< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts?: BaseHybridOptions - ): Promise>; + opts?: BaseHybridOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -104,14 +122,18 @@ interface Hybrid { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByHybridOptions} opts - The available options for performing the hybrid search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByHybridOptions} opts - The available options for performing the hybrid search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - hybrid( + hybrid< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts: GroupByHybridOptions - ): Promise>; + opts: GroupByHybridOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of an object search in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -121,17 +143,21 @@ interface Hybrid { * * @param {string} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {HybridOptions} [opts] - The available options for performing the hybrid search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {HybridOptions} [opts] - The available options for performing the hybrid search. + * @return {GenerateReturn} - The results of the search including the generated data. */ - hybrid( + hybrid< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string, generate: GenerateOptions, - opts?: HybridOptions - ): GenerateReturn; + opts?: HybridOptions + ): GenerateReturn; } -interface NearMedia { +interface NearMedia { /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -144,15 +170,19 @@ interface NearMedia { * @param {string | Buffer} media - The media file to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseNearOptions} [opts] - The available options for performing the near-media search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseNearOptions} [opts] - The available options for performing the near-media search. + * @return {Promise>} - The results of the search including the generated data. */ - nearMedia( + nearMedia< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -165,15 +195,19 @@ interface NearMedia { * @param {string | Buffer} media - The media file to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByNearOptions} opts - The available options for performing the near-media search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByNearOptions} opts - The available options for performing the near-media search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearMedia( + nearMedia< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-audio object search in this collection using an audio-capable vectorization module and vector-based similarity search. * @@ -186,18 +220,22 @@ interface NearMedia { * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search on. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {NearOptions} [opts] - The available options for performing the near-media search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {NearOptions} [opts] - The available options for performing the near-media search. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearMedia( + nearMedia< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( media: string | Buffer, type: NearMediaType, generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn; + opts?: NearOptions + ): GenerateReturn; } -interface NearObject { +interface NearObject { /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -207,14 +245,18 @@ interface NearObject { * * @param {string} id - The ID of the object to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseNearOptions} [opts] - The available options for performing the near-object search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseNearOptions} [opts] - The available options for performing the near-object search. + * @return {Promise>} - The results of the search including the generated data. */ - nearObject( + nearObject< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( id: string, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -224,14 +266,18 @@ interface NearObject { * * @param {string} id - The ID of the object to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByNearOptions} opts - The available options for performing the near-object search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByNearOptions} opts - The available options for performing the near-object search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearObject( + nearObject< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( id: string, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-object object search in this collection using a vector-based similarity search. * @@ -241,17 +287,21 @@ interface NearObject { * * @param {string} id - The ID of the object to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {NearOptions} [opts] - The available options for performing the near-object search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {NearOptions} [opts] - The available options for performing the near-object search. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearObject( + nearObject< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( id: string, generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn; + opts?: NearOptions + ): GenerateReturn; } -interface NearText { +interface NearText { /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -263,14 +313,18 @@ interface NearText { * * @param {string | string[]} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseNearTextOptions} [opts] - The available options for performing the near-text search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseNearTextOptions} [opts] - The available options for performing the near-text search. + * @return {Promise>} - The results of the search including the generated data. */ - nearText( + nearText< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string | string[], generate: GenerateOptions, - opts?: BaseNearTextOptions - ): Promise>; + opts?: BaseNearTextOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -282,14 +336,18 @@ interface NearText { * * @param {string | string[]} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByNearTextOptions} opts - The available options for performing the near-text search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByNearTextOptions} opts - The available options for performing the near-text search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearText( + nearText< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string | string[], generate: GenerateOptions, - opts: GroupByNearTextOptions - ): Promise>; + opts: GroupByNearTextOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-image object search in this collection using the image-capable vectorization module and vector-based similarity search. * @@ -301,17 +359,21 @@ interface NearText { * * @param {string | string[]} query - The query to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {NearTextOptions} [opts] - The available options for performing the near-text search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {NearTextOptions} [opts] - The available options for performing the near-text search. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearText( + nearText< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( query: string | string[], generate: GenerateOptions, - opts?: NearTextOptions - ): GenerateReturn; + opts?: NearTextOptions + ): GenerateReturn; } -interface NearVector { +interface NearVector { /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -321,14 +383,18 @@ interface NearVector { * * @param {NearVectorInputType} vector - The vector(s) to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {BaseNearOptions} [opts] - The available options for performing the near-vector search. - * @return {Promise>} - The results of the search including the generated data. + * @param {BaseNearOptions} [opts] - The available options for performing the near-vector search. + * @return {Promise>} - The results of the search including the generated data. */ - nearVector( + nearVector< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( vector: NearVectorInputType, generate: GenerateOptions, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -338,14 +404,18 @@ interface NearVector { * * @param {NearVectorInputType} vector - The vector(s) to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {GroupByNearOptions} opts - The available options for performing the near-vector search. - * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. + * @param {GroupByNearOptions} opts - The available options for performing the near-vector search. + * @return {Promise>} - The results of the search including the generated data grouped by the specified properties. */ - nearVector( + nearVector< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( vector: NearVectorInputType, generate: GenerateOptions, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; /** * Perform retrieval-augmented generation (RaG) on the results of a by-vector object search in this collection using vector-based similarity search. * @@ -355,25 +425,29 @@ interface NearVector { * * @param {NearVectorInputType} vector - The vector(s) to search for. * @param {GenerateOptions} generate - The available options for performing the generation. - * @param {NearOptions} [opts] - The available options for performing the near-vector search. - * @return {GenerateReturn} - The results of the search including the generated data. + * @param {NearOptions} [opts] - The available options for performing the near-vector search. + * @return {GenerateReturn} - The results of the search including the generated data. */ - nearVector( + nearVector< + I extends IncludeVector, + RV extends ReturnVectors, + C extends GenerativeConfigRuntime | undefined = undefined + >( vector: NearVectorInputType, generate: GenerateOptions, - opts?: NearOptions - ): GenerateReturn; + opts?: NearOptions + ): GenerateReturn; } -export interface Generate - extends Bm25, - Hybrid, - NearMedia, - NearObject, - NearText, - NearVector { +export interface Generate + extends Bm25, + Hybrid, + NearMedia, + NearObject, + NearText, + NearVector { fetchObjects: ( generate: GenerateOptions, - opts?: FetchObjectsOptions - ) => Promise>; + opts?: FetchObjectsOptions + ) => Promise>; } diff --git a/src/collections/index.ts b/src/collections/index.ts index cd5707a7..3e344691 100644 --- a/src/collections/index.ts +++ b/src/collections/index.ts @@ -31,6 +31,7 @@ import { VectorConfigCreate, Vectorizer, VectorizersConfigCreate, + Vectors, } from './types/index.js'; import { PrimitiveKeys } from './types/internal.js'; @@ -40,7 +41,7 @@ import { PrimitiveKeys } from './types/internal.js'; * Inspect [the docs](https://weaviate.io/developers/weaviate/configuration) for more information on the * different configuration options and how they affect the behavior of your collection. */ -export type CollectionConfigCreate = { +export type CollectionConfigCreate = { /** The name of the collection. */ name: N; /** The description of the collection. */ @@ -62,7 +63,7 @@ export type CollectionConfigCreate = { /** The configuration for Weaviate's sharding strategy. Is mutually exclusive with `replication`. */ sharding?: ShardingConfigCreate; /** The configuration for Weaviate's vectorizer(s) capabilities. */ - vectorizers?: VectorizersConfigCreate; + vectorizers?: VectorizersConfigCreate; }; const collections = (connection: Connection, dbVersionSupport: DbVersionSupport) => { @@ -72,9 +73,11 @@ const collections = (connection: Connection, dbVersionSupport: DbVersionSupport) .then((schema) => (schema.classes ? schema.classes.map(classToCollection) : [])); const deleteCollection = (name: string) => new ClassDeleter(connection).withClassName(name).do(); return { - create: async function ( - config: CollectionConfigCreate - ) { + create: async function < + TProperties extends Properties | undefined = undefined, + TName = string, + TVectors extends Vectors | undefined = undefined + >(config: CollectionConfigCreate) { const { name, invertedIndex, multiTenancy, replication, sharding, ...rest } = config; const supportsDynamicVectorIndex = await dbVersionSupport.supportsDynamicVectorIndex(); @@ -175,11 +178,11 @@ const collections = (connection: Connection, dbVersionSupport: DbVersionSupport) schema.properties = [...properties, ...references]; await new ClassCreator(connection).withClass(schema).do(); - return collection(connection, name, dbVersionSupport); + return collection(connection, name, dbVersionSupport); }, createFromSchema: async function (config: WeaviateClass) { const { class: name } = await new ClassCreator(connection).withClass(config).do(); - return collection(connection, name as string, dbVersionSupport); + return collection(connection, name as string, dbVersionSupport); }, delete: deleteCollection, deleteAll: () => listAll().then((configs) => Promise.all(configs?.map((c) => deleteCollection(c.name)))), @@ -192,17 +195,25 @@ const collections = (connection: Connection, dbVersionSupport: DbVersionSupport) listAll: listAll, get: ( name: TName - ) => collection(connection, name, dbVersionSupport), - use: ( + ) => collection(connection, name, dbVersionSupport), + use: < + TProperties extends Properties | undefined = undefined, + TName extends string = string, + TVectors extends Vectors | undefined = undefined + >( name: TName - ) => collection(connection, name, dbVersionSupport), + ) => collection(connection, name, dbVersionSupport), }; }; export interface Collections { - create( - config: CollectionConfigCreate - ): Promise>; + create< + TProperties extends Properties | undefined = undefined, + TName = string, + TVectors extends Vectors | undefined = undefined + >( + config: CollectionConfigCreate + ): Promise>; createFromSchema(config: WeaviateClass): Promise>; delete(collection: string): Promise; deleteAll(): Promise; @@ -212,9 +223,13 @@ export interface Collections { name: TName ): Collection; listAll(): Promise; - use( + use< + TName extends string = string, + TProperties extends Properties | undefined = undefined, + TVectors extends Vectors | undefined = undefined + >( name: TName - ): Collection; + ): Collection; } export default collections; diff --git a/src/collections/iterator/index.ts b/src/collections/iterator/index.ts index 0e631bb9..edb35762 100644 --- a/src/collections/iterator/index.ts +++ b/src/collections/iterator/index.ts @@ -3,16 +3,16 @@ import { WeaviateObject } from '../types/index.js'; const ITERATOR_CACHE_SIZE = 100; -export class Iterator { - private cache: WeaviateObject[] = []; +export class Iterator { + private cache: WeaviateObject[] = []; private last: string | undefined = undefined; - constructor(private query: (limit: number, after?: string) => Promise[]>) { + constructor(private query: (limit: number, after?: string) => Promise[]>) { this.query = query; } [Symbol.asyncIterator]() { return { - next: async (): Promise>> => { + next: async (): Promise>> => { const objects = await this.query(ITERATOR_CACHE_SIZE, this.last); this.cache = objects; if (this.cache.length == 0) { diff --git a/src/collections/iterator/integration.test.ts b/src/collections/iterator/integration.test.ts index c4757083..22401a05 100644 --- a/src/collections/iterator/integration.test.ts +++ b/src/collections/iterator/integration.test.ts @@ -45,7 +45,7 @@ describe('Testing of the collection.iterator method with a simple collection', ( }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); it('should iterate through the collection with no options returning the objects', async () => { diff --git a/src/collections/journey.test.ts b/src/collections/journey.test.ts index 85da3a8f..5a545f9d 100644 --- a/src/collections/journey.test.ts +++ b/src/collections/journey.test.ts @@ -187,6 +187,7 @@ describe('Journey testing of the client using a WCD cluster', () => { maxConnections: (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 26, 0))) ? 64 : 32, + multiVector: undefined, skip: false, vectorCacheMaxObjects: 1000000000000, quantizer: undefined, diff --git a/src/collections/query/check.ts b/src/collections/query/check.ts index 2e562a2b..8cc1230f 100644 --- a/src/collections/query/check.ts +++ b/src/collections/query/check.ts @@ -17,7 +17,7 @@ import { SearchOptions, } from './types.js'; -export class Check { +export class Check { private connection: Connection; private name: string; public dbVersionSupport: DbVersionSupport; @@ -40,7 +40,7 @@ export class Check { private getSearcher = () => this.connection.search(this.name, this.consistencyLevel, this.tenant); - private checkSupportForNamedVectors = async (opts?: BaseNearOptions) => { + private checkSupportForNamedVectors = async (opts?: BaseNearOptions) => { if (!Serialize.isNamedVectors(opts)) return; const check = await this.dbVersionSupport.supportsNamedVectors(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); @@ -48,20 +48,22 @@ export class Check { private checkSupportForBm25AndHybridGroupByQueries = async ( query: 'Bm25' | 'Hybrid', - opts?: SearchOptions | GroupByOptions + opts?: SearchOptions | GroupByOptions ) => { if (!Serialize.search.isGroupBy(opts)) return; const check = await this.dbVersionSupport.supportsBm25AndHybridGroupByQueries(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message(query)); }; - private checkSupportForHybridNearTextAndNearVectorSubSearches = async (opts?: HybridOptions) => { + private checkSupportForHybridNearTextAndNearVectorSubSearches = async ( + opts?: HybridOptions + ) => { if (opts?.vector === undefined || Array.isArray(opts.vector)) return; const check = await this.dbVersionSupport.supportsHybridNearTextAndNearVectorSubsearchQueries(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); }; - private checkSupportForMultiTargetSearch = async (opts?: BaseNearOptions) => { + private checkSupportForMultiTargetSearch = async (opts?: BaseNearOptions) => { if (!Serialize.isMultiTarget(opts)) return false; const check = await this.dbVersionSupport.supportsMultiTargetVectorSearch(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); @@ -79,7 +81,7 @@ export class Check { return check.supports; }; - private checkSupportForMultiWeightPerTargetSearch = async (opts?: BaseNearOptions) => { + private checkSupportForMultiWeightPerTargetSearch = async (opts?: BaseNearOptions) => { if (!Serialize.isMultiWeightPerTarget(opts)) return false; const check = await this.dbVersionSupport.supportsMultiWeightsPerTargetSearch(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); @@ -98,6 +100,14 @@ export class Check { return check.supports; }; + private checkSupportForVectors = async ( + vec?: NearVectorInputType | HybridNearVectorSubSearch | HybridNearTextSubSearch + ) => { + if (vec === undefined || Serialize.isHybridNearTextSearch(vec)) return false; + const check = await this.dbVersionSupport.supportsVectorsFieldInGRPC(); + return check.supports; + }; + public supportForSingleGroupedGenerative = async () => { const check = await this.dbVersionSupport.supportsSingleGrouped(); if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message); @@ -111,7 +121,7 @@ export class Check { return check.supports; }; - public nearSearch = (opts?: BaseNearOptions) => { + public nearSearch = (opts?: BaseNearOptions) => { return Promise.all([ this.getSearcher(), this.checkSupportForMultiTargetSearch(opts), @@ -124,13 +134,14 @@ export class Check { }); }; - public nearVector = (vec: NearVectorInputType, opts?: BaseNearOptions) => { + public nearVector = (vec: NearVectorInputType, opts?: BaseNearOptions) => { return Promise.all([ this.getSearcher(), this.checkSupportForMultiTargetSearch(opts), this.checkSupportForMultiVectorSearch(vec), this.checkSupportForMultiVectorPerTargetSearch(vec), this.checkSupportForMultiWeightPerTargetSearch(opts), + this.checkSupportForVectors(vec), this.checkSupportForNamedVectors(opts), ]).then( ([ @@ -139,26 +150,30 @@ export class Check { supportsMultiVector, supportsVectorsForTargets, supportsWeightsForTargets, + supportsVectors, ]) => { const is126 = supportsMultiTarget || supportsMultiVector; const is127 = supportsVectorsForTargets || supportsWeightsForTargets; + const is129 = supportsVectors; return { search, supportsTargets: is126 || is127, supportsVectorsForTargets: is127, supportsWeightsForTargets: is127, + supportsVectors: is129, }; } ); }; - public hybridSearch = (opts?: BaseHybridOptions) => { + public hybridSearch = (opts?: BaseHybridOptions) => { return Promise.all([ this.getSearcher(), this.checkSupportForMultiTargetSearch(opts), this.checkSupportForMultiVectorSearch(opts?.vector), this.checkSupportForMultiVectorPerTargetSearch(opts?.vector), this.checkSupportForMultiWeightPerTargetSearch(opts), + this.checkSupportForVectors(opts?.vector), this.checkSupportForNamedVectors(opts), this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts), this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts), @@ -169,32 +184,35 @@ export class Check { supportsMultiVector, supportsWeightsForTargets, supportsVectorsForTargets, + supportsVectors, ]) => { const is126 = supportsMultiTarget || supportsMultiVector; const is127 = supportsVectorsForTargets || supportsWeightsForTargets; + const is129 = supportsVectors; return { search, supportsTargets: is126 || is127, supportsWeightsForTargets: is127, supportsVectorsForTargets: is127, + supportsVectors: is129, }; } ); }; - public fetchObjects = (opts?: FetchObjectsOptions) => { + public fetchObjects = (opts?: FetchObjectsOptions) => { return Promise.all([this.getSearcher(), this.checkSupportForNamedVectors(opts)]).then(([search]) => { return { search }; }); }; - public fetchObjectById = (opts?: FetchObjectByIdOptions) => { + public fetchObjectById = (opts?: FetchObjectByIdOptions) => { return Promise.all([this.getSearcher(), this.checkSupportForNamedVectors(opts)]).then(([search]) => { return { search }; }); }; - public bm25 = (opts?: BaseBm25Options) => { + public bm25 = (opts?: BaseBm25Options) => { return Promise.all([ this.getSearcher(), this.checkSupportForNamedVectors(opts), diff --git a/src/collections/query/factories.ts b/src/collections/query/factories.ts new file mode 100644 index 00000000..10d5c4b7 --- /dev/null +++ b/src/collections/query/factories.ts @@ -0,0 +1,22 @@ +import { ListOfVectors, PrimitiveVectorType } from './types.js'; +import { NearVectorInputGuards } from './utils.js'; + +const hybridVector = { + nearText: () => {}, + nearVector: () => {}, +}; + +const nearVector = { + listOfVectors: (...vectors: V[]): ListOfVectors => { + return { + kind: 'listOfVectors', + dimensionality: NearVectorInputGuards.is1D(vectors[0]) ? '1D' : '2D', + vectors, + }; + }, +}; + +export const queryFactory = { + hybridVector, + nearVector, +}; diff --git a/src/collections/query/index.ts b/src/collections/query/index.ts index 9fb5b58e..d288f0af 100644 --- a/src/collections/query/index.ts +++ b/src/collections/query/index.ts @@ -8,9 +8,16 @@ import { DbVersionSupport } from '../../utils/dbVersion.js'; import { SearchReply } from '../../proto/v1/search_get.js'; import { Deserialize } from '../deserialize/index.js'; import { Serialize } from '../serialize/index.js'; -import { GroupByOptions, GroupByReturn, WeaviateObject, WeaviateReturn } from '../types/index.js'; +import { + GroupByOptions, + GroupByReturn, + ReturnVectors, + WeaviateObject, + WeaviateReturn, +} from '../types/index.js'; import { WeaviateInvalidInputError } from '../../errors.js'; +import { IncludeVector } from '../types/internal.js'; import { Check } from './check.js'; import { BaseBm25Options, @@ -34,111 +41,163 @@ import { SearchOptions, } from './types.js'; -class QueryManager implements Query { - private check: Check; +class QueryManager implements Query { + private check: Check; - private constructor(check: Check) { + private constructor(check: Check) { this.check = check; } - public static use( + public static use( connection: Connection, name: string, dbVersionSupport: DbVersionSupport, consistencyLevel?: ConsistencyLevel, tenant?: string - ): QueryManager { - return new QueryManager(new Check(connection, name, dbVersionSupport, consistencyLevel, tenant)); + ): QueryManager { + return new QueryManager( + new Check(connection, name, dbVersionSupport, consistencyLevel, tenant) + ); } - private async parseReply(reply: SearchReply) { + private async parseReply(reply: SearchReply) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); - return deserialize.query(reply); + return deserialize.query(reply); } - private async parseGroupByReply( - opts: SearchOptions | GroupByOptions | undefined, + private async parseGroupByReply( + opts: SearchOptions | GroupByOptions | undefined, reply: SearchReply ) { const deserialize = await Deserialize.use(this.check.dbVersionSupport); return Serialize.search.isGroupBy(opts) - ? deserialize.queryGroupBy(reply) - : deserialize.query(reply); + ? deserialize.queryGroupBy(reply) + : deserialize.query(reply); } - public fetchObjectById(id: string, opts?: FetchObjectByIdOptions): Promise | null> { + public fetchObjectById, RV extends ReturnVectors>( + id: string, + opts?: FetchObjectByIdOptions + ): Promise | null> { return this.check .fetchObjectById(opts) .then(({ search }) => search.withFetch(Serialize.search.fetchObjectById({ id, ...opts }))) - .then((reply) => this.parseReply(reply)) + .then((reply) => this.parseReply(reply)) .then((ret) => (ret.objects.length === 1 ? ret.objects[0] : null)); } - public fetchObjects(opts?: FetchObjectsOptions): Promise> { + public fetchObjects, RV extends ReturnVectors>( + opts?: FetchObjectsOptions + ): Promise> { return this.check .fetchObjects(opts) .then(({ search }) => search.withFetch(Serialize.search.fetchObjects(opts))) .then((reply) => this.parseReply(reply)); } - public bm25(query: string, opts?: BaseBm25Options): Promise>; - public bm25(query: string, opts: GroupByBm25Options): Promise>; - public bm25(query: string, opts?: Bm25Options): QueryReturn { + public bm25, RV extends ReturnVectors>( + query: string, + opts?: BaseBm25Options + ): Promise>; + public bm25, RV extends ReturnVectors>( + query: string, + opts: GroupByBm25Options + ): Promise>; + public bm25, RV extends ReturnVectors>( + query: string, + opts?: Bm25Options + ): QueryReturn { return this.check .bm25(opts) .then(({ search }) => search.withBm25(Serialize.search.bm25(query, opts))) .then((reply) => this.parseGroupByReply(opts, reply)); } - public hybrid(query: string, opts?: BaseHybridOptions): Promise>; - public hybrid(query: string, opts: GroupByHybridOptions): Promise>; - public hybrid(query: string, opts?: HybridOptions): QueryReturn { + public hybrid, RV extends ReturnVectors>( + query: string, + opts?: BaseHybridOptions + ): Promise>; + public hybrid, RV extends ReturnVectors>( + query: string, + opts: GroupByHybridOptions + ): Promise>; + public hybrid, RV extends ReturnVectors>( + query: string, + opts?: HybridOptions + ): QueryReturn { return this.check .hybridSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets, supportsVectorsForTargets }) => - search.withHybrid( - Serialize.search.hybrid( - { query, supportsTargets, supportsWeightsForTargets, supportsVectorsForTargets }, + .then( + async ({ + search, + supportsTargets, + supportsWeightsForTargets, + supportsVectorsForTargets, + supportsVectors, + }) => ({ + search, + args: await Serialize.search.hybrid( + { + query, + supportsTargets, + supportsWeightsForTargets, + supportsVectorsForTargets, + supportsVectors, + }, opts - ) - ) + ), + }) ) + .then(({ search, args }) => search.withHybrid(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearImage(image: string | Buffer, opts?: BaseNearOptions): Promise>; - public nearImage(image: string | Buffer, opts: GroupByNearOptions): Promise>; - public nearImage(image: string | Buffer, opts?: NearOptions): QueryReturn { + public nearImage, RV extends ReturnVectors>( + image: string | Buffer, + opts?: BaseNearOptions + ): Promise>; + public nearImage, RV extends ReturnVectors>( + image: string | Buffer, + opts: GroupByNearOptions + ): Promise>; + public nearImage, RV extends ReturnVectors>( + image: string | Buffer, + opts?: NearOptions + ): QueryReturn { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => { - return toBase64FromMedia(image).then((image) => - search.withNearImage( - Serialize.search.nearImage( - { - image, - supportsTargets, - supportsWeightsForTargets, - }, - opts - ) - ) - ); + return toBase64FromMedia(image).then((image) => ({ + search, + args: Serialize.search.nearImage( + { + image, + supportsTargets, + supportsWeightsForTargets, + }, + opts + ), + })); }) + .then(({ search, args }) => search.withNearImage(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearMedia( + public nearMedia, RV extends ReturnVectors>( + media: string | Buffer, + type: NearMediaType, + opts?: BaseNearOptions + ): Promise>; + public nearMedia, RV extends ReturnVectors>( media: string | Buffer, type: NearMediaType, - opts?: BaseNearOptions - ): Promise>; - public nearMedia( + opts: GroupByNearOptions + ): Promise>; + public nearMedia, RV extends ReturnVectors>( media: string | Buffer, type: NearMediaType, - opts: GroupByNearOptions - ): Promise>; - public nearMedia(media: string | Buffer, type: NearMediaType, opts?: NearOptions): QueryReturn { + opts?: NearOptions + ): QueryReturn { return this.check .nearSearch(opts) .then(({ search, supportsTargets, supportsWeightsForTargets }) => { @@ -178,70 +237,106 @@ class QueryManager implements Query { .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearObject(id: string, opts?: BaseNearOptions): Promise>; - public nearObject(id: string, opts: GroupByNearOptions): Promise>; - public nearObject(id: string, opts?: NearOptions): QueryReturn { + public nearObject, RV extends ReturnVectors>( + id: string, + opts?: BaseNearOptions + ): Promise>; + public nearObject, RV extends ReturnVectors>( + id: string, + opts: GroupByNearOptions + ): Promise>; + public nearObject, RV extends ReturnVectors>( + id: string, + opts?: NearOptions + ): QueryReturn { return this.check .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => - search.withNearObject( - Serialize.search.nearObject( - { - id, - supportsTargets, - supportsWeightsForTargets, - }, - opts - ) - ) - ) + .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ + search, + args: Serialize.search.nearObject( + { + id, + supportsTargets, + supportsWeightsForTargets, + }, + opts + ), + })) + .then(({ search, args }) => search.withNearObject(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearText(query: string | string[], opts?: BaseNearTextOptions): Promise>; - public nearText(query: string | string[], opts: GroupByNearTextOptions): Promise>; - public nearText(query: string | string[], opts?: NearTextOptions): QueryReturn { + public nearText, RV extends ReturnVectors>( + query: string | string[], + opts?: BaseNearTextOptions + ): Promise>; + public nearText, RV extends ReturnVectors>( + query: string | string[], + opts: GroupByNearTextOptions + ): Promise>; + public nearText, RV extends ReturnVectors>( + query: string | string[], + opts?: NearTextOptions + ): QueryReturn { return this.check .nearSearch(opts) - .then(({ search, supportsTargets, supportsWeightsForTargets }) => - search.withNearText( - Serialize.search.nearText( - { - query, - supportsTargets, - supportsWeightsForTargets, - }, - opts - ) - ) - ) + .then(({ search, supportsTargets, supportsWeightsForTargets }) => ({ + search, + args: Serialize.search.nearText( + { + query, + supportsTargets, + supportsWeightsForTargets, + }, + opts + ), + })) + .then(({ search, args }) => search.withNearText(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } - public nearVector(vector: NearVectorInputType, opts?: BaseNearOptions): Promise>; - public nearVector(vector: NearVectorInputType, opts: GroupByNearOptions): Promise>; - public nearVector(vector: NearVectorInputType, opts?: NearOptions): QueryReturn { + public nearVector, RV extends ReturnVectors>( + vector: NearVectorInputType, + opts?: BaseNearOptions + ): Promise>; + public nearVector, RV extends ReturnVectors>( + vector: NearVectorInputType, + opts: GroupByNearOptions + ): Promise>; + public nearVector, RV extends ReturnVectors>( + vector: NearVectorInputType, + opts?: NearOptions + ): QueryReturn { return this.check .nearVector(vector, opts) - .then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) => - search.withNearVector( - Serialize.search.nearVector( + .then( + async ({ + search, + supportsTargets, + supportsVectorsForTargets, + supportsWeightsForTargets, + supportsVectors, + }) => ({ + search, + args: await Serialize.search.nearVector( { vector, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets, + supportsVectors, }, opts - ) - ) + ), + }) ) + .then(({ search, args }) => search.withNearVector(args)) .then((reply) => this.parseGroupByReply(opts, reply)); } } export default QueryManager.use; - +export { queryFactory } from './factories.js'; export { BaseBm25Options, BaseHybridOptions, diff --git a/src/collections/query/integration.test.ts b/src/collections/query/integration.test.ts index 965afa32..355530f5 100644 --- a/src/collections/query/integration.test.ts +++ b/src/collections/query/integration.test.ts @@ -64,7 +64,7 @@ describe('Testing of the collection.query methods with a simple collection', () }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); it('should fetch an object by its id', async () => { @@ -134,11 +134,7 @@ describe('Testing of the collection.query methods with a simple collection', () expect(ret.objects[0].uuid).toEqual(id); }); - requireAtLeast( - 1, - 31, - 0 - )('bm25 search operator (minimum_should_match)', () => { + requireAtLeast(1, 31, 0)(describe)('bm25 search operator (minimum_should_match)', () => { it('should query with bm25 + operator', async () => { const ret = await collection.query.bm25('carrot', { limit: 1, @@ -584,16 +580,25 @@ describe('Testing of the collection.query methods with a collection with a neste describe('Testing of the collection.query methods with a collection with a multiple vectors', () => { let client: WeaviateClient; - let collection: Collection; + let collection: Collection< + TestCollectionQueryWithMultiVectorProps, + 'TestCollectionQueryWithMultiVector', + TestCollectionQueryWithMultiVectorVectors + >; const collectionName = 'TestCollectionQueryWithMultiVector'; let id1: string; let id2: string; - type TestCollectionQueryWithMultiVector = { + type TestCollectionQueryWithMultiVectorProps = { title: string; }; + type TestCollectionQueryWithMultiVectorVectors = { + title: number[]; + title2: number[]; + }; + afterAll(() => { return client.collections.delete(collectionName).catch((err) => { console.error(err); @@ -606,7 +611,7 @@ describe('Testing of the collection.query methods with a collection with a multi collection = client.collections.use(collectionName); const query = () => client.collections - .create({ + .create({ name: collectionName, properties: [ { @@ -1135,7 +1140,7 @@ describe('Testing of the groupBy collection.query methods with a simple collecti }); }); const res = await collection.query.fetchObjectById(id, { includeVector: true }); - vector = res?.vectors.default!; + vector = res?.vectors.default as number[]; }); // it('should groupBy without search', async () => { diff --git a/src/collections/query/types.ts b/src/collections/query/types.ts index 5eaf0f53..b1d82d35 100644 --- a/src/collections/query/types.ts +++ b/src/collections/query/types.ts @@ -1,5 +1,5 @@ import { FilterValue } from '../filters/index.js'; -import { MultiTargetVectorJoin } from '../index.js'; +import { MultiTargetVectorJoin, ReturnVectors } from '../index.js'; import { Sorting } from '../sort/classes.js'; import { GroupByOptions, @@ -11,12 +11,12 @@ import { WeaviateObject, WeaviateReturn, } from '../types/index.js'; -import { PrimitiveKeys } from '../types/internal.js'; +import { IncludeVector, PrimitiveKeys } from '../types/internal.js'; /** Options available in the `query.fetchObjectById` method */ -export type FetchObjectByIdOptions = { +export type FetchObjectByIdOptions = { /** Whether to include the vector of the object in the response. If using named vectors, pass an array of strings to include only specific vectors. */ - includeVector?: boolean | string[]; + includeVector?: I; /** * Which properties of the object to return. Can be primitive, in which case specify their names, or nested, in which case * use the QueryNested type. If not specified, all properties are returned. @@ -27,7 +27,7 @@ export type FetchObjectByIdOptions = { }; /** Options available in the `query.fetchObjects` method */ -export type FetchObjectsOptions = { +export type FetchObjectsOptions = { /** How many objects to return in the query */ limit?: number; /** How many objects to skip in the query. Incompatible with the `after` cursor */ @@ -39,7 +39,7 @@ export type FetchObjectsOptions = { /** The sorting to be applied to the query. Use `weaviate.sort.*` to create sorting */ sort?: Sorting; /** Whether to include the vector of the object in the response. If using named vectors, pass an array of strings to include only specific vectors. */ - includeVector?: boolean | string[]; + includeVector?: I; /** Which metadata of the object to return. If not specified, no metadata is returned. */ returnMetadata?: QueryMetadata; /** @@ -52,7 +52,7 @@ export type FetchObjectsOptions = { }; /** Base options available to all the query methods that involve searching. */ -export type SearchOptions = { +export type SearchOptions = { /** How many objects to return in the query */ limit?: number; /** How many objects to skip in the query. Incompatible with the `after` cursor */ @@ -64,7 +64,7 @@ export type SearchOptions = { /** How to rerank the query results. Requires a configured [reranking](https://weaviate.io/developers/weaviate/concepts/reranking) module. */ rerank?: RerankOptions; /** Whether to include the vector of the object in the response. If using named vectors, pass an array of strings to include only specific vectors. */ - includeVector?: boolean | string[]; + includeVector?: I; /** Which metadata of the object to return. If not specified, no metadata is returned. */ returnMetadata?: QueryMetadata; /** @@ -96,20 +96,20 @@ export type Bm25SearchOptions = { }; /** Base options available in the `query.bm25` method */ -export type BaseBm25Options = SearchOptions & Bm25SearchOptions; +export type BaseBm25Options = SearchOptions & Bm25SearchOptions; /** Options available in the `query.bm25` method when specifying the `groupBy` parameter. */ -export type GroupByBm25Options = BaseBm25Options & { +export type GroupByBm25Options = BaseBm25Options & { /** The group by options to apply to the search. */ groupBy: GroupByOptions; }; /** Options available in the `query.bm25` method */ -export type Bm25Options = BaseBm25Options | GroupByBm25Options | undefined; +export type Bm25Options = BaseBm25Options | GroupByBm25Options | undefined; /** Options available to the hybrid search type only */ -export type HybridSearchOptions = { - /** The weight of the vector search score. If not specified, the default weight specified by the server is used. */ +export type HybridSearchOptions = { + /** The weight of the vector score. If not specified, the default weight specified by the server is used. */ alpha?: number; /** The type of fusion to apply. If not specified, the default fusion type specified by the server is used. */ fusionType?: 'Ranked' | 'RelativeScore'; @@ -118,14 +118,14 @@ export type HybridSearchOptions = { /** The properties to search in. If not specified, all properties are searched. */ queryProperties?: (PrimitiveKeys | Bm25QueryProperty)[]; /** Specify which vector(s) to search on if using named vectors. */ - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; /** The specific vector to search for or a specific vector subsearch. If not specified, the query is vectorized and used in the similarity search. */ vector?: NearVectorInputType | HybridNearTextSubSearch | HybridNearVectorSubSearch; bm25Operator?: Bm25OperatorOptions; }; /** Base options available in the `query.hybrid` method */ -export type BaseHybridOptions = SearchOptions & HybridSearchOptions; +export type BaseHybridOptions = SearchOptions & HybridSearchOptions; export type HybridSubSearchBase = { certainty?: number; @@ -143,28 +143,28 @@ export type HybridNearVectorSubSearch = HybridSubSearchBase & { }; /** Options available in the `query.hybrid` method when specifying the `groupBy` parameter. */ -export type GroupByHybridOptions = BaseHybridOptions & { +export type GroupByHybridOptions = BaseHybridOptions & { /** The group by options to apply to the search. */ groupBy: GroupByOptions; }; /** Options available in the `query.hybrid` method */ -export type HybridOptions = BaseHybridOptions | GroupByHybridOptions | undefined; +export type HybridOptions = BaseHybridOptions | GroupByHybridOptions | undefined; -export type NearSearchOptions = { +export type NearSearchOptions = { /** The minimum similarity score to return. Incompatible with the `distance` param. */ certainty?: number; /** The maximum distance to search. Incompatible with the `certainty` param. */ distance?: number; /** Specify which vector to search on if using named vectors. */ - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; }; /** Base options for the near search queries. */ -export type BaseNearOptions = SearchOptions & NearSearchOptions; +export type BaseNearOptions = SearchOptions & NearSearchOptions; /** Options available in the near search queries when specifying the `groupBy` parameter. */ -export type GroupByNearOptions = BaseNearOptions & { +export type GroupByNearOptions = BaseNearOptions & { /** The group by options to apply to the search. */ groupBy: GroupByOptions; }; @@ -177,25 +177,44 @@ export type MoveOptions = { }; /** Base options for the `query.nearText` method. */ -export type BaseNearTextOptions = BaseNearOptions & { +export type BaseNearTextOptions = BaseNearOptions & { moveTo?: MoveOptions; moveAway?: MoveOptions; }; /** Options available in the near text search queries when specifying the `groupBy` parameter. */ -export type GroupByNearTextOptions = BaseNearTextOptions & { +export type GroupByNearTextOptions = BaseNearTextOptions & { groupBy: GroupByOptions; }; /** The type of the media to search for in the `query.nearMedia` method */ export type NearMediaType = 'audio' | 'depth' | 'image' | 'imu' | 'thermal' | 'video'; +export type SingleVectorType = number[]; + +export type MultiVectorType = number[][]; + +/** The allowed types of primitive vectors as stored in Weaviate. + * + * These correspond to 1-dimensional vectors, created by modules named `x2vec-`, and 2-dimensional vectors, created by modules named `x2colbert-`. + */ +export type PrimitiveVectorType = SingleVectorType | MultiVectorType; + +export type ListOfVectors = { + kind: 'listOfVectors'; + dimensionality: '1D' | '2D'; + vectors: V[]; +}; + /** * The vector(s) to search for in the `query/generate.nearVector` and `query/generate.hybrid` methods. One of: - * - a single vector, in which case pass a single number array. - * - multiple named vectors, in which case pass an object of type `Record`. + * - a single 1-dimensional vector, in which case pass a single number array. + * - a single 2-dimensional vector, in which case pas a single array of number arrays. + * - multiple named vectors, in which case pass an object of type `Record`. */ -export type NearVectorInputType = number[] | Record; +export type NearVectorInputType = + | PrimitiveVectorType + | Record | ListOfVectors>; /** * Over which vector spaces to perform the vector search query in the `nearX` search method. One of: @@ -203,9 +222,11 @@ export type NearVectorInputType = number[] | Record = TargetVector | TargetVector[] | MultiTargetVectorJoin; + +export type TargetVector = V extends undefined ? string : keyof V & string; -interface Bm25 { +interface Bm25 { /** * Search for objects in this collection using the keyword-based BM25 algorithm. * @@ -213,11 +234,16 @@ interface Bm25 { * * This overload is for performing a search without the `groupBy` param. * + * @typeParam I - The vector(s) to include in the response. If using named vectors, pass an array of strings to include only specific vectors. + * @typeParam RV - The vectors(s) to be returned in the response depending on the input in opts.includeVector. * @param {string} query - The query to search for. - * @param {BaseBm25Options} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseBm25Options} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>>} - The result of the search within the fetched collection. */ - bm25(query: string, opts?: BaseBm25Options): Promise>; + bm25, RV extends ReturnVectors>( + query: string, + opts?: BaseBm25Options + ): Promise>; /** * Search for objects in this collection using the keyword-based BM25 algorithm. * @@ -225,11 +251,16 @@ interface Bm25 { * * This overload is for performing a search with the `groupBy` param. * + * @typeParam I - The vector(s) to include in the response. If using named vectors, pass an array of strings to include only specific vectors. + * @typeParam RV - The vectors(s) to be returned in the response depending on the input in opts.includeVector. * @param {string} query - The query to search for. - * @param {GroupByBm25Options} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {GroupByBm25Options} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - bm25(query: string, opts: GroupByBm25Options): Promise>; + bm25, RV extends ReturnVectors>( + query: string, + opts: GroupByBm25Options + ): Promise>; /** * Search for objects in this collection using the keyword-based BM25 algorithm. * @@ -237,14 +268,19 @@ interface Bm25 { * * This overload is for performing a search with a programmatically defined `opts` param. * + * @typeParam I - The vector(s) to include in the response. If using named vectors, pass an array of strings to include only specific vectors. + * @typeParam RV - The vectors(s) to be returned in the response depending on the input in opts.includeVector. * @param {string} query - The query to search for. - * @param {Bm25Options} [opts] - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {Bm25Options} [opts] - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - bm25(query: string, opts?: Bm25Options): QueryReturn; + bm25, RV extends ReturnVectors>( + query: string, + opts?: Bm25Options + ): QueryReturn; } -interface Hybrid { +interface Hybrid { /** * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -252,11 +288,16 @@ interface Hybrid { * * This overload is for performing a search without the `groupBy` param. * + * @typeParam I - The vector(s) to include in the response. If using named vectors, pass an array of strings to include only specific vectors. + * @typeParam RV - The vectors(s) to be returned in the response depending on the input in opts.includeVector. * @param {string} query - The query to search for in the BM25 keyword search.. - * @param {BaseHybridOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseHybridOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - hybrid(query: string, opts?: BaseHybridOptions): Promise>; + hybrid, RV extends ReturnVectors>( + query: string, + opts?: BaseHybridOptions + ): Promise>; /** * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -264,11 +305,16 @@ interface Hybrid { * * This overload is for performing a search with the `groupBy` param. * + * @typeParam I - The vector(s) to include in the response. If using named vectors, pass an array of strings to include only specific vectors. + * @typeParam RV - The vectors(s) to be returned in the response depending on the input in opts.includeVector. * @param {string} query - The query to search for in the BM25 keyword search.. - * @param {GroupByHybridOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {GroupByHybridOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - hybrid(query: string, opts: GroupByHybridOptions): Promise>; + hybrid, RV extends ReturnVectors>( + query: string, + opts: GroupByHybridOptions + ): Promise>; /** * Search for objects in this collection using the hybrid algorithm blending keyword-based BM25 and vector-based similarity. * @@ -277,13 +323,16 @@ interface Hybrid { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string} query - The query to search for in the BM25 keyword search.. - * @param {HybridOptions} [opts] - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {HybridOptions} [opts] - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - hybrid(query: string, opts?: HybridOptions): QueryReturn; + hybrid, RV extends ReturnVectors>( + query: string, + opts?: HybridOptions + ): QueryReturn; } -interface NearImage { +interface NearImage { /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have an image-capable vectorization module installed in order to use this method, @@ -294,10 +343,13 @@ interface NearImage { * This overload is for performing a search without the `groupBy` param. * * @param {string | Buffer} image - The image to search on. This can be a base64 string, a file path string, or a buffer. - * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - nearImage(image: string | Buffer, opts?: BaseNearOptions): Promise>; + nearImage, RV extends ReturnVectors>( + image: string | Buffer, + opts?: BaseNearOptions + ): Promise>; /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have an image-capable vectorization module installed in order to use this method, @@ -308,10 +360,13 @@ interface NearImage { * This overload is for performing a search with the `groupBy` param. * * @param {string | Buffer} image - The image to search on. This can be a base64 string, a file path string, or a buffer. - * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ - nearImage(image: string | Buffer, opts: GroupByNearOptions): Promise>; + nearImage, RV extends ReturnVectors>( + image: string | Buffer, + opts: GroupByNearOptions + ): Promise>; /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have an image-capable vectorization module installed in order to use this method, @@ -322,13 +377,16 @@ interface NearImage { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string | Buffer} image - The image to search on. This can be a base64 string, a file path string, or a buffer. - * @param {NearOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearImage(image: string | Buffer, opts?: NearOptions): QueryReturn; + nearImage, RV extends ReturnVectors>( + image: string | Buffer, + opts?: NearOptions + ): QueryReturn; } -interface NearMedia { +interface NearMedia { /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have a multi-media-capable vectorization module installed in order to use this method, e.g. `multi2vec-bind` or `multi2vec-palm`. @@ -339,14 +397,14 @@ interface NearMedia { * * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search for, e.g. 'audio'. - * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - nearMedia( + nearMedia, RV extends ReturnVectors>( media: string | Buffer, type: NearMediaType, - opts?: BaseNearOptions - ): Promise>; + opts?: BaseNearOptions + ): Promise>; /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have a multi-media-capable vectorization module installed in order to use this method, e.g. `multi2vec-bind` or `multi2vec-palm`. @@ -357,14 +415,14 @@ interface NearMedia { * * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search for, e.g. 'audio'. - * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ - nearMedia( + nearMedia, RV extends ReturnVectors>( media: string | Buffer, type: NearMediaType, - opts: GroupByNearOptions - ): Promise>; + opts: GroupByNearOptions + ): Promise>; /** * Search for objects by image in this collection using an image-capable vectorization module and vector-based similarity search. * You must have a multi-media-capable vectorization module installed in order to use this method, e.g. `multi2vec-bind` or `multi2vec-palm`. @@ -375,13 +433,17 @@ interface NearMedia { * * @param {string | Buffer} media - The media to search on. This can be a base64 string, a file path string, or a buffer. * @param {NearMediaType} type - The type of media to search for, e.g. 'audio'. - * @param {NearOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearMedia(media: string | Buffer, type: NearMediaType, opts?: NearOptions): QueryReturn; + nearMedia, RV extends ReturnVectors>( + media: string | Buffer, + type: NearMediaType, + opts?: NearOptions + ): QueryReturn; } -interface NearObject { +interface NearObject { /** * Search for objects in this collection by another object using a vector-based similarity search. * @@ -390,10 +452,13 @@ interface NearObject { * This overload is for performing a search without the `groupBy` param. * * @param {string} id - The UUID of the object to search for. - * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - nearObject(id: string, opts?: BaseNearOptions): Promise>; + nearObject, RV extends ReturnVectors>( + id: string, + opts?: BaseNearOptions + ): Promise>; /** * Search for objects in this collection by another object using a vector-based similarity search. * @@ -402,10 +467,13 @@ interface NearObject { * This overload is for performing a search with the `groupBy` param. * * @param {string} id - The UUID of the object to search for. - * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ - nearObject(id: string, opts: GroupByNearOptions): Promise>; + nearObject, RV extends ReturnVectors>( + id: string, + opts: GroupByNearOptions + ): Promise>; /** * Search for objects in this collection by another object using a vector-based similarity search. * @@ -414,13 +482,16 @@ interface NearObject { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {number[]} id - The UUID of the object to search for. - * @param {NearOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearObject(id: string, opts?: NearOptions): QueryReturn; + nearObject, RV extends ReturnVectors>( + id: string, + opts?: NearOptions + ): QueryReturn; } -interface NearText { +interface NearText { /** * Search for objects in this collection by text using text-capable vectorization module and vector-based similarity search. * You must have a text-capable vectorization module installed in order to use this method, @@ -431,10 +502,13 @@ interface NearText { * This overload is for performing a search without the `groupBy` param. * * @param {string | string[]} query - The text query to search for. - * @param {BaseNearTextOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearTextOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - nearText(query: string | string[], opts?: BaseNearTextOptions): Promise>; + nearText, RV extends ReturnVectors>( + query: string | string[], + opts?: BaseNearTextOptions + ): Promise>; /** * Search for objects in this collection by text using text-capable vectorization module and vector-based similarity search. * You must have a text-capable vectorization module installed in order to use this method, @@ -445,10 +519,13 @@ interface NearText { * This overload is for performing a search with the `groupBy` param. * * @param {string | string[]} query - The text query to search for. - * @param {GroupByNearTextOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearTextOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ - nearText(query: string | string[], opts: GroupByNearTextOptions): Promise>; + nearText, RV extends ReturnVectors>( + query: string | string[], + opts: GroupByNearTextOptions + ): Promise>; /** * Search for objects in this collection by text using text-capable vectorization module and vector-based similarity search. * You must have a text-capable vectorization module installed in order to use this method, @@ -459,13 +536,16 @@ interface NearText { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {string | string[]} query - The text query to search for. - * @param {NearTextOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearTextOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearText(query: string | string[], opts?: NearTextOptions): QueryReturn; + nearText, RV extends ReturnVectors>( + query: string | string[], + opts?: NearTextOptions + ): QueryReturn; } -interface NearVector { +interface NearVector { /** * Search for objects by vector in this collection using a vector-based similarity search. * @@ -474,10 +554,13 @@ interface NearVector { * This overload is for performing a search without the `groupBy` param. * * @param {NearVectorInputType} vector - The vector(s) to search on. - * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. - * @returns {Promise>} - The result of the search within the fetched collection. + * @param {BaseNearOptions} [opts] - The available options for the search excluding the `groupBy` param. + * @returns {Promise>} - The result of the search within the fetched collection. */ - nearVector(vector: NearVectorInputType, opts?: BaseNearOptions): Promise>; + nearVector, RV extends ReturnVectors>( + vector: NearVectorInputType, + opts?: BaseNearOptions + ): Promise>; /** * Search for objects by vector in this collection using a vector-based similarity search. * @@ -486,10 +569,13 @@ interface NearVector { * This overload is for performing a search with the `groupBy` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. - * @returns {Promise>} - The group by result of the search within the fetched collection. + * @param {GroupByNearOptions} opts - The available options for the search including the `groupBy` param. + * @returns {Promise>} - The group by result of the search within the fetched collection. */ - nearVector(vector: NearVectorInputType, opts: GroupByNearOptions): Promise>; + nearVector, RV extends ReturnVectors>( + vector: NearVectorInputType, + opts: GroupByNearOptions + ): Promise>; /** * Search for objects by vector in this collection using a vector-based similarity search. * @@ -498,43 +584,54 @@ interface NearVector { * This overload is for performing a search with a programmatically defined `opts` param. * * @param {NearVectorInputType} vector - The vector(s) to search for. - * @param {NearOptions} [opts] - The available options for the search. - * @returns {QueryReturn} - The result of the search within the fetched collection. + * @param {NearOptions} [opts] - The available options for the search. + * @returns {QueryReturn} - The result of the search within the fetched collection. */ - nearVector(vector: NearVectorInputType, opts?: NearOptions): QueryReturn; + nearVector, RV extends ReturnVectors>( + vector: NearVectorInputType, + opts?: NearOptions + ): QueryReturn; } /** All the available methods on the `.query` namespace. */ -export interface Query - extends Bm25, - Hybrid, - NearImage, - NearMedia, - NearObject, - NearText, - NearVector { +export interface Query + extends Bm25, + Hybrid, + NearImage, + NearMedia, + NearObject, + NearText, + NearVector { /** * Retrieve an object from the server by its UUID. * * @param {string} id - The UUID of the object to retrieve. * @param {FetchObjectByIdOptions} [opts] - The available options for fetching the object. - * @returns {Promise | null>} - The object with the given UUID, or null if it does not exist. + * @returns {Promise | null>} - The object with the given UUID, or null if it does not exist. */ - fetchObjectById: (id: string, opts?: FetchObjectByIdOptions) => Promise | null>; + fetchObjectById: >( + id: string, + opts?: FetchObjectByIdOptions + ) => Promise> | null>; /** * Retrieve objects from the server without searching. * * @param {FetchObjectsOptions} [opts] - The available options for fetching the objects. - * @returns {Promise>} - The objects within the fetched collection. + * @returns {Promise>} - The objects within the fetched collection. */ - fetchObjects: (opts?: FetchObjectsOptions) => Promise>; + fetchObjects: >( + opts?: FetchObjectsOptions + ) => Promise>>; } /** Options available in the `query.nearImage`, `query.nearMedia`, `query.nearObject`, and `query.nearVector` methods */ -export type NearOptions = BaseNearOptions | GroupByNearOptions | undefined; +export type NearOptions = BaseNearOptions | GroupByNearOptions | undefined; /** Options available in the `query.nearText` method */ -export type NearTextOptions = BaseNearTextOptions | GroupByNearTextOptions | undefined; +export type NearTextOptions = + | BaseNearTextOptions + | GroupByNearTextOptions + | undefined; /** The return type of the `query` methods. It is a union of a standard query and a group by query due to function overloading. */ -export type QueryReturn = Promise> | Promise>; +export type QueryReturn = Promise> | Promise>; diff --git a/src/collections/query/utils.ts b/src/collections/query/utils.ts index 0450679d..a46334da 100644 --- a/src/collections/query/utils.ts +++ b/src/collections/query/utils.ts @@ -1,14 +1,46 @@ -import { MultiTargetVectorJoin } from '../index.js'; -import { Bm25OperatorOptions, Bm25OperatorOr, NearVectorInputType, TargetVectorInputType } from './types.js'; +import { MultiTargetVectorJoin, Vectors } from '../index.js'; +import { + Bm25OperatorOptions, + Bm25OperatorOr, + ListOfVectors, + MultiVectorType, + NearVectorInputType, + PrimitiveVectorType, + SingleVectorType, + TargetVectorInputType, +} from './types.js'; export class NearVectorInputGuards { - public static is1DArray(input: NearVectorInputType): input is number[] { + public static is1D(input: NearVectorInputType): input is SingleVectorType { return Array.isArray(input) && input.length > 0 && !Array.isArray(input[0]); } - public static isObject(input: NearVectorInputType): input is Record { + public static is2D(input: NearVectorInputType): input is MultiVectorType { + return Array.isArray(input) && input.length > 0 && Array.isArray(input[0]) && input[0].length > 0; + } + + public static isObject( + input: NearVectorInputType + ): input is Record< + string, + PrimitiveVectorType | ListOfVectors | ListOfVectors + > { return !Array.isArray(input); } + + public static isListOf1D( + input: PrimitiveVectorType | ListOfVectors | ListOfVectors + ): input is ListOfVectors { + const i = input as ListOfVectors; + return !Array.isArray(input) && i.kind === 'listOfVectors' && i.dimensionality == '1D'; + } + + public static isListOf2D( + input: PrimitiveVectorType | ListOfVectors | ListOfVectors + ): input is ListOfVectors { + const i = input as ListOfVectors; + return !Array.isArray(input) && i.kind === 'listOfVectors' && i.dimensionality == '2D'; + } } export class ArrayInputGuards { @@ -21,16 +53,16 @@ export class ArrayInputGuards { } export class TargetVectorInputGuards { - public static isSingle(input: TargetVectorInputType): input is string { + public static isSingle(input: TargetVectorInputType): input is string { return typeof input === 'string'; } - public static isMulti(input: TargetVectorInputType): input is string[] { + public static isMulti(input: TargetVectorInputType): input is string[] { return Array.isArray(input); } - public static isMultiJoin(input: TargetVectorInputType): input is MultiTargetVectorJoin { - const i = input as MultiTargetVectorJoin; + public static isMultiJoin(input: TargetVectorInputType): input is MultiTargetVectorJoin { + const i = input as MultiTargetVectorJoin; return i.combination !== undefined && i.targetVectors !== undefined; } } diff --git a/src/collections/references/classes.ts b/src/collections/references/classes.ts index 0200b7dc..b762d336 100644 --- a/src/collections/references/classes.ts +++ b/src/collections/references/classes.ts @@ -1,13 +1,19 @@ -import { Properties, ReferenceInput, ReferenceToMultiTarget, WeaviateObject } from '../types/index.js'; +import { + Properties, + ReferenceInput, + ReferenceToMultiTarget, + Vectors, + WeaviateObject, +} from '../types/index.js'; import { Beacon } from './types.js'; import { uuidToBeacon } from './utils.js'; export class ReferenceManager { - public objects: WeaviateObject[]; + public objects: WeaviateObject[]; public targetCollection: string; public uuids?: string[]; - constructor(targetCollection: string, objects?: WeaviateObject[], uuids?: string[]) { + constructor(targetCollection: string, objects?: WeaviateObject[], uuids?: string[]) { this.objects = objects ?? []; this.targetCollection = targetCollection; this.uuids = uuids; diff --git a/src/collections/serialize/index.ts b/src/collections/serialize/index.ts index ddebb8e4..d9df1847 100644 --- a/src/collections/serialize/index.ts +++ b/src/collections/serialize/index.ts @@ -92,8 +92,11 @@ import { ObjectPropertiesValue, TextArray, TextArrayProperties, + Vectors, Vectors as VectorsGrpc, + Vectors_VectorType, } from '../../proto/v1/base.js'; +import { yieldToEventLoop } from '../../utils/yield.js'; import { FilterId } from '../filters/classes.js'; import { FilterValue, Filters } from '../filters/index.js'; import { @@ -130,10 +133,14 @@ import { HybridNearVectorSubSearch, HybridOptions, HybridSearchOptions, + ListOfVectors, + MultiVectorType, NearOptions, NearTextOptions, NearVectorInputType, + PrimitiveVectorType, SearchOptions, + SingleVectorType, TargetVectorInputType, } from '../query/types.js'; import { ArrayInputGuards, NearVectorInputGuards, TargetVectorInputGuards } from '../query/utils.js'; @@ -410,26 +417,27 @@ class Aggregate { }); }; - public static hybrid = ( + public static hybrid = async ( query: string, - opts?: AggregateHybridOptions> - ): AggregateHybridArgs => { + opts?: AggregateHybridOptions, V> + ): Promise => { return { ...Aggregate.common(opts), objectLimit: opts?.objectLimit, - hybrid: Serialize.hybridSearch({ + hybrid: await Serialize.hybridSearch({ query: query, supportsTargets: true, supportsVectorsForTargets: true, supportsWeightsForTargets: true, + supportsVectors: true, ...opts, }), }; }; - public static nearImage = ( + public static nearImage = ( image: string, - opts?: AggregateNearOptions> + opts?: AggregateNearOptions, V> ): AggregateNearImageArgs => { return { ...Aggregate.common(opts), @@ -443,9 +451,9 @@ class Aggregate { }; }; - public static nearObject = ( + public static nearObject = ( id: string, - opts?: AggregateNearOptions> + opts?: AggregateNearOptions, V> ): AggregateNearObjectArgs => { return { ...Aggregate.common(opts), @@ -459,9 +467,9 @@ class Aggregate { }; }; - public static nearText = ( + public static nearText = ( query: string | string[], - opts?: AggregateNearOptions> + opts?: AggregateNearOptions, V> ): AggregateNearTextArgs => { return { ...Aggregate.common(opts), @@ -475,18 +483,19 @@ class Aggregate { }; }; - public static nearVector = ( + public static nearVector = async ( vector: NearVectorInputType, - opts?: AggregateNearOptions> - ): AggregateNearVectorArgs => { + opts?: AggregateNearOptions, V> + ): Promise => { return { ...Aggregate.common(opts), objectLimit: opts?.objectLimit, - nearVector: Serialize.nearVectorSearch({ + nearVector: await Serialize.nearVectorSearch({ vector, supportsTargets: true, supportsVectorsForTargets: true, supportsWeightsForTargets: true, + supportsVectors: true, ...opts, }), }; @@ -550,10 +559,7 @@ class Search { }; }; - private static metadata = ( - includeVector?: boolean | string[], - metadata?: QueryMetadata - ): MetadataRequest => { + private static metadata = (includeVector?: I, metadata?: QueryMetadata): MetadataRequest => { const out: any = { uuid: true, vector: typeof includeVector === 'boolean' ? includeVector : false, @@ -614,7 +620,7 @@ class Search { return args.groupBy !== undefined; }; - private static common = (args?: SearchOptions): BaseSearchArgs => { + private static common = (args?: SearchOptions): BaseSearchArgs => { const out: BaseSearchArgs = { autocut: args?.autoLimit, limit: args?.limit, @@ -632,15 +638,15 @@ class Search { return out; }; - public static bm25 = (query: string, opts?: Bm25Options): SearchBm25Args => { + public static bm25 = (query: string, opts?: Bm25Options): SearchBm25Args => { return { ...Search.common(opts), bm25Search: Serialize.bm25Search({ query, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static fetchObjects = (args?: FetchObjectsOptions): SearchFetchArgs => { + public static fetchObjects = (args?: FetchObjectsOptions): SearchFetchArgs => { return { ...Search.common(args), after: args?.after, @@ -648,7 +654,9 @@ class Search { }; }; - public static fetchObjectById = (args: { id: string } & FetchObjectByIdOptions): SearchFetchArgs => { + public static fetchObjectById = ( + args: { id: string } & FetchObjectByIdOptions + ): SearchFetchArgs => { return Search.common({ filters: new FilterId().equal(args.id), includeVector: args.includeVector, @@ -658,142 +666,146 @@ class Search { }); }; - public static hybrid = ( + public static hybrid = async ( args: { query: string; supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; }, - opts?: HybridOptions - ): SearchHybridArgs => { + opts?: HybridOptions + ): Promise => { return { - ...Search.common(opts), - hybridSearch: Serialize.hybridSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + ...Search.common(opts), + hybridSearch: await Serialize.hybridSearch({ ...args, ...opts }), + groupBy: Search.isGroupBy>(opts) + ? Search.groupBy(opts.groupBy) + : undefined, }; }; - public static nearAudio = ( + public static nearAudio = ( args: { audio: string; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearOptions + opts?: NearOptions ): SearchNearAudioArgs => { return { ...Search.common(opts), nearAudio: Serialize.nearAudioSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearDepth = ( + public static nearDepth = ( args: { depth: string; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearOptions + opts?: NearOptions ): SearchNearDepthArgs => { return { ...Search.common(opts), nearDepth: Serialize.nearDepthSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearImage = ( + public static nearImage = ( args: { image: string; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearOptions + opts?: NearOptions ): SearchNearImageArgs => { return { ...Search.common(opts), nearImage: Serialize.nearImageSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearIMU = ( + public static nearIMU = ( args: { imu: string; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearOptions + opts?: NearOptions ): SearchNearIMUArgs => { return { ...Search.common(opts), nearIMU: Serialize.nearIMUSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearObject = ( + public static nearObject = ( args: { id: string; supportsTargets: boolean; supportsWeightsForTargets: boolean }, - opts?: NearOptions + opts?: NearOptions ): SearchNearObjectArgs => { return { ...Search.common(opts), nearObject: Serialize.nearObjectSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearText = ( + public static nearText = ( args: { query: string | string[]; supportsTargets: boolean; supportsWeightsForTargets: boolean; }, - opts?: NearTextOptions + opts?: NearTextOptions ): SearchNearTextArgs => { return { ...Search.common(opts), nearText: Serialize.nearTextSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearThermal = ( + public static nearThermal = ( args: { thermal: string; supportsTargets: boolean; supportsWeightsForTargets: boolean }, - opts?: NearOptions + opts?: NearOptions ): SearchNearThermalArgs => { return { ...Search.common(opts), nearThermal: Serialize.nearThermalSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearVector = ( + public static nearVector = async ( args: { vector: NearVectorInputType; supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; }, - opts?: NearOptions - ): SearchNearVectorArgs => { + opts?: NearOptions + ): Promise => { return { ...Search.common(opts), - nearVector: Serialize.nearVectorSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + nearVector: await Serialize.nearVectorSearch({ ...args, ...opts }), + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; - public static nearVideo = ( + public static nearVideo = ( args: { video: string; supportsTargets: boolean; supportsWeightsForTargets: boolean }, - opts?: NearOptions + opts?: NearOptions ): SearchNearVideoArgs => { return { ...Search.common(opts), nearVideo: Serialize.nearVideoSearch({ ...args, ...opts }), - groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, + groupBy: Search.isGroupBy>(opts) ? Search.groupBy(opts.groupBy) : undefined, }; }; } @@ -802,15 +814,15 @@ export class Serialize { static aggregate = Aggregate; static search = Search; - public static isNamedVectors = (opts?: BaseNearOptions): boolean => { + public static isNamedVectors = (opts?: BaseNearOptions): boolean => { return Array.isArray(opts?.includeVector) || opts?.targetVector !== undefined; }; - public static isMultiTarget = (opts?: BaseNearOptions): boolean => { + public static isMultiTarget = (opts?: BaseNearOptions): boolean => { return opts?.targetVector !== undefined && !TargetVectorInputGuards.isSingle(opts.targetVector); }; - public static isMultiWeightPerTarget = (opts?: BaseNearOptions): boolean => { + public static isMultiWeightPerTarget = (opts?: BaseNearOptions): boolean => { return ( opts?.targetVector !== undefined && TargetVectorInputGuards.isMultiJoin(opts.targetVector) && @@ -986,9 +998,14 @@ export class Serialize { }); }; - public static isHybridVectorSearch = ( - vector: BaseHybridOptions['vector'] - ): vector is number[] | Record => { + public static isHybridVectorSearch = ( + vector: BaseHybridOptions['vector'] + ): vector is + | PrimitiveVectorType + | Record< + string, + PrimitiveVectorType | ListOfVectors | ListOfVectors + > => { return ( vector !== undefined && !Serialize.isHybridNearTextSearch(vector) && @@ -996,40 +1013,46 @@ export class Serialize { ); }; - public static isHybridNearTextSearch = ( - vector: BaseHybridOptions['vector'] + public static isHybridNearTextSearch = ( + vector: BaseHybridOptions['vector'] ): vector is HybridNearTextSubSearch => { return (vector as HybridNearTextSubSearch)?.query !== undefined; }; - public static isHybridNearVectorSearch = ( - vector: BaseHybridOptions['vector'] + public static isHybridNearVectorSearch = ( + vector: BaseHybridOptions['vector'] ): vector is HybridNearVectorSubSearch => { return (vector as HybridNearVectorSubSearch)?.vector !== undefined; }; - private static hybridVector = (args: { + private static hybridVector = async (args: { supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; - vector?: BaseHybridOptions['vector']; + supportsVectors: boolean; + vector?: BaseHybridOptions['vector']; }) => { const vector = args.vector; if (Serialize.isHybridVectorSearch(vector)) { - const { targets, targetVectors, vectorBytes, vectorPerTarget, vectorForTargets } = Serialize.vectors({ - ...args, - argumentName: 'vector', - vector: vector, - }); + const { targets, targetVectors, vectorBytes, vectorPerTarget, vectorForTargets, vectors } = + await Serialize.vectors({ + ...args, + argumentName: 'vector', + vector: vector, + }); return vectorBytes !== undefined ? { vectorBytes, targetVectors, targets } : { targetVectors, targets, - nearVector: NearVector.fromPartial({ - vectorForTargets, - vectorPerTarget, - }), + nearVector: + vectorForTargets != undefined || vectorPerTarget != undefined + ? NearVector.fromPartial({ + vectorForTargets, + vectorPerTarget, + }) + : undefined, + vectors, }; } else if (Serialize.isHybridNearTextSearch(vector)) { const { targetVectors, targets } = Serialize.targetVector(args); @@ -1045,11 +1068,12 @@ export class Serialize { }), }; } else if (Serialize.isHybridNearVectorSearch(vector)) { - const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets } = Serialize.vectors({ - ...args, - argumentName: 'vector', - vector: vector.vector, - }); + const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets, vectors } = + await Serialize.vectors({ + ...args, + argumentName: 'vector', + vector: vector.vector, + }); return { targetVectors, targets, @@ -1059,6 +1083,7 @@ export class Serialize { vectorBytes, vectorPerTarget, vectorForTargets, + vectors, }), }; } else { @@ -1067,14 +1092,15 @@ export class Serialize { } }; - public static hybridSearch = ( + public static hybridSearch = async ( args: { query: string; supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; - } & HybridSearchOptions - ): Hybrid => { + supportsVectors: boolean; + } & HybridSearchOptions + ): Promise => { const fusionType = (fusionType?: string): Hybrid_FusionType => { switch (fusionType) { case 'Ranked': @@ -1085,7 +1111,8 @@ export class Serialize { return Hybrid_FusionType.FUSION_TYPE_UNSPECIFIED; } }; - const { targets, targetVectors, vectorBytes, nearText, nearVector } = Serialize.hybridVector(args); + const { targets, targetVectors, vectorBytes, nearText, nearVector, vectors } = + await Serialize.hybridVector(args); return Hybrid.fromPartial({ query: args.query, alpha: args.alpha ? args.alpha : 0.5, @@ -1098,11 +1125,16 @@ export class Serialize { targets, nearText, nearVector, + vectors, }); }; - public static nearAudioSearch = ( - args: { audio: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearAudioSearch = ( + args: { audio: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions< + T, + V, + I + > ): NearAudioSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearAudioSearch.fromPartial({ @@ -1114,8 +1146,12 @@ export class Serialize { }); }; - public static nearDepthSearch = ( - args: { depth: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearDepthSearch = ( + args: { depth: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions< + T, + V, + I + > ): NearDepthSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearDepthSearch.fromPartial({ @@ -1127,8 +1163,12 @@ export class Serialize { }); }; - public static nearImageSearch = ( - args: { image: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearImageSearch = ( + args: { image: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions< + T, + V, + I + > ): NearImageSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearImageSearch.fromPartial({ @@ -1140,8 +1180,8 @@ export class Serialize { }); }; - public static nearIMUSearch = ( - args: { imu: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearIMUSearch = ( + args: { imu: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions ): NearIMUSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearIMUSearch.fromPartial({ @@ -1153,8 +1193,8 @@ export class Serialize { }); }; - public static nearObjectSearch = ( - args: { id: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearObjectSearch = ( + args: { id: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions ): NearObject => { const { targets, targetVectors } = Serialize.targetVector(args); return NearObject.fromPartial({ @@ -1166,11 +1206,11 @@ export class Serialize { }); }; - public static nearTextSearch = (args: { + public static nearTextSearch = (args: { query: string | string[]; supportsTargets: boolean; supportsWeightsForTargets: boolean; - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; certainty?: number; distance?: number; moveAway?: { concepts?: string[]; force?: number; objects?: string[] }; @@ -1200,8 +1240,12 @@ export class Serialize { }); }; - public static nearThermalSearch = ( - args: { thermal: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearThermalSearch = ( + args: { thermal: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions< + T, + V, + I + > ): NearThermalSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearThermalSearch.fromPartial({ @@ -1213,38 +1257,71 @@ export class Serialize { }); }; + private static vectorToBuffer = (vector: number[]): ArrayBufferLike => { + return new Float32Array(vector).buffer; + }; + private static vectorToBytes = (vector: number[]): Uint8Array => { - return new Uint8Array(new Float32Array(vector).buffer); + const uint32len = 4; + const dv = new DataView(new ArrayBuffer(vector.length * uint32len)); + vector.forEach((v, i) => dv.setFloat32(i * uint32len, v, true)); + return new Uint8Array(dv.buffer); + }; + + /** + * Convert a 2D array of numbers to a Uint8Array + * + * Defined as an async method so that control can be relinquished back to the event loop on each outer loop for large vectors + */ + private static vectorsToBytes = async (vectors: number[][]): Promise => { + if (vectors.length === 0) { + return new Uint8Array(); + } + if (vectors[0].length === 0) { + return new Uint8Array(); + } + + const uint16Len = 2; + const uint32len = 4; + const dim = vectors[0].length; + + const dv = new DataView(new ArrayBuffer(uint16Len + vectors.length * dim * uint32len)); + dv.setUint16(0, dim, true); + dv.setUint16(uint16Len, vectors.length, true); + await Promise.all( + vectors.map((vector, i) => + yieldToEventLoop().then(() => + vector.forEach((v, j) => dv.setFloat32(uint16Len + i * dim * uint32len + j * uint32len, v, true)) + ) + ) + ); + + return new Uint8Array(dv.buffer); }; - public static nearVectorSearch = (args: { + public static nearVectorSearch = async (args: { vector: NearVectorInputType; supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; certainty?: number; distance?: number; - targetVector?: TargetVectorInputType; - }): NearVector => { - const { targetVectors, targets, vectorBytes, vectorPerTarget, vectorForTargets } = Serialize.vectors({ - ...args, - argumentName: 'nearVector', - }); - return NearVector.fromPartial({ + targetVector?: TargetVectorInputType; + }): Promise => + NearVector.fromPartial({ certainty: args.certainty, distance: args.distance, - targetVectors, - targets, - vectorPerTarget, - vectorBytes, - vectorForTargets, + ...(await Serialize.vectors({ + ...args, + argumentName: 'nearVector', + })), }); - }; - public static targetVector = (args: { + public static targetVector = (args: { supportsTargets: boolean; supportsWeightsForTargets: boolean; - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; }): { targets?: Targets; targetVectors?: string[] } => { if (args.targetVector === undefined) { return {}; @@ -1269,20 +1346,22 @@ export class Serialize { } }; - private static vectors = (args: { + static vectors = async (args: { supportsTargets: boolean; supportsVectorsForTargets: boolean; supportsWeightsForTargets: boolean; + supportsVectors: boolean; argumentName: 'nearVector' | 'vector'; - targetVector?: TargetVectorInputType; + targetVector?: TargetVectorInputType; vector?: NearVectorInputType; - }): { + }): Promise<{ targetVectors?: string[]; targets?: Targets; vectorBytes?: Uint8Array; + vectors?: Vectors[]; vectorPerTarget?: Record; vectorForTargets?: VectorForTarget[]; - } => { + }> => { const invalidVectorError = new WeaviateInvalidInputError(`${args.argumentName} argument must be populated and: - an array of numbers (number[]) @@ -1296,38 +1375,16 @@ export class Serialize { if (Object.keys(args.vector).length === 0) { throw invalidVectorError; } - if (args.supportsVectorsForTargets) { - const vectorForTargets: VectorForTarget[] = Object.entries(args.vector) - .map(([target, vector]) => { - return { - target, - vector: vector, - }; - }) - .reduce((acc, { target, vector }) => { - return ArrayInputGuards.is2DArray(vector) - ? acc.concat( - vector.map((v) => ({ name: target, vectorBytes: Serialize.vectorToBytes(v), vectors: [] })) - ) - : acc.concat([{ name: target, vectorBytes: Serialize.vectorToBytes(vector), vectors: [] }]); - }, [] as VectorForTarget[]); - return args.targetVector !== undefined - ? { - ...Serialize.targetVector(args), - vectorForTargets, - } - : { - targetVectors: undefined, - targets: Targets.fromPartial({ - targetVectors: vectorForTargets.map((v) => v.name), - }), - vectorForTargets, - }; - } else { + if (!args.supportsVectorsForTargets) { const vectorPerTarget: Record = {}; Object.entries(args.vector).forEach(([k, v]) => { if (ArrayInputGuards.is2DArray(v)) { - return; + throw new WeaviateUnsupportedFeatureError('Multi-vectors are not supported in Weaviate <1.29.0'); + } + if (NearVectorInputGuards.isListOf1D(v) || NearVectorInputGuards.isListOf2D(v)) { + throw new WeaviateUnsupportedFeatureError( + 'Lists of vectors are not supported in Weaviate <1.29.0' + ); } vectorPerTarget[k] = Serialize.vectorToBytes(v); }); @@ -1352,25 +1409,119 @@ export class Serialize { }; } } - } else { - if (args.vector.length === 0) { - throw invalidVectorError; - } - if (NearVectorInputGuards.is1DArray(args.vector)) { - const { targetVectors, targets } = Serialize.targetVector(args); - const vectorBytes = Serialize.vectorToBytes(args.vector); - return { - targetVectors, - targets, - vectorBytes, + const vectorForTargets: VectorForTarget[] = []; + for (const [target, vector] of Object.entries(args.vector)) { + if (!args.supportsVectors) { + if (NearVectorInputGuards.isListOf2D(vector)) { + throw new WeaviateUnsupportedFeatureError( + 'Lists of multi-vectors are not supported in Weaviate <1.29.0' + ); + } + if (ArrayInputGuards.is2DArray(vector)) { + vector.forEach((v) => + vectorForTargets.push({ name: target, vectorBytes: Serialize.vectorToBytes(v), vectors: [] }) + ); + continue; + } + if (NearVectorInputGuards.isListOf1D(vector)) { + vector.vectors.forEach((v) => + vectorForTargets.push({ + name: target, + vectorBytes: Serialize.vectorToBytes(v), + vectors: [], + }) + ); + continue; + } + vectorForTargets.push({ name: target, vectorBytes: Serialize.vectorToBytes(vector), vectors: [] }); + continue; + } + const vectorForTarget: VectorForTarget = { + name: target, + vectorBytes: new Uint8Array(), + vectors: [], }; + if (NearVectorInputGuards.isListOf1D(vector)) { + vectorForTarget.vectors.push( + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, + vectorBytes: await Serialize.vectorsToBytes(vector.vectors), // eslint-disable-line no-await-in-loop + }) + ); + } else if (NearVectorInputGuards.isListOf2D(vector)) { + for (const v of vector.vectors) { + vectorForTarget.vectors.push( + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_MULTI_FP32, + vectorBytes: await Serialize.vectorsToBytes(v), // eslint-disable-line no-await-in-loop + }) + ); + } + } else if (ArrayInputGuards.is2DArray(vector)) { + vectorForTarget.vectors.push( + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_MULTI_FP32, + vectorBytes: await Serialize.vectorsToBytes(vector), // eslint-disable-line no-await-in-loop + }) + ); + } else { + vectorForTarget.vectors.push( + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, + vectorBytes: Serialize.vectorToBytes(vector), + }) + ); + } + vectorForTargets.push(vectorForTarget); } + return args.targetVector !== undefined + ? { + ...Serialize.targetVector(args), + vectorForTargets, + } + : { + targetVectors: undefined, + targets: Targets.fromPartial({ + targetVectors: vectorForTargets.map((v) => v.name), + }), + vectorForTargets, + }; + } + if (args.vector.length === 0) { throw invalidVectorError; } + if (NearVectorInputGuards.is1D(args.vector)) { + const { targetVectors, targets } = Serialize.targetVector(args); + const vectorBytes = Serialize.vectorToBytes(args.vector); + return args.supportsVectors + ? { + targets, + targetVectors, + vectors: [Vectors.fromPartial({ type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, vectorBytes })], + } + : { + targets, + targetVectors, + vectorBytes, + }; + } + if (NearVectorInputGuards.is2D(args.vector)) { + if (!args.supportsVectors) { + throw new WeaviateUnsupportedFeatureError('Multi-vectors are not supported in Weaviate <1.29.0'); + } + const { targetVectors, targets } = Serialize.targetVector(args); + const vectorBytes = await Serialize.vectorsToBytes(args.vector); + return { + targets, + targetVectors, + vectors: [Vectors.fromPartial({ type: Vectors_VectorType.VECTOR_TYPE_MULTI_FP32, vectorBytes })], + }; + } + throw invalidVectorError; }; - private static targets = ( - targets: MultiTargetVectorJoin, + private static targets = ( + targets: MultiTargetVectorJoin, supportsWeightsForTargets: boolean ): { combination: CombinationMethod; @@ -1403,7 +1554,7 @@ export class Serialize { .map(([target, weight]) => { return { target, - weight, + weight: weight as number | number[], }; }) .reduce((acc, { target, weight }) => { @@ -1439,8 +1590,12 @@ export class Serialize { } }; - public static nearVideoSearch = ( - args: { video: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions + public static nearVideoSearch = ( + args: { video: string; supportsTargets: boolean; supportsWeightsForTargets: boolean } & NearOptions< + T, + V, + I + > ): NearVideoSearch => { const { targets, targetVectors } = Serialize.targetVector(args); return NearVideoSearch.fromPartial({ @@ -1853,11 +2008,20 @@ export class Serialize { let vectorBytes: Uint8Array | undefined; let vectors: VectorsGrpc[] | undefined; if (obj.vectors !== undefined && !Array.isArray(obj.vectors)) { - vectors = Object.entries(obj.vectors).map(([k, v]) => - VectorsGrpc.fromPartial({ - vectorBytes: Serialize.vectorToBytes(v), - name: k, - }) + vectors = Object.entries(obj.vectors).flatMap(([k, v]) => + NearVectorInputGuards.is1D(v) + ? [ + VectorsGrpc.fromPartial({ + vectorBytes: Serialize.vectorToBytes(v), + name: k, + }), + ] + : v.map((vv) => + VectorsGrpc.fromPartial({ + vectorBytes: Serialize.vectorToBytes(vv), + name: k, + }) + ) ); } else if (Array.isArray(obj.vectors) && requiresInsertFix) { vectors = [ diff --git a/src/collections/serialize/unit.test.ts b/src/collections/serialize/unit.test.ts index 6d9f9612..e6f12764 100644 --- a/src/collections/serialize/unit.test.ts +++ b/src/collections/serialize/unit.test.ts @@ -12,7 +12,7 @@ import { SearchNearVectorArgs, SearchNearVideoArgs, } from '../../grpc/searcher.js'; -import { Filters, Filters_Operator } from '../../proto/v1/base.js'; +import { Filters, Filters_Operator, Vectors, Vectors_VectorType } from '../../proto/v1/base.js'; import { BM25, CombinationMethod, @@ -143,13 +143,14 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for simple hybrid', () => { - const args = Serialize.search.hybrid( + it('should parse args for simple hybrid <1.29', async () => { + const args = await Serialize.search.hybrid( { query: 'test', supportsTargets: false, supportsVectorsForTargets: false, supportsWeightsForTargets: false, + supportsVectors: false, }, { queryProperties: ['name'], @@ -174,13 +175,53 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for multi-vector & multi-target hybrid', () => { - const args = Serialize.search.hybrid( + it('should parse args for simple hybrid >=1.29', async () => { + const args = await Serialize.search.hybrid( { query: 'test', supportsTargets: true, supportsVectorsForTargets: true, supportsWeightsForTargets: true, + supportsVectors: true, + }, + { + queryProperties: ['name'], + alpha: 0.6, + vector: [1, 2, 3], + targetVector: 'title', + fusionType: 'Ranked', + maxVectorDistance: 0.4, + } + ); + expect(args).toEqual({ + hybridSearch: Hybrid.fromPartial({ + query: 'test', + properties: ['name'], + alpha: 0.6, + vectors: [ + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, + vectorBytes: new Uint8Array(new Float32Array([1, 2, 3]).buffer), + }), + ], + targets: { + targetVectors: ['title'], + }, + fusionType: Hybrid_FusionType.FUSION_TYPE_RANKED, + vectorDistance: 0.4, + }), + metadata: MetadataRequest.fromPartial({ uuid: true }), + }); + }); + + it('should parse args for multi-vector & multi-target hybrid', async () => { + const args = await Serialize.search.hybrid( + { + query: 'test', + supportsTargets: true, + supportsVectorsForTargets: true, + supportsWeightsForTargets: true, + supportsVectors: false, }, { queryProperties: ['name'], @@ -364,12 +405,13 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for nearVector with single vector', () => { - const args = Serialize.search.nearVector({ + it('should parse args for nearVector with single vector <1.29', async () => { + const args = await Serialize.search.nearVector({ vector: [1, 2, 3], supportsTargets: false, supportsVectorsForTargets: false, supportsWeightsForTargets: false, + supportsVectors: false, }); expect(args).toEqual({ nearVector: NearVector.fromPartial({ @@ -379,8 +421,29 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for nearVector with two named vectors and supportsTargets (<1.27.0)', () => { - const args = Serialize.search.nearVector({ + it('should parse args for nearVector with single vector >=1.29', async () => { + const args = await Serialize.search.nearVector({ + vector: [1, 2, 3], + supportsTargets: false, + supportsVectorsForTargets: false, + supportsWeightsForTargets: false, + supportsVectors: true, + }); + expect(args).toEqual({ + nearVector: NearVector.fromPartial({ + vectors: [ + Vectors.fromPartial({ + type: Vectors_VectorType.VECTOR_TYPE_SINGLE_FP32, + vectorBytes: new Uint8Array(new Float32Array([1, 2, 3]).buffer), + }), + ], + }), + metadata: MetadataRequest.fromPartial({ uuid: true }), + }); + }); + + it('should parse args for nearVector with two named vectors and supportsTargets (<1.27.0)', async () => { + const args = await Serialize.search.nearVector({ vector: { a: [1, 2, 3], b: [4, 5, 6], @@ -388,6 +451,7 @@ describe('Unit testing of Serialize', () => { supportsTargets: true, supportsVectorsForTargets: false, supportsWeightsForTargets: false, + supportsVectors: false, }); expect(args).toEqual({ nearVector: NearVector.fromPartial({ @@ -401,8 +465,8 @@ describe('Unit testing of Serialize', () => { }); }); - it('should parse args for nearVector with two named vectors and all supports (==1.27.x)', () => { - const args = Serialize.search.nearVector({ + it('should parse args for nearVector with two named vectors and all supports (==1.27.x)', async () => { + const args = await Serialize.search.nearVector({ vector: { a: [ [1, 2, 3], @@ -413,6 +477,7 @@ describe('Unit testing of Serialize', () => { supportsTargets: true, supportsVectorsForTargets: true, supportsWeightsForTargets: true, + supportsVectors: false, }); expect(args).toEqual({ nearVector: NearVector.fromPartial({ @@ -659,7 +724,7 @@ describe('Unit testing of Serialize', () => { }; type Test = { name: string; - targetVector: TargetVectorInputType; + targetVector: TargetVectorInputType; supportsTargets: boolean; supportsWeightsForTargets: boolean; out: Out; @@ -685,7 +750,7 @@ describe('Unit testing of Serialize', () => { }, { name: 'should parse MultiTargetJoin sum', - targetVector: multiTargetVector().average(['a', 'b']), + targetVector: multiTargetVector().average(['a', 'b']), supportsTargets: true, supportsWeightsForTargets: false, out: { @@ -727,7 +792,7 @@ describe('Unit testing of Serialize', () => { }, { name: 'should parse MultiTargetJoin minimum', - targetVector: multiTargetVector().minimum(['a', 'b']), + targetVector: multiTargetVector().minimum(['a', 'b']), supportsTargets: true, supportsWeightsForTargets: false, out: { @@ -739,7 +804,7 @@ describe('Unit testing of Serialize', () => { }, { name: 'should parse MultiTargetJoin sum', - targetVector: multiTargetVector().average(['a', 'b']), + targetVector: multiTargetVector().average(['a', 'b']), supportsTargets: true, supportsWeightsForTargets: false, out: { @@ -781,7 +846,7 @@ describe('Unit testing of Serialize', () => { }, { name: 'should parse MultiTargetJoin sum', - targetVector: multiTargetVector().sum(['a', 'b']), + targetVector: multiTargetVector().sum(['a', 'b']), supportsTargets: true, supportsWeightsForTargets: false, out: { diff --git a/src/collections/types/generate.ts b/src/collections/types/generate.ts index a330002a..4cee21db 100644 --- a/src/collections/types/generate.ts +++ b/src/collections/types/generate.ts @@ -32,8 +32,9 @@ import { GroupByObject, GroupByResult, WeaviateGenericObject, WeaviateNonGeneric export type GenerativeGenericObject< T, + V, C extends GenerativeConfigRuntime | undefined -> = WeaviateGenericObject & { +> = WeaviateGenericObject & { /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this single object. */ generated?: string; /** Generative data returned from the LLM inference on this object. */ @@ -53,9 +54,13 @@ export type GenerativeNonGenericObject = T extends undefined - ? GenerativeNonGenericObject - : GenerativeGenericObject; +export type GenerativeObject = T extends undefined + ? V extends undefined + ? GenerativeNonGenericObject + : GenerativeGenericObject['properties'], V, C> + : V extends undefined + ? GenerativeGenericObject['vectors'], C> + : GenerativeGenericObject; export type GenerativeSingle = { debug?: GenerativeDebug; @@ -69,26 +74,29 @@ export type GenerativeGrouped = { }; /** The return of a query method in the `collection.generate` namespace. */ -export type GenerativeReturn = { +export type GenerativeReturn = { /** The objects that were found by the query. */ - objects: GenerativeObject[]; + objects: GenerativeObject[]; /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this query as a whole. */ generated?: string; generative?: GenerativeGrouped; }; -export type GenerativeGroupByResult = GroupByResult & { +export type GenerativeGroupByResult = GroupByResult< + T, + V +> & { /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this query as a whole. */ generated?: string; generative?: GenerativeSingle; }; /** The return of a query method in the `collection.generate` namespace where the `groupBy` argument was specified. */ -export type GenerativeGroupByReturn = { +export type GenerativeGroupByReturn = { /** The objects that were found by the query. */ - objects: GroupByObject[]; + objects: GroupByObject[]; /** The groups that were created by the query. */ - groups: Record>; + groups: Record>; /** @deprecated (use `generative.text` instead) The LLM-generated output applicable to this query as a whole. */ generated?: string; generative?: GenerativeGrouped; @@ -201,9 +209,9 @@ export type GenerativeMetadata = : never : never; -export type GenerateReturn = - | Promise> - | Promise>; +export type GenerateReturn = + | Promise> + | Promise>; export type GenerativeAnthropicConfigRuntime = { baseURL?: string | undefined; diff --git a/src/collections/types/internal.ts b/src/collections/types/internal.ts index 001c24ab..0227d769 100644 --- a/src/collections/types/internal.ts +++ b/src/collections/types/internal.ts @@ -40,6 +40,10 @@ export type QueryReference = T extends undefined ? RefPropertyDefault : RefPr export type NonRefProperty = keyof T | QueryNested; export type NonPrimitiveProperty = RefProperty | QueryNested; +export type QueryVector = V extends undefined ? string : keyof V & string; + +export type IncludeVector = boolean | QueryVector[] | undefined; + export type IsEmptyType = keyof T extends never ? true : false; export type ReferenceInput = diff --git a/src/collections/types/query.ts b/src/collections/types/query.ts index 1e8b3c25..5482e37c 100644 --- a/src/collections/types/query.ts +++ b/src/collections/types/query.ts @@ -1,4 +1,5 @@ import { WeaviateField } from '../index.js'; +import { PrimitiveVectorType } from '../query/types.js'; import { CrossReferenceDefault } from '../references/index.js'; import { ExtractCrossReferenceType, @@ -26,7 +27,7 @@ export type QueryMetadata = 'all' | MetadataKeys | undefined; export type ReturnMetadata = Partial; -export type WeaviateGenericObject = { +export type WeaviateGenericObject = { /** The generic returned properties of the object derived from the type `T`. */ properties: ReturnProperties; /** The returned metadata of the object. */ @@ -36,7 +37,7 @@ export type WeaviateGenericObject = { /** The UUID of the object. */ uuid: string; /** The returned vectors of the object. */ - vectors: Vectors; + vectors: V; }; export type WeaviateNonGenericObject = { @@ -56,45 +57,60 @@ export type ReturnProperties = Pick>; export type ReturnReferences = Pick>; -export type Vectors = Record; +export interface Vectors { + [k: string]: PrimitiveVectorType; +} -export type ReturnVectors = V extends string[] - ? { [Key in V[number]]: number[] } - : Record; +export type ReturnVectors = V extends undefined + ? undefined + : I extends true + ? V + : I extends Array + ? Pick< + V, + { + [Key in keyof V]: Key extends U ? Key : never; + }[keyof V] + > + : never; /** An object belonging to a collection as returned by the methods in the `collection.query` namespace. * * Depending on the generic type `T`, the object will have subfields that map from `T`'s specific type definition. * If not, then the object will be non-generic and have a `properties` field that maps from a generic string to a `WeaviateField`. */ -export type WeaviateObject = T extends undefined // need this instead of Properties to avoid circular type reference - ? WeaviateNonGenericObject - : WeaviateGenericObject; +export type WeaviateObject = T extends undefined // need this instead of Properties to avoid circular type reference + ? V extends undefined + ? WeaviateNonGenericObject + : WeaviateGenericObject + : V extends undefined + ? WeaviateGenericObject + : WeaviateGenericObject; /** The return of a query method in the `collection.query` namespace. */ -export type WeaviateReturn = { +export type WeaviateReturn = { /** The objects that were found by the query. */ - objects: WeaviateObject[]; + objects: WeaviateObject[]; }; -export type GroupByObject = WeaviateObject & { +export type GroupByObject = WeaviateObject & { belongsToGroup: string; }; -export type GroupByResult = { +export type GroupByResult = { name: string; minDistance: number; maxDistance: number; numberOfObjects: number; - objects: WeaviateObject[]; + objects: WeaviateObject[]; }; /** The return of a query method in the `collection.query` namespace where the `groupBy` argument was specified. */ -export type GroupByReturn = { +export type GroupByReturn = { /** The objects that were found by the query. */ - objects: GroupByObject[]; + objects: GroupByObject[]; /** The groups that were created by the query. */ - groups: Record>; + groups: Record>; }; export type GroupByOptions = T extends undefined diff --git a/src/collections/vectors/journey.test.ts b/src/collections/vectors/journey.test.ts new file mode 100644 index 00000000..de2ab644 --- /dev/null +++ b/src/collections/vectors/journey.test.ts @@ -0,0 +1,189 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +/* eslint-disable @typescript-eslint/no-non-null-asserted-optional-chain */ +import weaviate, { + VectorIndexConfigHNSW, + WeaviateClient, + WeaviateField, + WeaviateGenericObject, +} from '../../index.js'; +import { DbVersion } from '../../utils/dbVersion.js'; +import { Collection } from '../collection/index.js'; +import { MultiVectorType, SingleVectorType } from '../query/types.js'; + +const only = DbVersion.fromString(`v${process.env.WEAVIATE_VERSION!}`).isAtLeast(1, 29, 0) + ? describe + : describe.skip; + +only('Testing of the collection.query methods with a collection with multvectors', () => { + let client: WeaviateClient; + let collection: Collection; + const collectionName = 'TestCollectionQueryWithMultiVectors'; + + let id1: string; + let id2: string; + + let singleVector: SingleVectorType; + let multiVector: MultiVectorType; + + type MyVectors = { + regular: SingleVectorType; + colbert: MultiVectorType; + }; + + afterAll(() => { + return client.collections.delete(collectionName).catch((err) => { + console.error(err); + throw err; + }); + }); + + beforeAll(async () => { + client = await weaviate.connectToLocal(); + collection = client.collections.use(collectionName); + }); + + afterAll(() => client.collections.delete(collectionName)); + + it('should be able to create a collection including multivectors', async () => { + const { hnsw } = weaviate.configure.vectorIndex; + const { multiVector } = weaviate.configure.vectorIndex.multiVector; + collection = await client.collections.create({ + name: collectionName, + vectorizers: [ + weaviate.configure.vectorizer.none({ + name: 'regular', + }), + weaviate.configure.vectorizer.none({ + name: 'colbert', + vectorIndexConfig: hnsw({ + multiVector: multiVector(), + }), + }), + ], + }); + }); + + it('should be able to get the config of the created collection', async () => { + const config = await collection.config.get(); + expect(config.vectorizers.regular).toBeDefined(); + expect(config.vectorizers.colbert).toBeDefined(); + expect((config.vectorizers.regular.indexConfig as VectorIndexConfigHNSW).multiVector).toBeUndefined(); + expect((config.vectorizers.colbert.indexConfig as VectorIndexConfigHNSW).multiVector).toBeDefined(); + }); + + it('should be able to insert one object with multiple multivectors', async () => { + id1 = await collection.data.insert({ + vectors: { + regular: [1, 2, 3, 4], + colbert: [ + [1, 2], + [3, 4], + ], + }, + }); + }); + + it('should be able to get the inserted object with its vectors stated implicitly', async () => { + const obj = await collection.query.fetchObjectById(id1, { includeVector: true }); + const assert = (obj: any): obj is WeaviateGenericObject, MyVectors> => { + expect(obj).not.toBeNull(); + return true; + }; + if (assert(obj)) { + singleVector = obj.vectors.regular; + multiVector = obj.vectors.colbert; + expect(obj.uuid).toBe(id1); + expect(obj.vectors).toBeDefined(); + expect(obj.vectors.regular).toEqual([1, 2, 3, 4]); + expect(obj.vectors.colbert).toEqual([ + [1, 2], + [3, 4], + ]); + } + }); + + it('should be able to get the inserted object with its vectors stated explicitly', async () => { + const obj = await collection.query.fetchObjectById(id1, { includeVector: ['regular', 'colbert'] }); + const assert = (obj: any): obj is WeaviateGenericObject, MyVectors> => { + expect(obj).not.toBeNull(); + return true; + }; + if (assert(obj)) { + singleVector = obj.vectors.regular; + multiVector = obj.vectors.colbert; + expect(obj.uuid).toBe(id1); + expect(obj.vectors).toBeDefined(); + expect(obj.vectors.regular).toEqual([1, 2, 3, 4]); + expect(obj.vectors.colbert).toEqual([ + [1, 2], + [3, 4], + ]); + } + }); + + it('should be able to get the inserted object with one of its vectors', async () => { + const obj = await collection.query.fetchObjectById(id1, { includeVector: ['regular'] }); + singleVector = obj?.vectors.regular!; + expect(obj?.uuid).toBe(id1); + expect(obj?.vectors).toBeDefined(); + expect(obj?.vectors.regular).toEqual([1, 2, 3, 4]); + expect((obj?.vectors as MyVectors).colbert).toBeUndefined(); + }); + + it('should be able to query with hybrid for the inserted object over the single vector space', async () => { + const result = await collection.query.hybrid('', { + alpha: 1, + vector: singleVector, + targetVector: ['regular'], + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with hybrid for the inserted object over the multi vector space', async () => { + const result = await collection.query.hybrid('', { + alpha: 1, + vector: multiVector, + targetVector: ['colbert'], + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with hybrid for the inserted object over both spaces simultaneously', async () => { + const result = await collection.query.hybrid('', { + alpha: 1, + vector: { regular: singleVector, colbert: multiVector }, + targetVector: collection.multiTargetVector.sum(['regular', 'colbert']), + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with nearVector for the inserted object over the single vector space', async () => { + const result = await collection.query.nearVector(singleVector, { + certainty: 0.5, + targetVector: ['regular'], + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with nearVector for the inserted object over the multi vector space', async () => { + const result = await collection.query.nearVector(multiVector, { + certainty: 0.5, + targetVector: ['colbert'], + }); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); + + it('should be able to query with nearVector for the inserted object over both spaces simultaneously', async () => { + const result = await collection.query.nearVector( + { regular: singleVector, colbert: multiVector }, + { targetVector: collection.multiTargetVector.sum(['regular', 'colbert']) } + ); + expect(result.objects.length).toBe(1); + expect(result.objects[0].uuid).toBe(id1); + }); +}); diff --git a/src/collections/vectors/multiTargetVector.ts b/src/collections/vectors/multiTargetVector.ts index f9d2abdf..2ce41fbb 100644 --- a/src/collections/vectors/multiTargetVector.ts +++ b/src/collections/vectors/multiTargetVector.ts @@ -1,3 +1,5 @@ +import { TargetVector } from '../query/types.js'; + /** The allowed combination methods for multi-target vector joins */ export type MultiTargetVectorJoinCombination = | 'sum' @@ -7,55 +9,63 @@ export type MultiTargetVectorJoinCombination = | 'manual-weights'; /** Weights for each target vector in a multi-target vector join */ -export type MultiTargetVectorWeights = Record; +export type MultiTargetVectorWeights = Partial, number | number[]>>; /** A multi-target vector join used when specifying a vector-based query */ -export type MultiTargetVectorJoin = { +export type MultiTargetVectorJoin = { /** The combination method to use for the target vectors */ combination: MultiTargetVectorJoinCombination; /** The target vectors to combine */ - targetVectors: string[]; + targetVectors: TargetVector[]; /** The weights to use for each target vector */ - weights?: MultiTargetVectorWeights; + weights?: MultiTargetVectorWeights; }; -export default () => { +export default () => { return { - sum: (targetVectors: string[]): MultiTargetVectorJoin => { + sum: []>(targetVectors: T): MultiTargetVectorJoin => { return { combination: 'sum' as MultiTargetVectorJoinCombination, targetVectors }; }, - average: (targetVectors: string[]): MultiTargetVectorJoin => { + average: []>(targetVectors: T): MultiTargetVectorJoin => { return { combination: 'average' as MultiTargetVectorJoinCombination, targetVectors }; }, - minimum: (targetVectors: string[]): MultiTargetVectorJoin => { + minimum: []>(targetVectors: T): MultiTargetVectorJoin => { return { combination: 'minimum' as MultiTargetVectorJoinCombination, targetVectors }; }, - relativeScore: (weights: MultiTargetVectorWeights): MultiTargetVectorJoin => { + relativeScore: []>( + weights: MultiTargetVectorWeights + ): MultiTargetVectorJoin => { return { combination: 'relative-score' as MultiTargetVectorJoinCombination, - targetVectors: Object.keys(weights), + targetVectors: Object.keys(weights) as T, weights, }; }, - manualWeights: (weights: MultiTargetVectorWeights): MultiTargetVectorJoin => { + manualWeights: []>( + weights: MultiTargetVectorWeights + ): MultiTargetVectorJoin => { return { combination: 'manual-weights' as MultiTargetVectorJoinCombination, - targetVectors: Object.keys(weights), + targetVectors: Object.keys(weights) as T, weights, }; }, }; }; -export interface MultiTargetVector { +export interface MultiTargetVector { /** Create a multi-target vector join that sums the vector scores over the target vectors */ - sum: (targetVectors: string[]) => MultiTargetVectorJoin; + sum: []>(targetVectors: T) => MultiTargetVectorJoin; /** Create a multi-target vector join that averages the vector scores over the target vectors */ - average: (targetVectors: string[]) => MultiTargetVectorJoin; + average: []>(targetVectors: T) => MultiTargetVectorJoin; /** Create a multi-target vector join that takes the minimum vector score over the target vectors */ - minimum: (targetVectors: string[]) => MultiTargetVectorJoin; + minimum: []>(targetVectors: T) => MultiTargetVectorJoin; /** Create a multi-target vector join that uses relative weights for each target vector */ - relativeScore: (weights: MultiTargetVectorWeights) => MultiTargetVectorJoin; + relativeScore: []>( + weights: MultiTargetVectorWeights + ) => MultiTargetVectorJoin; /** Create a multi-target vector join that uses manual weights for each target vector */ - manualWeights: (weights: MultiTargetVectorWeights) => MultiTargetVectorJoin; + manualWeights: []>( + weights: MultiTargetVectorWeights + ) => MultiTargetVectorJoin; } diff --git a/src/index.ts b/src/index.ts index 775d5f1f..ad8dcba7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,7 +3,7 @@ import { Backup, backup } from './collections/backup/client.js'; import cluster, { Cluster } from './collections/cluster/index.js'; import { configGuards } from './collections/config/index.js'; import { configure, reconfigure } from './collections/configure/index.js'; -import collections, { Collections } from './collections/index.js'; +import collections, { Collections, queryFactory } from './collections/index.js'; import { AccessTokenCredentialsInput, ApiKey, @@ -256,6 +256,7 @@ const app = { configGuards, reconfigure, permissions, + query: queryFactory, }; export default app; diff --git a/src/openapi/schema.ts b/src/openapi/schema.ts index 56a18d78..9c633603 100644 --- a/src/openapi/schema.ts +++ b/src/openapi/schema.ts @@ -41,11 +41,31 @@ export interface paths { }; }; '/replication/replicate': { + /** Begins an asynchronous operation to move or copy a specific shard replica from its current node to a designated target node. The operation involves copying data, synchronizing, and potentially decommissioning the source replica. */ post: operations['replicate']; + delete: operations['deleteAllReplications']; + }; + '/replication/replicate/force-delete': { + /** USE AT OWN RISK! Synchronously force delete operations from the FSM. This will not perform any checks on which state the operation is in so may lead to data corruption or loss. It is recommended to first scale the number of replication engine workers to 0 before calling this endpoint to ensure no operations are in-flight. */ + post: operations['forceDeleteReplications']; }; '/replication/replicate/{id}': { - /** Returns the details of a replication operation for a given shard, identified by the provided replication operation id. */ + /** Fetches the current status and detailed information for a specific replication operation, identified by its unique ID. Optionally includes historical data of the operation's progress if requested. */ get: operations['replicationDetails']; + /** Removes a specific replication operation. If the operation is currently active, it will be cancelled and its resources cleaned up before the operation is deleted. */ + delete: operations['deleteReplication']; + }; + '/replication/replicate/list': { + /** Retrieves a list of currently registered replication operations, optionally filtered by collection, shard, or node ID. */ + get: operations['listReplication']; + }; + '/replication/replicate/{id}/cancel': { + /** Requests the cancellation of an active replication operation identified by its ID. The operation will be stopped, but its record will remain in the 'CANCELLED' state (can't be resumed) and will not be automatically deleted. */ + post: operations['cancelReplication']; + }; + '/replication/sharding-state': { + /** Fetches the current sharding state, including replica locations and statuses, for all collections or a specified collection. If a shard name is provided along with a collection, the state for that specific shard is returned. */ + get: operations['getCollectionShardingState']; }; '/users/own-info': { get: operations['getOwnInfo']; @@ -250,6 +270,9 @@ export interface paths { /** Returns node information for the nodes relevant to the collection. */ get: operations['nodes.get.class']; }; + '/tasks': { + get: operations['distributedTasks.get']; + }; '/classifications/': { /** Trigger a classification based on the specified params. Classifications will run in the background, use GET /classifications/ to retrieve the status of your classification. */ post: operations['classifications.post']; @@ -397,6 +420,19 @@ export interface definitions { */ collection?: string; }; + /** @description resources applicable for replicate actions */ + replicate?: { + /** + * @description string or regex. if a specific collection name, if left empty it will be ALL or * + * @default * + */ + collection?: string; + /** + * @description string or regex. if a specific shard name, if left empty it will be ALL or * + * @default * + */ + shard?: string; + }; /** * @description allowed actions in weaviate. * @enum {string} @@ -425,7 +461,11 @@ export interface definitions { | 'create_tenants' | 'read_tenants' | 'update_tenants' - | 'delete_tenants'; + | 'delete_tenants' + | 'create_replicate' + | 'read_replicate' + | 'update_replicate' + | 'delete_replicate'; }; /** @description list of roles */ RolesListResponse: definitions['Role'][]; @@ -678,57 +718,125 @@ export interface definitions { value?: { [key: string]: unknown }; merge?: definitions['Object']; }; - /** @description Request body to add a replica of given shard of a given collection */ + /** @description Specifies the parameters required to initiate a shard replica movement operation between two nodes for a given collection and shard. This request defines the source and destination node, the collection and type of transfer. */ ReplicationReplicateReplicaRequest: { - /** @description The node containing the replica */ + /** @description The name of the Weaviate node currently hosting the shard replica that needs to be moved or copied. */ sourceNodeName: string; - /** @description The node to add a copy of the replica on */ + /** @description The name of the Weaviate node where the new shard replica will be created as part of the movement or copy operation. */ destinationNodeName: string; - /** @description The collection name holding the shard */ + /** @description The unique identifier (name) of the collection to which the target shard belongs. */ collectionId: string; - /** @description The shard id holding the replica to be copied */ + /** @description The ID of the shard whose replica is to be moved or copied. */ shardId: string; + /** + * @description Specifies the type of replication operation to perform. 'COPY' creates a new replica on the destination node while keeping the source replica. 'MOVE' creates a new replica on the destination node and then removes the source replica upon successful completion. Defaults to 'COPY' if omitted. + * @default COPY + * @enum {string} + */ + transferType?: 'COPY' | 'MOVE'; + }; + /** @description Contains the unique identifier for a successfully initiated asynchronous replica movement operation. This ID can be used to track the progress of the operation. */ + ReplicationReplicateReplicaResponse: { + /** + * Format: uuid + * @description The unique identifier (ID) assigned to the registered replication operation. + */ + id: string; }; - /** @description Request body to disable (soft-delete) a replica of given shard of a given collection */ + /** @description Provides the detailed sharding state for one or more collections, including the distribution of shards and their replicas across the cluster nodes. */ + ReplicationShardingStateResponse: { + shardingState?: definitions['ReplicationShardingState']; + }; + /** @description Specifies the parameters required to mark a specific shard replica as inactive (soft-delete) on a particular node. This action typically prevents the replica from serving requests but does not immediately remove its data. */ ReplicationDisableReplicaRequest: { - /** @description The node containing the replica to be disabled */ + /** @description The name of the Weaviate node hosting the shard replica that is to be disabled. */ nodeName: string; - /** @description The collection name holding the replica to be disabled */ + /** @description The name of the collection to which the shard replica belongs. */ collectionId: string; - /** @description The shard id holding the replica to be disabled */ + /** @description The ID of the shard whose replica is to be disabled. */ shardId: string; }; - /** @description Request body to delete a replica of given shard of a given collection */ + /** @description Specifies the parameters required to permanently delete a specific shard replica from a particular node. This action will remove the replica's data from the node. */ ReplicationDeleteReplicaRequest: { - /** @description The node containing the replica to be deleted */ + /** @description The name of the Weaviate node from which the shard replica will be deleted. */ nodeName: string; - /** @description The collection name holding the replica to be delete */ + /** @description The name of the collection to which the shard replica belongs. */ collectionId: string; - /** @description The shard id holding the replica to be deleted */ + /** @description The ID of the shard whose replica is to be deleted. */ shardId: string; }; - /** @description The current status and details of a replication operation, including information about the resources involved in the replication process. */ + /** @description Represents a shard and lists the nodes that currently host its replicas. */ + ReplicationShardReplicas: { + shard?: string; + replicas?: string[]; + }; + /** @description Details the sharding layout for a specific collection, mapping each shard to its set of replicas across the cluster. */ + ReplicationShardingState: { + /** @description The name of the collection. */ + collection?: string; + /** @description An array detailing each shard within the collection and the nodes hosting its replicas. */ + shards?: definitions['ReplicationShardReplicas'][]; + }; + /** @description Represents the current or historical status of a shard replica involved in a replication operation, including its operational state and any associated errors. */ + ReplicationReplicateDetailsReplicaStatus: { + /** + * @description The current operational state of the replica during the replication process. + * @enum {string} + */ + state?: 'REGISTERED' | 'HYDRATING' | 'FINALIZING' | 'DEHYDRATING' | 'READY' | 'CANCELLED'; + /** @description A list of error messages encountered by this replica during the replication operation, if any. */ + errors?: string[]; + }; + /** @description Provides a comprehensive overview of a specific replication operation, detailing its unique ID, the involved collection, shard, source and target nodes, transfer type, current status, and optionally, its status history. */ ReplicationReplicateDetailsReplicaResponse: { - /** @description The unique id of the replication operation. */ + /** + * Format: uuid + * @description The unique identifier (ID) of this specific replication operation. + */ id: string; - /** @description The id of the shard to collect replication details for. */ + /** @description The identifier of the shard involved in this replication operation. */ shardId: string; - /** @description The name of the collection holding data being replicated. */ + /** @description The name of the collection to which the shard being replicated belongs. */ collection: string; - /** @description The id of the node where the source replica is allocated. */ + /** @description The identifier of the node from which the replica is being moved or copied (the source node). */ sourceNodeId: string; - /** @description The id of the node where the target replica is allocated. */ + /** @description The identifier of the node to which the replica is being moved or copied (the destination node). */ targetNodeId: string; /** - * @description The current status of the replication operation, indicating the replication phase the operation is in. + * @description Indicates whether the operation is a 'COPY' (source replica remains) or a 'MOVE' (source replica is removed after successful transfer). * @enum {string} */ - status: - | 'READY' - | 'INDEXING' - | 'REPLICATION_FINALIZING' - | 'REPLICATION_HYDRATING' - | 'REPLICATION_DEHYDRATING'; + transferType: 'COPY' | 'MOVE'; + /** @description An object detailing the current operational state of the replica movement and any errors encountered. */ + status: definitions['ReplicationReplicateDetailsReplicaStatus']; + /** @description An array detailing the historical sequence of statuses the replication operation has transitioned through, if requested and available. */ + statusHistory?: definitions['ReplicationReplicateDetailsReplicaStatus'][]; + }; + /** @description Specifies the parameters available when force deleting replication operations. */ + ReplicationReplicateForceDeleteRequest: { + /** + * Format: uuid + * @description The unique identifier (ID) of the replication operation to be forcefully deleted. + */ + id?: string; + /** @description The name of the collection to which the shard being replicated belongs. */ + collection?: string; + /** @description The identifier of the shard involved in the replication operations. */ + shard?: string; + /** @description The name of the target node where the replication operations are registered. */ + node?: string; + /** + * @description If true, the operation will not actually delete anything but will return the expected outcome of the deletion. + * @default false + */ + dryRun?: boolean; + }; + /** @description Provides the UUIDs that were successfully force deleted as part of the replication operation. If dryRun is true, this will return the expected outcome without actually deleting anything. */ + ReplicationReplicateForceDeleteResponse: { + /** @description The unique identifiers (IDs) of the replication operations that were forcefully deleted. */ + deleted?: string[]; + /** @description Indicates whether the operation was a dry run (true) or an actual deletion (false). */ + dryRun?: boolean; }; /** @description A single peer in the network. */ PeerUpdate: { @@ -1088,6 +1196,33 @@ export interface definitions { vectorQueueLength?: number; /** @description The load status of the shard. */ loaded?: boolean; + /** @description The status of the async replication. */ + asyncReplicationStatus?: definitions['AsyncReplicationStatus'][]; + /** + * Format: int64 + * @description Number of replicas for the shard. + */ + numberOfReplicas?: unknown; + /** + * Format: int64 + * @description Minimum number of replicas for the shard. + */ + replicationFactor?: unknown; + }; + /** @description The status of the async replication. */ + AsyncReplicationStatus: { + /** + * Format: uint64 + * @description The number of objects propagated in the most recent iteration. + */ + objectsPropagated?: number; + /** + * Format: int64 + * @description The start time of the most recent iteration. + */ + startDiffTimeUnixMillis?: number; + /** @description The target node of the replication, if set, otherwise empty. */ + targetNode?: string; }; /** @description The definition of a backup node status response body */ NodeStatus: { @@ -1114,6 +1249,33 @@ export interface definitions { NodesStatusResponse: { nodes?: definitions['NodeStatus'][]; }; + /** @description Distributed task metadata. */ + DistributedTask: { + /** @description The ID of the task. */ + id?: string; + /** @description The version of the task. */ + version?: number; + /** @description The status of the task. */ + status?: string; + /** + * Format: date-time + * @description The time when the task was created. + */ + startedAt?: string; + /** + * Format: date-time + * @description The time when the task was finished. + */ + finishedAt?: string; + /** @description The nodes that finished the task. */ + finishedNodes?: string[]; + /** @description The high level reason why the task failed. */ + error?: string; + /** @description The payload of the task. */ + payload?: { [key: string]: unknown }; + }; + /** @description Active distributed tasks by namespace. */ + DistributedTasks: { [key: string]: definitions['DistributedTask'][] }; /** @description The definition of Raft statistics. */ RaftStatistics: { appliedIndex?: string; @@ -1670,11 +1832,6 @@ export interface definitions { | 'FREEZING' | 'UNFREEZING'; }; - /** @description attributes representing a single tenant response within weaviate */ - TenantResponse: definitions['Tenant'] & { - /** @description The list of nodes that owns that tenant data. */ - belongsToNodes?: string[]; - }; } export interface parameters { @@ -1740,15 +1897,46 @@ export interface operations { 503: unknown; }; }; + /** Begins an asynchronous operation to move or copy a specific shard replica from its current node to a designated target node. The operation involves copying data, synchronizing, and potentially decommissioning the source replica. */ replicate: { parameters: { body: { body: definitions['ReplicationReplicateReplicaRequest']; }; }; + responses: { + /** Replication operation registered successfully. ID of the operation is returned. */ + 200: { + schema: definitions['ReplicationReplicateReplicaResponse']; + }; + /** Malformed request. */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + /** Replica movement operations are disabled. */ + 501: { + schema: definitions['ErrorResponse']; + }; + }; + }; + deleteAllReplications: { responses: { /** Replication operation registered successfully */ - 200: unknown; + 204: never; /** Malformed request. */ 400: { schema: definitions['ErrorResponse']; @@ -1767,22 +1955,140 @@ export interface operations { 500: { schema: definitions['ErrorResponse']; }; + /** Replica movement operations are disabled. */ + 501: { + schema: definitions['ErrorResponse']; + }; }; }; - /** Returns the details of a replication operation for a given shard, identified by the provided replication operation id. */ + /** USE AT OWN RISK! Synchronously force delete operations from the FSM. This will not perform any checks on which state the operation is in so may lead to data corruption or loss. It is recommended to first scale the number of replication engine workers to 0 before calling this endpoint to ensure no operations are in-flight. */ + forceDeleteReplications: { + parameters: { + body: { + body?: definitions['ReplicationReplicateForceDeleteRequest']; + }; + }; + responses: { + /** Replication operations force deleted successfully. */ + 200: { + schema: definitions['ReplicationReplicateForceDeleteResponse']; + }; + /** Malformed request. */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; + /** Fetches the current status and detailed information for a specific replication operation, identified by its unique ID. Optionally includes historical data of the operation's progress if requested. */ replicationDetails: { parameters: { path: { - /** The replication operation id to get details for. */ + /** The ID of the replication operation to get details for. */ id: string; }; + query: { + /** Whether to include the history of the replication operation. */ + includeHistory?: boolean; + }; }; responses: { /** The details of the replication operation. */ 200: { schema: definitions['ReplicationReplicateDetailsReplicaResponse']; }; - /** Malformed request. */ + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden. */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** Shard replica operation not found. */ + 404: unknown; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + /** Replica movement operations are disabled. */ + 501: { + schema: definitions['ErrorResponse']; + }; + }; + }; + /** Removes a specific replication operation. If the operation is currently active, it will be cancelled and its resources cleaned up before the operation is deleted. */ + deleteReplication: { + parameters: { + path: { + /** The ID of the replication operation to delete. */ + id: string; + }; + }; + responses: { + /** Successfully deleted. */ + 204: never; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden. */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** Shard replica operation not found. */ + 404: unknown; + /** The operation is not in a deletable state, e.g. it is a MOVE op in the DEHYDRATING state. */ + 409: { + schema: definitions['ErrorResponse']; + }; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + /** Replica movement operations are disabled. */ + 501: { + schema: definitions['ErrorResponse']; + }; + }; + }; + /** Retrieves a list of currently registered replication operations, optionally filtered by collection, shard, or node ID. */ + listReplication: { + parameters: { + query: { + /** The ID of the target node to get details for. */ + nodeId?: string; + /** The name of the collection to get details for. */ + collection?: string; + /** The shard to get details for. */ + shard?: string; + /** Whether to include the history of the replication operation. */ + includeHistory?: boolean; + }; + }; + responses: { + /** The details of the replication operations. */ + 200: { + schema: definitions['ReplicationReplicateDetailsReplicaResponse'][]; + }; + /** Bad request. */ 400: { schema: definitions['ErrorResponse']; }; @@ -1792,12 +2098,94 @@ export interface operations { 403: { schema: definitions['ErrorResponse']; }; - /** Shard replica operation not found */ + /** Shard replica operation not found. */ + 404: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + /** Replica movement operations are disabled. */ + 501: { + schema: definitions['ErrorResponse']; + }; + }; + }; + /** Requests the cancellation of an active replication operation identified by its ID. The operation will be stopped, but its record will remain in the 'CANCELLED' state (can't be resumed) and will not be automatically deleted. */ + cancelReplication: { + parameters: { + path: { + /** The ID of the replication operation to cancel. */ + id: string; + }; + }; + responses: { + /** Successfully cancelled. */ + 204: never; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** Shard replica operation not found. */ 404: unknown; + /** The operation is not in a cancellable state, e.g. it is READY or is a MOVE op in the DEHYDRATING state. */ + 409: { + schema: definitions['ErrorResponse']; + }; + /** Request body is well-formed (i.e., syntactically correct), but semantically erroneous. */ + 422: { + schema: definitions['ErrorResponse']; + }; /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ 500: { schema: definitions['ErrorResponse']; }; + /** Replica movement operations are disabled. */ + 501: { + schema: definitions['ErrorResponse']; + }; + }; + }; + /** Fetches the current sharding state, including replica locations and statuses, for all collections or a specified collection. If a shard name is provided along with a collection, the state for that specific shard is returned. */ + getCollectionShardingState: { + parameters: { + query: { + /** The collection name to get the sharding state for. */ + collection?: string; + /** The shard to get the sharding state for. */ + shard?: string; + }; + }; + responses: { + /** Successfully retrieved sharding state. */ + 200: { + schema: definitions['ReplicationShardingStateResponse']; + }; + /** Bad request. */ + 400: { + schema: definitions['ErrorResponse']; + }; + /** Unauthorized or invalid credentials. */ + 401: unknown; + /** Forbidden */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** Collection or shard not found. */ + 404: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + /** Replica movement operations are disabled. */ + 501: { + schema: definitions['ErrorResponse']; + }; }; }; getOwnInfo: { @@ -1812,6 +2200,10 @@ export interface operations { 500: { schema: definitions['ErrorResponse']; }; + /** Replica movement operations are disabled. */ + 501: { + schema: definitions['ErrorResponse']; + }; }; }; listAllUsers: { @@ -3858,7 +4250,7 @@ export interface operations { responses: { /** load the tenant given the specified class */ 200: { - schema: definitions['TenantResponse']; + schema: definitions['Tenant']; }; /** Unauthorized or invalid credentials. */ 401: unknown; @@ -4186,6 +4578,7 @@ export interface operations { className: string; }; query: { + shardName?: string; /** Controls the verbosity of the output, possible values are: "minimal", "verbose". Defaults to "minimal". */ output?: parameters['CommonOutputVerbosityParameterQuery']; }; @@ -4215,6 +4608,22 @@ export interface operations { }; }; }; + 'distributedTasks.get': { + responses: { + /** Distributed tasks successfully returned */ + 200: { + schema: definitions['DistributedTasks']; + }; + /** Unauthorized or invalid credentials. */ + 403: { + schema: definitions['ErrorResponse']; + }; + /** An error has occurred while trying to fulfill the request. Most likely the ErrorResponse will contain more information about the error. */ + 500: { + schema: definitions['ErrorResponse']; + }; + }; + }; /** Trigger a classification based on the specified params. Classifications will run in the background, use GET /classifications/ to retrieve the status of your classification. */ 'classifications.post': { parameters: { diff --git a/src/roles/integration.test.ts b/src/roles/integration.test.ts index 52540d63..40eafd2e 100644 --- a/src/roles/integration.test.ts +++ b/src/roles/integration.test.ts @@ -279,11 +279,7 @@ const testCases: TestCase[] = [ }, ]; -requireAtLeast( - 1, - 29, - 0 -)('Integration testing of the roles namespace', () => { +requireAtLeast(1, 29, 0)(describe)('Integration testing of the roles namespace', () => { let client: WeaviateClient; beforeAll(async () => { @@ -317,11 +313,7 @@ requireAtLeast( expect(exists).toBeFalsy(); }); - requireAtLeast( - 1, - 30, - 0 - )('namespaced users', () => { + requireAtLeast(1, 30, 0)(describe)('namespaced users', () => { it('retrieves assigned users with/without namespace', async () => { await client.roles.create('landlord', { collection: 'Buildings', diff --git a/src/users/integration.test.ts b/src/users/integration.test.ts index d148542a..1b4e28bb 100644 --- a/src/users/integration.test.ts +++ b/src/users/integration.test.ts @@ -3,11 +3,7 @@ import { requireAtLeast } from '../../test/version.js'; import { WeaviateUserTypeDB } from '../openapi/types.js'; import { GetUserOptions, UserDB } from './types.js'; -requireAtLeast( - 1, - 29, - 0 -)('Integration testing of the users namespace', () => { +requireAtLeast(1, 29, 0)(describe)('Integration testing of the users namespace', () => { const makeClient = (key: string) => weaviate.connectToLocal({ port: 8091, @@ -62,11 +58,7 @@ requireAtLeast( expect(roles.test).toBeUndefined(); }); - requireAtLeast( - 1, - 30, - 0 - )('dynamic user management', () => { + requireAtLeast(1, 30, 0)(describe)('dynamic user management', () => { /** List dynamic DB users. */ const listDBUsers = (c: WeaviateClient, opts?: GetUserOptions) => c.users.db.listAll(opts).then((all) => all.filter((u) => u.userType == 'db_user')); @@ -172,11 +164,7 @@ requireAtLeast( expect(roles.Permissioner.nodesPermissions).toHaveLength(1); }); - requireAtLeast( - 1, - 30, - 1 - )('additional DUM features', () => { + requireAtLeast(1, 30, 1)(describe)('additional DUM features', () => { it('should be able to fetch additional user info', async () => { const admin = await makeClient('admin-key'); const timKey = await admin.users.db.create('timely-tim'); diff --git a/src/utils/dbVersion.ts b/src/utils/dbVersion.ts index a945e414..d48901bc 100644 --- a/src/utils/dbVersion.ts +++ b/src/utils/dbVersion.ts @@ -227,6 +227,15 @@ export class DbVersionSupport { }); }; + supportsVectorsFieldInGRPC = () => { + return this.dbVersionProvider.getVersion().then((version) => { + return { + version: version, + supports: version.isAtLeast(1, 29, 0), + message: undefined, + }; + }); + }; supportsSingleGrouped = () => this.dbVersionProvider.getVersion().then((version) => ({ version, diff --git a/src/utils/yield.ts b/src/utils/yield.ts new file mode 100644 index 00000000..cf3a911c --- /dev/null +++ b/src/utils/yield.ts @@ -0,0 +1 @@ +export const yieldToEventLoop = () => new Promise((resolve) => setTimeout(resolve, 0)); diff --git a/test/version.ts b/test/version.ts index b34118ef..14cc1076 100644 --- a/test/version.ts +++ b/test/version.ts @@ -3,5 +3,7 @@ import { DbVersion } from '../src/utils/dbVersion'; const version = DbVersion.fromString(`v${process.env.WEAVIATE_VERSION!}`); /** Run the suite / test only for Weaviate version above this. */ -export const requireAtLeast = (...semver: [...Parameters]) => - version.isAtLeast(...semver) ? describe : describe.skip; +export const requireAtLeast = + (...semver: [...Parameters]) => + (type: jest.Describe | jest.It) => + version.isAtLeast(...semver) ? type : type.skip;