Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(firestore): add support for vector query #13559

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,34 @@ public void aggregateQuery(
});
}

@Override
public void findNearest(@NonNull GeneratedAndroidFirebaseFirestore.FirestorePigeonFirebaseApp app, @NonNull String path, @NonNull Boolean isCollectionGroup, @NonNull GeneratedAndroidFirebaseFirestore.PigeonQueryParameters parameters, @NonNull List<Double> queryVector, @NonNull GeneratedAndroidFirebaseFirestore.VectorSource source, @NonNull Long limit, @NonNull GeneratedAndroidFirebaseFirestore.VectorQueryOptions queryOptions, @NonNull GeneratedAndroidFirebaseFirestore.DistanceMeasure distanceMeasure, @NonNull GeneratedAndroidFirebaseFirestore.Result<GeneratedAndroidFirebaseFirestore.PigeonQuerySnapshot> result) {
Query query = PigeonParser.parseQuery(getFirestoreFromPigeon(app), path, isCollectionGroup, parameters);

if (query == null) {
result.error(
new GeneratedAndroidFirebaseFirestore.FlutterError(
"invalid_query",
"An error occurred while parsing query arguments, see native logs for more information. Please report this issue.",
null));
return;
}

cachedThreadPool.execute(
() -> {
try {
QuerySnapshot querySnapshot = Tasks.await(query.findNearest(queryVector, PigeonParser.parseVectorSource(source), limit, PigeonParser.parseVectorQueryOptions(queryOptions), PigeonParser.parseDistanceMeasure(distanceMeasure)));

result.success(
PigeonParser.toPigeonQuerySnapshot(
querySnapshot,
DocumentSnapshot.ServerTimestampBehavior.NONE));
} catch (Exception e) {
ExceptionConverter.sendErrorToFlutter(result, e);
}
});
}

@Override
public void writeBatchCommit(
@NonNull GeneratedAndroidFirebaseFirestore.FirestorePigeonFirebaseApp app,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,38 @@ private Source(final int index) {
}
}

/** An enumeration of firestore source types. */
public enum VectorSource {
/**
* Causes Firestore to avoid the cache, generating an error if the server cannot be reached.
* Note that the cache will still be updated if the server request succeeds. Also note that
* latency-compensation still takes effect, so any pending write operations will be visible in
* the returned data (merged into the server-provided data).
*/
SERVER(0);

final int index;

private VectorSource(final int index) {
this.index = index;
}
}

public enum DistanceMeasure {
/** The cosine similarity measure. */
COSINE(0),
/** The euclidean distance measure. */
EUCLIDEAN(1),
/** The dot product distance measure. */
DOT_PRODUCT(2);

final int index;

private DistanceMeasure(final int index) {
this.index = index;
}
}

/**
* The listener retrieves data and listens to updates from the local Firestore cache only. If the
* cache is empty, an empty snapshot will be returned. Snapshot events will be triggered on cache
Expand Down Expand Up @@ -856,6 +888,79 @@ public ArrayList<Object> toList() {
}
}

/** Generated class from Pigeon that represents data sent in messages. */
public static final class VectorQueryOptions {
private @NonNull String distanceResultField;

public @NonNull String getDistanceResultField() {
return distanceResultField;
}

public void setDistanceResultField(@NonNull String setterArg) {
if (setterArg == null) {
throw new IllegalStateException("Nonnull field \"distanceResultField\" is null.");
}
this.distanceResultField = setterArg;
}

private @NonNull Double distanceThreshold;

public @NonNull Double getDistanceThreshold() {
return distanceThreshold;
}

public void setDistanceThreshold(@NonNull Double setterArg) {
if (setterArg == null) {
throw new IllegalStateException("Nonnull field \"distanceThreshold\" is null.");
}
this.distanceThreshold = setterArg;
}

/** Constructor is non-public to enforce null safety; use Builder. */
VectorQueryOptions() {}

public static final class Builder {

private @Nullable String distanceResultField;

public @NonNull Builder setDistanceResultField(@NonNull String setterArg) {
this.distanceResultField = setterArg;
return this;
}

private @Nullable Double distanceThreshold;

public @NonNull Builder setDistanceThreshold(@NonNull Double setterArg) {
this.distanceThreshold = setterArg;
return this;
}

public @NonNull VectorQueryOptions build() {
VectorQueryOptions pigeonReturn = new VectorQueryOptions();
pigeonReturn.setDistanceResultField(distanceResultField);
pigeonReturn.setDistanceThreshold(distanceThreshold);
return pigeonReturn;
}
}

@NonNull
public ArrayList<Object> toList() {
ArrayList<Object> toListResult = new ArrayList<Object>(2);
toListResult.add(distanceResultField);
toListResult.add(distanceThreshold);
return toListResult;
}

static @NonNull VectorQueryOptions fromList(@NonNull ArrayList<Object> list) {
VectorQueryOptions pigeonResult = new VectorQueryOptions();
Object distanceResultField = list.get(0);
pigeonResult.setDistanceResultField((String) distanceResultField);
Object distanceThreshold = list.get(1);
pigeonResult.setDistanceThreshold((Double) distanceThreshold);
return pigeonResult;
}
}

/** Generated class from Pigeon that represents data sent in messages. */
public static final class PigeonGetOptions {
private @NonNull Source source;
Expand Down Expand Up @@ -1667,6 +1772,8 @@ protected Object readValueOfType(byte type, @NonNull ByteBuffer buffer) {
return PigeonSnapshotMetadata.fromList((ArrayList<Object>) readValue(buffer));
case (byte) 140:
return PigeonTransactionCommand.fromList((ArrayList<Object>) readValue(buffer));
case (byte) 141:
return VectorQueryOptions.fromList((ArrayList<Object>) readValue(buffer));
default:
return super.readValueOfType(type, buffer);
}
Expand Down Expand Up @@ -1713,6 +1820,9 @@ protected void writeValue(@NonNull ByteArrayOutputStream stream, Object value) {
} else if (value instanceof PigeonTransactionCommand) {
stream.write(140);
writeValue(stream, ((PigeonTransactionCommand) value).toList());
} else if (value instanceof VectorQueryOptions) {
stream.write(141);
writeValue(stream, ((VectorQueryOptions) value).toList());
} else {
super.writeValue(stream, value);
}
Expand Down Expand Up @@ -1809,6 +1919,18 @@ void aggregateQuery(
@NonNull Boolean isCollectionGroup,
@NonNull Result<List<AggregateQueryResponse>> result);

void findNearest(
@NonNull FirestorePigeonFirebaseApp app,
@NonNull String path,
@NonNull Boolean isCollectionGroup,
@NonNull PigeonQueryParameters parameters,
@NonNull List<Double> queryVector,
@NonNull VectorSource source,
@NonNull Long limit,
@NonNull VectorQueryOptions queryOptions,
@NonNull DistanceMeasure distanceMeasure,
@NonNull Result<PigeonQuerySnapshot> result);

void writeBatchCommit(
@NonNull FirestorePigeonFirebaseApp app,
@NonNull List<PigeonTransactionCommand> writes,
Expand Down Expand Up @@ -2478,6 +2600,55 @@ public void error(Throwable error) {
channel.setMessageHandler(null);
}
}
{
BasicMessageChannel<Object> channel =
new BasicMessageChannel<>(
binaryMessenger,
"dev.flutter.pigeon.cloud_firestore_platform_interface.FirebaseFirestoreHostApi.findNearest",
getCodec());
if (api != null) {
channel.setMessageHandler(
(message, reply) -> {
ArrayList<Object> wrapped = new ArrayList<Object>();
ArrayList<Object> args = (ArrayList<Object>) message;
FirestorePigeonFirebaseApp appArg = (FirestorePigeonFirebaseApp) args.get(0);
String pathArg = (String) args.get(1);
Boolean isCollectionGroupArg = (Boolean) args.get(2);
PigeonQueryParameters parametersArg = (PigeonQueryParameters) args.get(3);
List<Double> queryVectorArg = (List<Double>) args.get(4);
VectorSource sourceArg = VectorSource.values()[(int) args.get(5)];
Number limitArg = (Number) args.get(6);
VectorQueryOptions queryOptionsArg = (VectorQueryOptions) args.get(7);
DistanceMeasure distanceMeasureArg = DistanceMeasure.values()[(int) args.get(8)];
Result<PigeonQuerySnapshot> resultCallback =
new Result<PigeonQuerySnapshot>() {
public void success(PigeonQuerySnapshot result) {
wrapped.add(0, result);
reply.reply(wrapped);
}

public void error(Throwable error) {
ArrayList<Object> wrappedError = wrapError(error);
reply.reply(wrappedError);
}
};

api.findNearest(
appArg,
pathArg,
isCollectionGroupArg,
parametersArg,
queryVectorArg,
sourceArg,
(limitArg == null) ? null : limitArg.longValue(),
queryOptionsArg,
distanceMeasureArg,
resultCallback);
});
} else {
channel.setMessageHandler(null);
}
}
{
BasicMessageChannel<Object> channel =
new BasicMessageChannel<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ - (instancetype)initWithValue:(Source)value {
}
@end

/// An enumeration of firestore source types.
@implementation VectorSourceBox
- (instancetype)initWithValue:(VectorSource)value {
self = [super init];
if (self) {
_value = value;
}
return self;
}
@end

@implementation DistanceMeasureBox
- (instancetype)initWithValue:(DistanceMeasure)value {
self = [super init];
if (self) {
_value = value;
}
return self;
}
@end

/// The listener retrieves data and listens to updates from the local Firestore cache only.
/// If the cache is empty, an empty snapshot will be returned.
/// Snapshot events will be triggered on cache updates, like local mutations or load bundles.
Expand Down Expand Up @@ -168,6 +189,12 @@ + (nullable PigeonQuerySnapshot *)nullableFromList:(NSArray *)list;
- (NSArray *)toList;
@end

@interface VectorQueryOptions ()
+ (VectorQueryOptions *)fromList:(NSArray *)list;
+ (nullable VectorQueryOptions *)nullableFromList:(NSArray *)list;
- (NSArray *)toList;
@end

@interface PigeonGetOptions ()
+ (PigeonGetOptions *)fromList:(NSArray *)list;
+ (nullable PigeonGetOptions *)nullableFromList:(NSArray *)list;
Expand Down Expand Up @@ -410,6 +437,33 @@ - (NSArray *)toList {
}
@end

@implementation VectorQueryOptions
+ (instancetype)makeWithDistanceResultField:(NSString *)distanceResultField
distanceThreshold:(NSNumber *)distanceThreshold {
VectorQueryOptions *pigeonResult = [[VectorQueryOptions alloc] init];
pigeonResult.distanceResultField = distanceResultField;
pigeonResult.distanceThreshold = distanceThreshold;
return pigeonResult;
}
+ (VectorQueryOptions *)fromList:(NSArray *)list {
VectorQueryOptions *pigeonResult = [[VectorQueryOptions alloc] init];
pigeonResult.distanceResultField = GetNullableObjectAtIndex(list, 0);
NSAssert(pigeonResult.distanceResultField != nil, @"");
pigeonResult.distanceThreshold = GetNullableObjectAtIndex(list, 1);
NSAssert(pigeonResult.distanceThreshold != nil, @"");
return pigeonResult;
}
+ (nullable VectorQueryOptions *)nullableFromList:(NSArray *)list {
return (list) ? [VectorQueryOptions fromList:list] : nil;
}
- (NSArray *)toList {
return @[
(self.distanceResultField ?: [NSNull null]),
(self.distanceThreshold ?: [NSNull null]),
];
}
@end

@implementation PigeonGetOptions
+ (instancetype)makeWithSource:(Source)source
serverTimestampBehavior:(ServerTimestampBehavior)serverTimestampBehavior {
Expand Down Expand Up @@ -680,6 +734,8 @@ - (nullable id)readValueOfType:(UInt8)type {
return [PigeonSnapshotMetadata fromList:[self readValue]];
case 140:
return [PigeonTransactionCommand fromList:[self readValue]];
case 141:
return [VectorQueryOptions fromList:[self readValue]];
default:
return [super readValueOfType:type];
}
Expand Down Expand Up @@ -729,6 +785,9 @@ - (void)writeValue:(id)value {
} else if ([value isKindOfClass:[PigeonTransactionCommand class]]) {
[self writeByte:140];
[self writeValue:[value toList]];
} else if ([value isKindOfClass:[VectorQueryOptions class]]) {
[self writeByte:141];
[self writeValue:[value toList]];
} else {
[super writeValue:value];
}
Expand Down Expand Up @@ -1255,6 +1314,50 @@ void FirebaseFirestoreHostApiSetup(id<FlutterBinaryMessenger> binaryMessenger,
[channel setMessageHandler:nil];
}
}
{
FlutterBasicMessageChannel *channel = [[FlutterBasicMessageChannel alloc]
initWithName:@"dev.flutter.pigeon.cloud_firestore_platform_interface."
@"FirebaseFirestoreHostApi.findNearest"
binaryMessenger:binaryMessenger
codec:FirebaseFirestoreHostApiGetCodec()];
if (api) {
NSCAssert([api respondsToSelector:@selector
(findNearestApp:
path:isCollectionGroup:parameters:queryVector:source:limit
:queryOptions:distanceMeasure:completion:)],
@"FirebaseFirestoreHostApi api (%@) doesn't respond to "
@"@selector(findNearestApp:path:isCollectionGroup:parameters:queryVector:source:"
@"limit:queryOptions:distanceMeasure:completion:)",
api);
[channel setMessageHandler:^(id _Nullable message, FlutterReply callback) {
NSArray *args = message;
FirestorePigeonFirebaseApp *arg_app = GetNullableObjectAtIndex(args, 0);
NSString *arg_path = GetNullableObjectAtIndex(args, 1);
NSNumber *arg_isCollectionGroup = GetNullableObjectAtIndex(args, 2);
PigeonQueryParameters *arg_parameters = GetNullableObjectAtIndex(args, 3);
NSArray<NSNumber *> *arg_queryVector = GetNullableObjectAtIndex(args, 4);
VectorSource arg_source = [GetNullableObjectAtIndex(args, 5) integerValue];
NSNumber *arg_limit = GetNullableObjectAtIndex(args, 6);
VectorQueryOptions *arg_queryOptions = GetNullableObjectAtIndex(args, 7);
DistanceMeasure arg_distanceMeasure = [GetNullableObjectAtIndex(args, 8) integerValue];
[api findNearestApp:arg_app
path:arg_path
isCollectionGroup:arg_isCollectionGroup
parameters:arg_parameters
queryVector:arg_queryVector
source:arg_source
limit:arg_limit
queryOptions:arg_queryOptions
distanceMeasure:arg_distanceMeasure
completion:^(PigeonQuerySnapshot *_Nullable output,
FlutterError *_Nullable error) {
callback(wrapResult(output, error));
}];
}];
} else {
[channel setMessageHandler:nil];
}
}
{
FlutterBasicMessageChannel *channel = [[FlutterBasicMessageChannel alloc]
initWithName:@"dev.flutter.pigeon.cloud_firestore_platform_interface."
Expand Down
Loading
Loading