diff --git a/.editorconfig b/.editorconfig index 482cea58e85..d9d36e2054d 100644 --- a/.editorconfig +++ b/.editorconfig @@ -69,7 +69,7 @@ dotnet_naming_rule.static_fields_should_have_prefix.style = static_prefix_sty dotnet_naming_symbols.static_fields.applicable_kinds = field dotnet_naming_symbols.static_fields.required_modifiers = static dotnet_naming_symbols.static_fields.applicable_accessibilities = private, internal, private_protected -dotnet_naming_style.static_prefix_style.required_prefix = s_ +dotnet_naming_style.static_prefix_style.required_prefix = __ dotnet_naming_style.static_prefix_style.capitalization = camel_case # internal and private fields should be _camelCase diff --git a/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs b/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs index fd5998c93b3..199203d345e 100644 --- a/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs +++ b/src/MongoDB.Bson/Serialization/IBsonSerializerExtensions.cs @@ -50,6 +50,48 @@ public static TValue Deserialize(this IBsonSerializer serializer return serializer.Deserialize(context, args); } + /// + /// Gets the serializer for a base type starting from a serializer for a derived type. + /// + /// The serializer for the derived type. + /// The base type. + /// The serializer for the base type. + public static IBsonSerializer GetBaseTypeSerializer(this IBsonSerializer derivedTypeSerializer, Type baseType) + { + if (derivedTypeSerializer.ValueType == baseType) + { + return derivedTypeSerializer; + } + + if (!baseType.IsAssignableFrom(derivedTypeSerializer.ValueType)) + { + throw new ArgumentException($"{baseType} is not assignable from {derivedTypeSerializer.ValueType}."); + } + + return BsonSerializer.LookupSerializer(baseType); // TODO: should be able to ask a serializer for the base type serializer + } + + /// + /// Gets the serializer for a derived type starting from a serializer for a base type. + /// + /// The serializer for the base type. + /// The derived type. + /// The serializer for the derived type. + public static IBsonSerializer GetDerivedTypeSerializer(this IBsonSerializer baseTypeSerializer, Type derivedType) + { + if (baseTypeSerializer.ValueType == derivedType) + { + return baseTypeSerializer; + } + + if (!baseTypeSerializer.ValueType.IsAssignableFrom(derivedType)) + { + throw new ArgumentException($"{baseTypeSerializer.ValueType} is not assignable from {derivedType}."); + } + + return BsonSerializer.LookupSerializer(derivedType); // TODO: should be able to ask a serializer for the derived type serializer + } + /// /// Gets the discriminator convention for a serializer. /// diff --git a/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs index f10cb541d16..e90210fbc14 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/ArraySerializer.cs @@ -13,10 +13,29 @@ * limitations under the License. */ +using System; using System.Collections.Generic; namespace MongoDB.Bson.Serialization.Serializers { + /// + /// A static factory class for ArraySerializers. + /// + public static class ArraySerializer + { + /// + /// Creates an ArraySerializer. + /// + /// The item serializer. + /// An ArraySerializer. + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var itemType = itemSerializer.ValueType; + var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); + return (IBsonSerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); + } + } + /// /// Represents a serializer for one-dimensional arrays. /// diff --git a/src/MongoDB.Bson/Serialization/Serializers/CharSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/CharSerializer.cs index a1688526364..787720f34cf 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/CharSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/CharSerializer.cs @@ -22,6 +22,15 @@ namespace MongoDB.Bson.Serialization.Serializers /// public sealed class CharSerializer : StructSerializerBase, IRepresentationConfigurable { + #region static + private static readonly CharSerializer __instance = new(); + + /// + /// Returns the default instance of CharSerializer. + /// + public static CharSerializer Instance => __instance; + #endregion + // private fields private readonly BsonType _representation; diff --git a/src/MongoDB.Bson/Serialization/Serializers/DictionarySerializerBase.cs b/src/MongoDB.Bson/Serialization/Serializers/DictionarySerializerBase.cs index 3f6ff642bdd..2535cbed080 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/DictionarySerializerBase.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/DictionarySerializerBase.cs @@ -499,20 +499,15 @@ obj is DictionarySerializerBase other && /// public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo) { - if (_dictionaryRepresentation is DictionaryRepresentation.ArrayOfArrays or DictionaryRepresentation.ArrayOfDocuments) - { - var representation = _dictionaryRepresentation == DictionaryRepresentation.ArrayOfArrays - ? BsonType.Array - : BsonType.Document; - var keySerializer = _lazyKeySerializer.Value; - var valueSerializer = _lazyValueSerializer.Value; - var keyValuePairSerializer = new KeyValuePairSerializer(representation, keySerializer, valueSerializer); - serializationInfo = new BsonSerializationInfo(null, keyValuePairSerializer, keyValuePairSerializer.ValueType); - return true; - } - - serializationInfo = null; - return false; + var representation = _dictionaryRepresentation == DictionaryRepresentation.ArrayOfArrays + ? BsonType.Array + : BsonType.Document; + var keySerializer = _lazyKeySerializer.Value; + var valueSerializer = _lazyValueSerializer.Value; + var keyValuePairSerializer = new KeyValuePairSerializer(representation, keySerializer, valueSerializer); + + serializationInfo = new BsonSerializationInfo(null, keyValuePairSerializer, keyValuePairSerializer.ValueType); + return true; } /// diff --git a/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs index 8740bdd3a9b..423b9500bed 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/NullableSerializer.cs @@ -33,6 +33,73 @@ public interface INullableSerializer /// public static class NullableSerializer { + private readonly static IBsonSerializer __nullableBooleanInstance = new NullableSerializer(BooleanSerializer.Instance); + private readonly static IBsonSerializer __nullableDecimalInstance = new NullableSerializer(DecimalSerializer.Instance); + private readonly static IBsonSerializer __nullableDecimal128Instance = new NullableSerializer(Decimal128Serializer.Instance); + private readonly static IBsonSerializer __nullableDoubleInstance = new NullableSerializer(DoubleSerializer.Instance); + private readonly static IBsonSerializer __nullableInt32Instance = new NullableSerializer(Int32Serializer.Instance); + private readonly static IBsonSerializer __nullableInt64Instance = new NullableSerializer(Int64Serializer.Instance); + private readonly static IBsonSerializer __nullableLocalDateTimeInstance = new NullableSerializer(DateTimeSerializer.LocalInstance); + private readonly static IBsonSerializer __nullableObjectIdInstance = new NullableSerializer(ObjectIdSerializer.Instance); + private readonly static IBsonSerializer __nullableSingleInstance = new NullableSerializer(SingleSerializer.Instance); + private readonly static IBsonSerializer __nullableStandardGuidInstance = new NullableSerializer(GuidSerializer.StandardInstance); + private readonly static IBsonSerializer __nullableUtcDateTimeInstance = new NullableSerializer(DateTimeSerializer.UtcInstance); + + /// + /// Gets a serializer for nullable bools. + /// + public static IBsonSerializer NullableBooleanInstance => __nullableBooleanInstance; + + /// + /// Gets a serializer for nullable decimals. + /// + public static IBsonSerializer NullableDecimalInstance => __nullableDecimalInstance; + + /// + /// Gets a serializer for nullable Decimal128s. + /// + public static IBsonSerializer NullableDecimal128Instance => __nullableDecimal128Instance; + + /// + /// Gets a serializer for nullable doubles. + /// + public static IBsonSerializer NullableDoubleInstance => __nullableDoubleInstance; + + /// + /// Gets a serializer for nullable ints. + /// + public static IBsonSerializer NullableInt32Instance => __nullableInt32Instance; + + /// + /// Gets a serializer for nullable longs. + /// + public static IBsonSerializer NullableInt64Instance => __nullableInt64Instance; + + /// + /// Gets a serializer for local DateTime. + /// + public static IBsonSerializer NullableLocalDateTimeInstance => __nullableLocalDateTimeInstance; + + /// + /// Gets a serializer for nullable floats. + /// + public static IBsonSerializer NullableSingleInstance => __nullableSingleInstance; + + /// + /// Gets a serializer for nullable ObjectIds. + /// + public static IBsonSerializer NullableObjectIdInstance => __nullableObjectIdInstance; + + /// + /// Gets a serializer for nullable Guids with standard representation. + /// + public static IBsonSerializer NullableStandardGuidInstance => __nullableStandardGuidInstance; + + /// + /// Gets a serializer for UTC DateTime. + /// + public static IBsonSerializer NullableUtcDateTimeInstance => __nullableUtcDateTimeInstance; + /// /// Creates a NullableSerializer. /// diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs index db7618ce677..e80842589d8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/ExtensionMethods/ExpressionExtensions.cs @@ -60,5 +60,17 @@ public static TValue GetConstantValue(this Expression expression, Expres var message = $"Expression must be a constant: {expression} in {containingExpression}."; throw new ExpressionNotSupportedException(message); } + + public static bool IsConvert(this Expression expression, out Expression operand) + { + if (expression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression) + { + operand = unaryExpression.Operand; + return true; + } + + operand = null; + return false; + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index 40e41bbd51c..4b4b5906ca3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -64,7 +64,8 @@ private AstStage RenderProjectStage( out IBsonSerializer outputSerializer) { var partiallyEvaluatedOutput = (Expression>)LinqExpressionPreprocessor.Preprocess(_output); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedOutput.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedOutput, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); outputSerializer = (IBsonSerializer)projectSerializer; @@ -106,7 +107,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -150,7 +152,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer, TInput>> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -188,7 +191,8 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); - var context = TranslationContext.Create(translationOptions); + var parameter = partiallyEvaluatedGroupBy.Parameters.Single(); + var context = TranslationContext.Create(partiallyEvaluatedGroupBy, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar); var groupBySerializer = (IBsonSerializer)groupByTranslation.Serializer; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/BsonTypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/BsonTypeExtensions.cs new file mode 100644 index 00000000000..de23e6e9d5e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/BsonTypeExtensions.cs @@ -0,0 +1,24 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +internal static class BsonTypeExtensions +{ + public static bool IsNumeric(this BsonType bsonType) + => bsonType is BsonType.Decimal128 or BsonType.Double or BsonType.Int32 or BsonType.Int64; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs index 2b5c4a3a012..a2eed8cafd8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs @@ -35,7 +35,8 @@ public static LambdaExpression UnquoteLambdaIfQueryableMethod(MethodInfo method, Ensure.IsNotNull(method, nameof(method)); Ensure.IsNotNull(expression, nameof(expression)); - if (method.DeclaringType == typeof(Queryable)) + var declaringType = method.DeclaringType; + if (declaringType == typeof(Queryable) || declaringType == typeof(MongoQueryable)) { return UnquoteLambda(expression); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/IBsonSerializerExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/IBsonSerializerExtensions.cs new file mode 100644 index 00000000000..ffa90c1ab7a --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/IBsonSerializerExtensions.cs @@ -0,0 +1,165 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +internal static class IBsonSerializerExtensions +{ + public static bool CanBeAssignedTo(this IBsonSerializer sourceSerializer, IBsonSerializer targetSerializer) + { + if (sourceSerializer.Equals(targetSerializer)) + { + return true; + } + + if (sourceSerializer.ValueType.IsNumeric() && + targetSerializer.ValueType.IsNumeric() && + sourceSerializer.HasNumericRepresentation() && + targetSerializer.HasNumericRepresentation()) + { + return true; + } + + if (targetSerializer.ValueType.IsAssignableFrom(sourceSerializer.ValueType)) + { + return true; + } + + return false; + } + + public static IBsonSerializer GetItemSerializer(this IBsonSerializer serializer) + => ArraySerializerHelper.GetItemSerializer(serializer); + + public static IBsonSerializer GetItemSerializer(this IBsonSerializer serializer, int index) + { + if (serializer is IPolymorphicArraySerializer polymorphicArraySerializer) + { + return polymorphicArraySerializer.GetItemSerializer(index); + } + else + { + return serializer.GetItemSerializer(); + } + } + + public static IBsonSerializer GetItemSerializer(this IBsonSerializer serializer, Expression indexExpression, Expression containingExpression) + { + if (serializer is IPolymorphicArraySerializer polymorphicArraySerializer) + { + var index = indexExpression.GetConstantValue(containingExpression); + return polymorphicArraySerializer.GetItemSerializer(index); + } + else + { + return serializer.GetItemSerializer(); + } + } + + public static IReadOnlyList GetMatchingMemberSerializationInfosForConstructorParameters( + this IBsonSerializer serializer, + Expression expression, + ConstructorInfo constructorInfo) + { + if (serializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer type {serializer.GetType().Name} does not implement IBsonDocumentSerializer"); + } + + var matchingMemberSerializationInfos = new List(); + foreach (var constructorParameter in constructorInfo.GetParameters()) + { + var matchingMemberSerializationInfo = GetMatchingMemberSerializationInfo(expression, documentSerializer, constructorParameter.Name); + matchingMemberSerializationInfos.Add(matchingMemberSerializationInfo); + } + + return matchingMemberSerializationInfos; + + static BsonSerializationInfo GetMatchingMemberSerializationInfo( + Expression expression, + IBsonDocumentSerializer documentSerializer, + string constructorParameterName) + { + var possibleMatchingMembers = documentSerializer.ValueType.GetMembers().Where(m => m.Name.Equals(constructorParameterName, StringComparison.OrdinalIgnoreCase)).ToArray(); + if (possibleMatchingMembers.Length == 0) + { + throw new ExpressionNotSupportedException(expression, because: $"no matching member found for constructor parameter: {constructorParameterName}"); + } + if (possibleMatchingMembers.Length > 1) + { + throw new ExpressionNotSupportedException(expression, because: $"multiple possible matching members found for constructor parameter: {constructorParameterName}"); + } + var matchingMemberName = possibleMatchingMembers[0].Name; + + if (!documentSerializer.TryGetMemberSerializationInfo(matchingMemberName, out var matchingMemberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer of type {documentSerializer.GetType().Name} did not provide serialization info for member {matchingMemberName}"); + } + + return matchingMemberSerializationInfo; + } + } + + public static bool HasNumericRepresentation(this IBsonSerializer serializer) + { + return + serializer is IHasRepresentationSerializer hasRepresentationSerializer && + hasRepresentationSerializer.Representation.IsNumeric(); + } + + public static bool IsKeyValuePairSerializer( + this IBsonSerializer serializer, + out string keyElementName, + out string valueElementName, + out IBsonSerializer keySerializer, + out IBsonSerializer valueSerializer) + { + // TODO: add properties to IKeyValuePairSerializer to let us extract the needed information + // note: we can only verify the existence of "Key" and "Value" properties, but can't verify there are no others + if (serializer.ValueType is var valueType && + valueType.IsConstructedGenericType && + valueType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>) && + serializer is IBsonDocumentSerializer documentSerializer && + documentSerializer.TryGetMemberSerializationInfo("Key", out var keySerializationInfo) && + documentSerializer.TryGetMemberSerializationInfo("Value", out var valueSerializationInfo)) + { + keyElementName = keySerializationInfo.ElementName; + valueElementName = valueSerializationInfo.ElementName; + keySerializer = keySerializationInfo.Serializer; + valueSerializer = valueSerializationInfo.Serializer; + return true; + } + + keyElementName = null; + valueElementName = null; + keySerializer = null; + valueSerializer = null; + return false; + } + + public static IBsonSerializer GetValueSerializerIfWrapped(this IBsonSerializer serializer) + { + return serializer is IWrappedValueSerializer wrappedValueSerializer ? wrappedValueSerializer.ValueSerializer.GetValueSerializerIfWrapped() : serializer; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs index f73c074c835..93c7ce3dba6 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs @@ -16,6 +16,7 @@ using System; using System.Linq; using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Misc { @@ -125,31 +126,9 @@ public static bool IsOneOf(this MethodInfo method, MethodInfo comparand1, Method return method.Is(comparand1) || method.Is(comparand2) || method.Is(comparand3) || method.Is(comparand4); } - public static bool IsOneOf(this MethodInfo method, params MethodInfo[] comparands) - { - for (var i = 0; i < comparands.Length; i++) - { - if (method.Is(comparands[i])) - { - return true; - } - } - - return false; - } + public static bool IsOneOf(this MethodInfo method, IReadOnlyMethodInfoSet set) => set.Contains(method); - public static bool IsOneOf(this MethodInfo method, params MethodInfo[][] comparands) - { - for (var i = 0; i < comparands.Length; i++) - { - if (method.IsOneOf(comparands[i])) - { - return true; - } - } - - return false; - } + public static bool IsOneOf(this MethodInfo method, IReadOnlyMethodInfoSet set1, IReadOnlyMethodInfoSet set2) => set1.Contains(method) || set2.Contains(method); public static bool IsStaticCompareMethod(this MethodInfo method) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs index 0b34b0bd7cf..c0c4a473eb2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs @@ -97,7 +97,7 @@ public static BsonType GetRepresentation(IBsonSerializer serializer) return GetRepresentation(downcastingSerializer.DerivedSerializer); } - if (serializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + if (serializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) { return GetRepresentation(enumUnderlyingTypeSerializer.EnumSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs index 636c616deb3..aadee088c26 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs @@ -72,7 +72,7 @@ public static Type GetIEnumerableGenericInterface(this Type enumerableType) throw new InvalidOperationException($"Could not find IEnumerable interface of type: {enumerableType}."); } - public static bool Implements(this Type type, Type @interface) + public static bool ImplementsInterface(this Type type, Type @interface) { if (type == @interface) { @@ -102,6 +102,7 @@ public static bool Implements(this Type type, Type @interface) public static bool ImplementsDictionaryInterface(this Type type, out Type keyType, out Type valueType) { + // note: returns true for IReadOnlyDictionary also if (TryGetGenericInterface(type, __dictionaryInterfaceDefinitions, out var dictionaryInterface)) { var genericArguments = dictionaryInterface.GetGenericArguments(); @@ -146,6 +147,30 @@ public static bool ImplementsIList(this Type type, out Type itemType) return false; } + public static bool ImplementsIOrderedEnumerable(this Type type, out Type itemType) + { + if (TryGetIOrderedEnumerableGenericInterface(type, out var iOrderedEnumerableType)) + { + itemType = iOrderedEnumerableType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + + public static bool ImplementsIOrderedQueryable(this Type type, out Type itemType) + { + if (TryGetIOrderedQueryableGenericInterface(type, out var iorderedQueryableType)) + { + itemType = iorderedQueryableType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + public static bool ImplementsIQueryable(this Type type, out Type itemType) { if (TryGetIQueryableGenericInterface(type, out var iqueryableType)) @@ -158,6 +183,25 @@ public static bool ImplementsIQueryable(this Type type, out Type itemType) return false; } + public static bool ImplementsIQueryableOf(this Type type, Type itemType) + { + return + ImplementsIEnumerable(type, out var actualItemType) && + actualItemType == itemType; + } + + public static bool ImplementsISet(this Type type, out Type itemType) + { + if (TryGetISetGenericInterface(type, out var isetType)) + { + itemType = isetType.GetGenericArguments()[0]; + return true; + } + + itemType = null; + return false; + } + public static bool Is(this Type type, Type comparand) { if (type == comparand) @@ -197,11 +241,14 @@ public static bool IsArray(this Type type, out Type itemType) return false; } + public static bool IsBoolean(this Type type) + { + return type == typeof(bool); + } + public static bool IsBooleanOrNullableBoolean(this Type type) { - return - type == typeof(bool) || - type.IsNullable(out var valueType) && valueType == typeof(bool); + return IsBoolean(type) || type.IsNullable(out var valueType) && IsBoolean(valueType); } public static bool IsConvertibleToEnum(this Type type) @@ -294,23 +341,18 @@ public static bool IsNullableOf(this Type type, Type valueType) return type.IsNullable(out var nullableValueType) && nullableValueType == valueType; } - public static bool IsReadOnlySpanOf(this Type type, Type itemType) - { - return - type.IsGenericType && - type.GetGenericTypeDefinition() == typeof(ReadOnlySpan<>) && - type.GetGenericArguments()[0] == itemType; - } - public static bool IsNumeric(this Type type) { + // note: treating more types as numeric would require careful analysis of impact on callers of this method return - type == typeof(int) || - type == typeof(long) || + type == typeof(char) || // TODO: should we really treat char as numeric? + type == typeof(decimal) || + type == typeof(Decimal128) || type == typeof(double) || type == typeof(float) || - type == typeof(decimal) || - type == typeof(Decimal128); + type == typeof(int) || + type == typeof(long) || + type == typeof(short); } public static bool IsNumericOrNullableNumeric(this Type type) @@ -320,6 +362,14 @@ public static bool IsNumericOrNullableNumeric(this Type type) type.IsNullable(out var valueType) && valueType.IsNumeric(); } + public static bool IsReadOnlySpanOf(this Type type, Type itemType) + { + return + type.IsGenericType && + type.GetGenericTypeDefinition() == typeof(ReadOnlySpan<>) && + type.GetGenericArguments()[0] == itemType; + } + public static bool IsSameAsOrNullableOf(this Type type, Type valueType) { return type == valueType || type.IsNullableOf(valueType); @@ -337,7 +387,7 @@ public static bool IsSubclassOfOrImplements(this Type type, Type baseTypeOrInter { return type.IsSubclassOf(baseTypeOrInterface) || - type.Implements(baseTypeOrInterface); + type.ImplementsInterface(baseTypeOrInterface); } public static bool IsTuple(this Type type) @@ -386,9 +436,18 @@ public static bool TryGetIEnumerableGenericInterface(this Type type, out Type ie public static bool TryGetIListGenericInterface(this Type type, out Type ilistGenericInterface) => TryGetGenericInterface(type, typeof(IList<>), out ilistGenericInterface); + public static bool TryGetIOrderedEnumerableGenericInterface(this Type type, out Type iorderedEnumerableGenericInterface) + => TryGetGenericInterface(type, typeof(IOrderedEnumerable<>), out iorderedEnumerableGenericInterface); + + public static bool TryGetIOrderedQueryableGenericInterface(this Type type, out Type iorderedQueryableGenericInterface) + => TryGetGenericInterface(type, typeof(IOrderedQueryable<>), out iorderedQueryableGenericInterface); + public static bool TryGetIQueryableGenericInterface(this Type type, out Type iqueryableGenericInterface) => TryGetGenericInterface(type, typeof(IQueryable<>), out iqueryableGenericInterface); + public static bool TryGetISetGenericInterface(this Type type, out Type isetGenericInterface) + => TryGetGenericInterface(type, typeof(ISet<>), out isetGenericInterface); + private static TValue GetDefaultValueGeneric() { return default(TValue); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs index fe96bacae36..a868f3a8508 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs @@ -41,7 +41,7 @@ internal class MongoQuery : MongoQuery, IOrderedQue public MongoQuery(MongoQueryProvider provider) { _provider = provider; - _expression = Expression.Constant(this); + _expression = Expression.Constant(this, typeof(IQueryable<>).MakeGenericType(typeof(TDocument))); } public MongoQuery(MongoQueryProvider provider, Expression expression) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs index fd12fccd4c9..1e4e3b582c4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/BsonDocumentMethod.cs @@ -22,14 +22,20 @@ internal static class BsonDocumentMethod { // private static fields private static readonly MethodInfo __addWithNameAndValue; + private static readonly MethodInfo __getItemWithIndex; + private static readonly MethodInfo __getItemWithName; // static constructor static BsonDocumentMethod() { __addWithNameAndValue = ReflectionInfo.Method((BsonDocument document, string name, BsonValue value) => document.Add(name, value)); + __getItemWithIndex = ReflectionInfo.Method((BsonDocument document, int index) => document[index]); + __getItemWithName = ReflectionInfo.Method((BsonDocument document, string name) => document[name]); } // public static properties public static MethodInfo AddWithNameAndValue => __addWithNameAndValue; + public static MethodInfo GetItemWithIndex => __getItemWithIndex; + public static MethodInfo GetItemWithName => __getItemWithName; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DateTimeMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DateTimeMethod.cs index 4e677ffe5f3..981fbe14c01 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DateTimeMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DateTimeMethod.cs @@ -61,9 +61,22 @@ internal static class DateTimeMethod private static readonly MethodInfo __week; private static readonly MethodInfo __weekWithTimezone; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __addOrSubtractOverloads; + private static readonly IReadOnlyMethodInfoSet __addOrSubtractWithTimeSpanOverloads; + private static readonly IReadOnlyMethodInfoSet __addOrSubtractWithTimezoneOverloads; + private static readonly IReadOnlyMethodInfoSet __addOrSubtractWithUnitOverloads; + private static readonly IReadOnlyMethodInfoSet __subtractReturningDateTimeOverloads; + private static readonly IReadOnlyMethodInfoSet __subtractReturningInt64Overloads; + private static readonly IReadOnlyMethodInfoSet __subtractReturningTimeSpanWithMillisecondsUnitsOverloads; + private static readonly IReadOnlyMethodInfoSet __subtractWithDateTimeOverloads; + private static readonly IReadOnlyMethodInfoSet __subtractWithTimezoneOverloads; + private static readonly IReadOnlyMethodInfoSet __subtractWithUnitOverloads; + // static constructor static DateTimeMethod() { + // initialize methods before sets of methods __add = ReflectionInfo.Method((DateTime @this, TimeSpan value) => @this.Add(value)); __addDays = ReflectionInfo.Method((DateTime @this, double value) => @this.AddDays(value)); __addDaysWithTimezone = ReflectionInfo.Method((DateTime @this, double value, string timezone) => @this.AddDays(value, timezone)); @@ -103,6 +116,111 @@ static DateTimeMethod() __truncateWithBinSizeAndTimezone = ReflectionInfo.Method((DateTime @this, DateTimeUnit unit, long binSize, string timezone) => @this.Truncate(unit, binSize, timezone)); __week = ReflectionInfo.Method((DateTime @this) => @this.Week()); __weekWithTimezone = ReflectionInfo.Method((DateTime @this, string timezone) => @this.Week(timezone)); + + // initialize sets of methods after methods + __addOrSubtractOverloads = MethodInfoSet.Create( + [ + __add, + __addDays, + __addDaysWithTimezone, + __addHours, + __addHoursWithTimezone, + __addMilliseconds, + __addMillisecondsWithTimezone, + __addMinutes, + __addMinutesWithTimezone, + __addMonths, + __addMonthsWithTimezone, + __addQuarters, + __addQuartersWithTimezone, + __addSeconds, + __addSecondsWithTimezone, + __addTicks, + __addWeeks, + __addWeeksWithTimezone, + __addWithTimezone, + __addWithUnit, + __addWithUnitAndTimezone, + __addYears, + __addYearsWithTimezone, + __subtractWithTimeSpan, + __subtractWithTimeSpanAndTimezone, + __subtractWithUnit, + __subtractWithUnitAndTimezone + ]); + + __addOrSubtractWithTimeSpanOverloads = MethodInfoSet.Create( + [ + __add, + __addWithTimezone, + __subtractWithTimeSpan, + __subtractWithTimeSpanAndTimezone + ]); + + __addOrSubtractWithTimezoneOverloads = MethodInfoSet.Create( + [ + __addDaysWithTimezone, + __addHoursWithTimezone, + __addMillisecondsWithTimezone, + __addMinutesWithTimezone, + __addMonthsWithTimezone, + __addQuartersWithTimezone, + __addSecondsWithTimezone, + __addWeeksWithTimezone, + __addWithTimezone, + __addWithUnitAndTimezone, + __addYearsWithTimezone, + __subtractWithTimeSpanAndTimezone, + __subtractWithUnitAndTimezone + ]); + + __addOrSubtractWithUnitOverloads = MethodInfoSet.Create( + [ + __addWithUnit, + __addWithUnitAndTimezone, + __subtractWithUnit, + __subtractWithUnitAndTimezone + ]); + + __subtractReturningDateTimeOverloads = MethodInfoSet.Create( + [ + __subtractWithTimeSpan, + __subtractWithTimeSpanAndTimezone, + __subtractWithUnit, + __subtractWithUnitAndTimezone + ]); + + __subtractReturningInt64Overloads = MethodInfoSet.Create( + [ + __subtractWithDateTimeAndUnit, + __subtractWithDateTimeAndUnitAndTimezone + ]); + + __subtractReturningTimeSpanWithMillisecondsUnitsOverloads = MethodInfoSet.Create( + [ + __subtractWithDateTime, + __subtractWithDateTimeAndTimezone + ]); + + __subtractWithDateTimeOverloads = MethodInfoSet.Create( + [ + __subtractWithDateTime, + __subtractWithDateTimeAndTimezone, + __subtractWithDateTimeAndUnit, + __subtractWithDateTimeAndUnitAndTimezone + ]); + + __subtractWithTimezoneOverloads = MethodInfoSet.Create( + [ + __subtractWithDateTimeAndTimezone, + __subtractWithDateTimeAndUnitAndTimezone + ]); + + __subtractWithUnitOverloads = MethodInfoSet.Create( + [ + __subtractWithDateTimeAndUnit, + __subtractWithDateTimeAndUnitAndTimezone + ]); } // public properties @@ -145,5 +263,17 @@ static DateTimeMethod() public static MethodInfo TruncateWithBinSizeAndTimezone => __truncateWithBinSizeAndTimezone; public static MethodInfo Week => __week; public static MethodInfo WeekWithTimezone => __weekWithTimezone; + + // sets of methods + public static IReadOnlyMethodInfoSet AddOrSubtractOverloads => __addOrSubtractOverloads; + public static IReadOnlyMethodInfoSet AddOrSubtractWithTimeSpanOverloads => __addOrSubtractWithTimeSpanOverloads; + public static IReadOnlyMethodInfoSet AddOrSubtractWithTimezoneOverloads => __addOrSubtractWithTimezoneOverloads; + public static IReadOnlyMethodInfoSet AddOrSubtractWithUnitOverloads => __addOrSubtractWithUnitOverloads; + public static IReadOnlyMethodInfoSet SubtractReturningDateTimeOverloads => __subtractReturningDateTimeOverloads; + public static IReadOnlyMethodInfoSet SubtractReturningInt64Overloads => __subtractReturningInt64Overloads; + public static IReadOnlyMethodInfoSet SubtractReturningTimeSpanWithMillisecondsUnitsOverloads => __subtractReturningTimeSpanWithMillisecondsUnitsOverloads; + public static IReadOnlyMethodInfoSet SubtractWithDateTimeOverloads => __subtractWithDateTimeOverloads; + public static IReadOnlyMethodInfoSet SubtractWithTimezoneOverloads => __subtractWithTimezoneOverloads; + public static IReadOnlyMethodInfoSet SubtractWithUnitOverloads => __subtractWithUnitOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryMethod.cs index dd245eb6c3c..665a4ac8548 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryMethod.cs @@ -21,6 +21,18 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection internal static class DictionaryMethod { // public static methods + public static bool IsContainsKeyMethod(MethodInfo method) + { + return + !method.IsStatic && + method.Name == "ContainsKey" && + method.DeclaringType.ImplementsDictionaryInterface(out var keyType, out _) && + method.GetParameters() is var parameters && + parameters.Length == 1 && + parameters[0].ParameterType == keyType && + method.ReturnType == typeof(bool); + } + public static bool IsGetItemWithKeyMethod(MethodInfo method) { return diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs index a10f2a67531..c97a186424d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs @@ -30,6 +30,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __aggregateWithSeedAndFunc; private static readonly MethodInfo __aggregateWithSeedFuncAndResultSelector; private static readonly MethodInfo __all; + private static readonly MethodInfo __allWithPredicate; private static readonly MethodInfo __any; private static readonly MethodInfo __anyWithPredicate; private static readonly MethodInfo __append; @@ -74,7 +75,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __firstOrDefault; private static readonly MethodInfo __firstOrDefaultWithPredicate; private static readonly MethodInfo __firstWithPredicate; - private static readonly MethodInfo __groupBy; + private static readonly MethodInfo __groupByWithKeySelector; private static readonly MethodInfo __groupByWithKeySelectorAndElementSelector; private static readonly MethodInfo __groupByWithKeySelectorAndResultSelector; private static readonly MethodInfo __groupByWithKeySelectorElementSelectorAndResultSelector; @@ -144,8 +145,9 @@ internal static class EnumerableMethod private static readonly MethodInfo __range; private static readonly MethodInfo __repeat; private static readonly MethodInfo __reverse; + private static readonly MethodInfo __reverseWithArray; // will be null on target frameworks that don't have this method private static readonly MethodInfo __select; - private static readonly MethodInfo __selectMany; + private static readonly MethodInfo __selectManyWithSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorAndResultSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorTakingIndexAndResultSelector; private static readonly MethodInfo __selectManyWithSelectorTakingIndex; @@ -192,13 +194,29 @@ internal static class EnumerableMethod private static readonly MethodInfo __whereWithPredicateTakingIndex; private static readonly MethodInfo __zip; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __pickOverloads; + private static readonly IReadOnlyMethodInfoSet __pickOverloadsThatCanOnlyBeUsedAsGroupByAccumulators; + private static readonly IReadOnlyMethodInfoSet __pickWithComputedNOverloads; + private static readonly IReadOnlyMethodInfoSet __pickWithNOverloads; + private static readonly IReadOnlyMethodInfoSet __pickWithSortByOverloads; + private static readonly IReadOnlyMethodInfoSet __reverseOverloads; + // static constructor static EnumerableMethod() { + // initialize methods before sets of methods +#if NET10_OR_GREATER + __reverseWithArray = ReflectionInfo.Method(array source) => source.Reverse()); +#else + __reverseWithArray = GetReverseWithArrayMethodInfo(); // support users running net10 even though we don't target net10 yet +#endif + __aggregateWithFunc = ReflectionInfo.Method((IEnumerable source, Func func) => source.Aggregate(func)); __aggregateWithSeedAndFunc = ReflectionInfo.Method((IEnumerable source, object seed, Func func) => source.Aggregate(seed, func)); __aggregateWithSeedFuncAndResultSelector = ReflectionInfo.Method((IEnumerable source, object seed, Func func, Func resultSelector) => source.Aggregate(seed, func, resultSelector)); __all = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.All(predicate)); + __allWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.All(predicate)); __any = ReflectionInfo.Method((IEnumerable source) => source.Any()); __anyWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.Any(predicate)); __append = ReflectionInfo.Method((IEnumerable source, object element) => source.Append(element)); @@ -243,7 +261,7 @@ static EnumerableMethod() __firstOrDefault = ReflectionInfo.Method((IEnumerable source) => source.FirstOrDefault()); __firstOrDefaultWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.FirstOrDefault(predicate)); __firstWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.First(predicate)); - __groupBy = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.GroupBy(keySelector)); + __groupByWithKeySelector = ReflectionInfo.Method((IEnumerable source, Func keySelector) => source.GroupBy(keySelector)); __groupByWithKeySelectorAndElementSelector = ReflectionInfo.Method((IEnumerable source, Func keySelector, Func elementSelector) => source.GroupBy(keySelector, elementSelector)); __groupByWithKeySelectorAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func keySelector, Func resultSelector) => source.GroupBy(keySelector, resultSelector)); __groupByWithKeySelectorElementSelectorAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func keySelector, Func elementSelector, Func, object> resultSelector) => source.GroupBy(keySelector, elementSelector, resultSelector)); @@ -314,7 +332,7 @@ static EnumerableMethod() __repeat = ReflectionInfo.Method((object element, int count) => Enumerable.Repeat(element, count)); __reverse = ReflectionInfo.Method((IEnumerable source) => source.Reverse()); __select = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Select(selector)); - __selectMany = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); + __selectManyWithSelector = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); __selectManyWithCollectionSelectorAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func> collectionSelector, Func resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithCollectionSelectorTakingIndexAndResultSelector = ReflectionInfo.Method((IEnumerable source, Func> collectionSelector, Func resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); @@ -360,6 +378,75 @@ static EnumerableMethod() __where = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.Where(predicate)); __whereWithPredicateTakingIndex = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.Where(predicate)); __zip = ReflectionInfo.Method((IEnumerable first, IEnumerable second, Func resultSelector) => first.Zip(second, resultSelector)); + + // initialize sets of methods after methods + __pickOverloads = MethodInfoSet.Create( + [ + __bottom, + __bottomN, + __bottomNWithComputedN, + __firstN, + __firstNWithComputedN, + __lastN, + __lastNWithComputedN, + __maxN, + __maxNWithComputedN, + __minN, + __minNWithComputedN, + __top, + __topN, + __topNWithComputedN + ]); + + __pickOverloadsThatCanOnlyBeUsedAsGroupByAccumulators = MethodInfoSet.Create( + [ + __bottom, + __bottomN, + __bottomNWithComputedN, + __firstNWithComputedN, + __lastNWithComputedN, + __maxNWithComputedN, + __minNWithComputedN, + __top, + __topN, + __topNWithComputedN + ]); + + __pickWithComputedNOverloads = MethodInfoSet.Create( + [ + __bottomNWithComputedN, + __firstNWithComputedN, + __lastNWithComputedN, + __maxNWithComputedN, + __minNWithComputedN, + __topNWithComputedN + ]); + + __pickWithNOverloads = MethodInfoSet.Create( + [ + __bottomN, + __firstN, + __lastN, + __maxN, + __minN, + __topN + ]); + + __pickWithSortByOverloads = MethodInfoSet.Create( + [ + __bottom, + __bottomN, + __bottomNWithComputedN, + __top, + __topN, + __topNWithComputedN + ]); + + __reverseOverloads = MethodInfoSet.Create( + [ + __reverse, + __reverseWithArray + ]); } // public properties @@ -367,6 +454,7 @@ static EnumerableMethod() public static MethodInfo AggregateWithSeedAndFunc => __aggregateWithSeedAndFunc; public static MethodInfo AggregateWithSeedFuncAndResultSelector => __aggregateWithSeedFuncAndResultSelector; public static MethodInfo All => __all; + public static MethodInfo AllWithPredicate => __allWithPredicate; public static MethodInfo Any => __any; public static MethodInfo AnyWithPredicate => __anyWithPredicate; public static MethodInfo Append => __append; @@ -411,7 +499,7 @@ static EnumerableMethod() public static MethodInfo FirstOrDefault => __firstOrDefault; public static MethodInfo FirstOrDefaultWithPredicate => __firstOrDefaultWithPredicate; public static MethodInfo FirstWithPredicate => __firstWithPredicate; - public static MethodInfo GroupBy => __groupBy; + public static MethodInfo GroupByWithKeySelector => __groupByWithKeySelector; public static MethodInfo GroupByWithKeySelectorAndElementSelector => __groupByWithKeySelectorAndElementSelector; public static MethodInfo GroupByWithKeySelectorAndResultSelector => __groupByWithKeySelectorAndResultSelector; public static MethodInfo GroupByWithKeySelectorElementSelectorAndResultSelector => __groupByWithKeySelectorElementSelectorAndResultSelector; @@ -481,8 +569,9 @@ static EnumerableMethod() public static MethodInfo Range => __range; public static MethodInfo Repeat => __repeat; public static MethodInfo Reverse => __reverse; + public static MethodInfo ReverseWithArray => __reverseWithArray; public static MethodInfo Select => __select; - public static MethodInfo SelectMany => __selectMany; + public static MethodInfo SelectManyWithSelector => __selectManyWithSelector; public static MethodInfo SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector; public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector; public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex; @@ -529,6 +618,14 @@ static EnumerableMethod() public static MethodInfo WhereWithPredicateTakingIndex => __whereWithPredicateTakingIndex; public static MethodInfo Zip => __zip; + // sets of methods + public static IReadOnlyMethodInfoSet PickOverloads => __pickOverloads; + public static IReadOnlyMethodInfoSet PickOverloadsThatCanOnlyBeUsedAsGroupByAccumulators => __pickOverloadsThatCanOnlyBeUsedAsGroupByAccumulators; + public static IReadOnlyMethodInfoSet PickWithComputedNOverloads => __pickWithComputedNOverloads; + public static IReadOnlyMethodInfoSet PickWithNOverloads => __pickWithNOverloads; + public static IReadOnlyMethodInfoSet PickWithSortByOverloads => __pickWithSortByOverloads; + public static IReadOnlyMethodInfoSet ReverseOverloads => __reverseOverloads; + // public methods public static bool IsContainsMethod(MethodCallExpression methodCallExpression, out Expression sourceExpression, out Expression valueExpression) { @@ -613,5 +710,28 @@ public static MethodInfo MakeWhere(Type tsource) { return __where.MakeGenericMethod(tsource); } + +#if !NET10_OR_GREATER + private static MethodInfo GetReverseWithArrayMethodInfo() + { + // returns null on target frameworks that don't have this method + return + typeof(Enumerable) + .GetMethods() + .SingleOrDefault(m => + m.IsPublic && + m.IsStatic && + m.Name == "Reverse" && + m.IsGenericMethodDefinition && + m.GetGenericArguments() is var genericArguments && + genericArguments.Length == 1 && + genericArguments[0] is var tsource && + m.ReturnType == typeof(IEnumerable<>).MakeGenericType(tsource) && + m.GetParameters() is var parameters && + parameters.Length == 1 && + parameters[0] is var sourceParameter && + sourceParameter.ParameterType == tsource.MakeArrayType()); + } +#endif } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableOrQueryableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableOrQueryableMethod.cs new file mode 100644 index 00000000000..1b8611b9d58 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableOrQueryableMethod.cs @@ -0,0 +1,955 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal static class EnumerableOrQueryableMethod +{ + // methods (in this file matching Enumerable and Queryable methods are treated as if they were one method) + private static readonly IReadOnlyMethodInfoSet __aggregateWithFunc; + private static readonly IReadOnlyMethodInfoSet __aggregateWithSeedAndFunc; + private static readonly IReadOnlyMethodInfoSet __aggregateWithSeedFuncAndResultSelector; + private static readonly IReadOnlyMethodInfoSet __all; + private static readonly IReadOnlyMethodInfoSet __any; + private static readonly IReadOnlyMethodInfoSet __anyWithPredicate; + private static readonly IReadOnlyMethodInfoSet __append; + private static readonly IReadOnlyMethodInfoSet __concat; + private static readonly IReadOnlyMethodInfoSet __count; + private static readonly IReadOnlyMethodInfoSet __countWithPredicate; + private static readonly IReadOnlyMethodInfoSet __defaultIfEmpty; + private static readonly IReadOnlyMethodInfoSet __defaultIfEmptyWithDefaultValue; + private static readonly IReadOnlyMethodInfoSet __distinct; + private static readonly IReadOnlyMethodInfoSet __elementAt; + private static readonly IReadOnlyMethodInfoSet __elementAtOrDefault; + private static readonly IReadOnlyMethodInfoSet __except; + private static readonly IReadOnlyMethodInfoSet __first; + private static readonly IReadOnlyMethodInfoSet __firstOrDefault; + private static readonly IReadOnlyMethodInfoSet __firstWithPredicate; + private static readonly IReadOnlyMethodInfoSet __firstOrDefaultWithPredicate; + private static readonly IReadOnlyMethodInfoSet __groupByWithKeySelector; + private static readonly IReadOnlyMethodInfoSet __groupByWithKeySelectorAndElementSelector; + private static readonly IReadOnlyMethodInfoSet __groupByWithKeySelectorAndResultSelector; + private static readonly IReadOnlyMethodInfoSet __groupByWithKeySelectorElementSelectorAndResultSelector; + private static readonly IReadOnlyMethodInfoSet __intersect; + private static readonly IReadOnlyMethodInfoSet __last; + private static readonly IReadOnlyMethodInfoSet __lastOrDefault; + private static readonly IReadOnlyMethodInfoSet __lastWithPredicate; + private static readonly IReadOnlyMethodInfoSet __lastOrDefaultWithPredicate; + private static readonly IReadOnlyMethodInfoSet __longCount; + private static readonly IReadOnlyMethodInfoSet __longCountWithPredicate; + private static readonly IReadOnlyMethodInfoSet __ofType; + private static readonly IReadOnlyMethodInfoSet __orderBy; + private static readonly IReadOnlyMethodInfoSet __orderByDescending; + private static readonly IReadOnlyMethodInfoSet __prepend; + private static readonly IReadOnlyMethodInfoSet __reverse; + private static readonly IReadOnlyMethodInfoSet __select; + private static readonly IReadOnlyMethodInfoSet __selectManyWithCollectionSelectorAndResultSelector; + private static readonly IReadOnlyMethodInfoSet __selectManyWithSelector; + private static readonly IReadOnlyMethodInfoSet __single; + private static readonly IReadOnlyMethodInfoSet __singleOrDefault; + private static readonly IReadOnlyMethodInfoSet __singleWithPredicate; + private static readonly IReadOnlyMethodInfoSet __singleOrDefaultWithPredicate; + private static readonly IReadOnlyMethodInfoSet __skip; + private static readonly IReadOnlyMethodInfoSet __skipWhile; + private static readonly IReadOnlyMethodInfoSet __take; + private static readonly IReadOnlyMethodInfoSet __takeWhile; + private static readonly IReadOnlyMethodInfoSet __thenBy; + private static readonly IReadOnlyMethodInfoSet __thenByDescending; + private static readonly IReadOnlyMethodInfoSet __union; + private static readonly IReadOnlyMethodInfoSet __where; + private static readonly IReadOnlyMethodInfoSet __zip; + + // sets of methods + private static readonly IReadOnlyMethodInfoSet __aggregateOverloads; + private static readonly IReadOnlyMethodInfoSet __aggregateWithSeedOverloads; + private static readonly IReadOnlyMethodInfoSet __anyOverloads; + private static readonly IReadOnlyMethodInfoSet __appendOrPrepend; + private static readonly IReadOnlyMethodInfoSet __averageOverloads; + private static readonly IReadOnlyMethodInfoSet __averageWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __countOverloads; + private static readonly IReadOnlyMethodInfoSet __countWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __defaultIfEmptyOverloads; + private static readonly IReadOnlyMethodInfoSet __elementAtOverloads; + private static readonly IReadOnlyMethodInfoSet __firstOrLastWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __firstOrLastOrSingleOverloads; + private static readonly IReadOnlyMethodInfoSet __firstOrLastOrSingleWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __firstOverloads; + private static readonly IReadOnlyMethodInfoSet __firstOrDefaultOverloads; + private static readonly IReadOnlyMethodInfoSet __firstOrLastOverloads; + private static readonly IReadOnlyMethodInfoSet __firstWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __groupByOverloads; + private static readonly IReadOnlyMethodInfoSet __lastOverloads; + private static readonly IReadOnlyMethodInfoSet __lastOrDefaultOverloads; + private static readonly IReadOnlyMethodInfoSet __lastWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __maxOrMinOverloads; + private static readonly IReadOnlyMethodInfoSet __maxOrMinWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __maxOverloads; + private static readonly IReadOnlyMethodInfoSet __maxWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __minOverloads; + private static readonly IReadOnlyMethodInfoSet __minWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __orderByOrThenByOverloads; + private static readonly IReadOnlyMethodInfoSet __orderByOverloads; + private static readonly IReadOnlyMethodInfoSet __reverseOverloads; + private static readonly IReadOnlyMethodInfoSet __selectManyOverloads; + private static readonly IReadOnlyMethodInfoSet __singleOverloads; + private static readonly IReadOnlyMethodInfoSet __singleWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __skipOrTakeOverloads; + private static readonly IReadOnlyMethodInfoSet __skipOverloads; + private static readonly IReadOnlyMethodInfoSet __skipWhileOrTakeWhile; + private static readonly IReadOnlyMethodInfoSet __sumOverloads; + private static readonly IReadOnlyMethodInfoSet __sumWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __takeOverloads; + private static readonly IReadOnlyMethodInfoSet __thenByOverloads; + + static EnumerableOrQueryableMethod() + { + // initialize methods before sets of methods (in this file matching Enumerable and Queryable methods are treated as if they were one method) + __aggregateWithFunc = MethodInfoSet.Create( + [ + EnumerableMethod.AggregateWithFunc, + QueryableMethod.AggregateWithFunc + ]); + + __aggregateWithSeedAndFunc = MethodInfoSet.Create( + [ + EnumerableMethod.AggregateWithSeedAndFunc, + QueryableMethod.AggregateWithSeedAndFunc + ]); + + __aggregateWithSeedFuncAndResultSelector = MethodInfoSet.Create( + [ + EnumerableMethod.AggregateWithSeedFuncAndResultSelector, + QueryableMethod.AggregateWithSeedFuncAndResultSelector + ]); + + __all = MethodInfoSet.Create( + [ + EnumerableMethod.All, + QueryableMethod.All + ]); + + __any = MethodInfoSet.Create( + [ + EnumerableMethod.Any, + QueryableMethod.Any, + ]); + + __anyWithPredicate = MethodInfoSet.Create( + [ + EnumerableMethod.AnyWithPredicate, + QueryableMethod.AnyWithPredicate + ]); + + __append = MethodInfoSet.Create( + [ + EnumerableMethod.Append, + QueryableMethod.Append + ]); + + __concat = MethodInfoSet.Create( + [ + EnumerableMethod.Concat, + QueryableMethod.Concat + ]); + + __count = MethodInfoSet.Create( + [ + EnumerableMethod.Count, + QueryableMethod.Count + ]); + + __countWithPredicate = MethodInfoSet.Create( + [ + EnumerableMethod.CountWithPredicate, + QueryableMethod.CountWithPredicate + ]); + + __defaultIfEmpty = MethodInfoSet.Create( + [ + EnumerableMethod.DefaultIfEmpty, + QueryableMethod.DefaultIfEmpty + ]); + + __defaultIfEmptyWithDefaultValue = MethodInfoSet.Create( + [ + EnumerableMethod.DefaultIfEmptyWithDefaultValue, + QueryableMethod.DefaultIfEmptyWithDefaultValue, + ]); + + __distinct = MethodInfoSet.Create( + [ + EnumerableMethod.Distinct, + QueryableMethod.Distinct + ]); + + __elementAt = MethodInfoSet.Create( + [ + EnumerableMethod.ElementAt, + QueryableMethod.ElementAt + ]); + + __elementAtOrDefault = MethodInfoSet.Create( + [ + EnumerableMethod.ElementAtOrDefault, + QueryableMethod.ElementAtOrDefault + ]); + + __except = MethodInfoSet.Create( + [ + EnumerableMethod.Except, + QueryableMethod.Except + ]); + + __first = MethodInfoSet.Create( + [ + EnumerableMethod.First, + QueryableMethod.First + ]); + + __firstOrDefault = MethodInfoSet.Create( + [ + EnumerableMethod.FirstOrDefault, + QueryableMethod.FirstOrDefault + ]); + + __firstOrDefaultWithPredicate = MethodInfoSet.Create( + [ + EnumerableMethod.FirstOrDefaultWithPredicate, + QueryableMethod.FirstOrDefaultWithPredicate + ]); + + __firstWithPredicate = MethodInfoSet.Create( + [ + EnumerableMethod.FirstWithPredicate, + QueryableMethod.FirstWithPredicate + ]); + + __groupByWithKeySelector = MethodInfoSet.Create( + [ + EnumerableMethod.GroupByWithKeySelector, + QueryableMethod.GroupByWithKeySelector + ]); + + __groupByWithKeySelectorAndElementSelector = MethodInfoSet.Create( + [ + EnumerableMethod.GroupByWithKeySelectorAndElementSelector, + QueryableMethod.GroupByWithKeySelectorAndElementSelector + ]); + + __groupByWithKeySelectorAndResultSelector = MethodInfoSet.Create( + [ + EnumerableMethod.GroupByWithKeySelectorAndResultSelector, + QueryableMethod.GroupByWithKeySelectorAndResultSelector + ]); + + __groupByWithKeySelectorElementSelectorAndResultSelector = MethodInfoSet.Create( + [ + EnumerableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector, + QueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector + ]); + + __intersect = MethodInfoSet.Create( + [ + EnumerableMethod.Intersect, + QueryableMethod.Intersect + ]); + + __last = MethodInfoSet.Create( + [ + EnumerableMethod.Last, + QueryableMethod.Last + ]); + + __lastOrDefault = MethodInfoSet.Create( + [ + EnumerableMethod.LastOrDefault, + QueryableMethod.LastOrDefault + ]); + + __lastOrDefaultWithPredicate = MethodInfoSet.Create( + [ + EnumerableMethod.LastOrDefaultWithPredicate, + QueryableMethod.LastOrDefaultWithPredicate + ]); + + __lastWithPredicate = MethodInfoSet.Create( + [ + EnumerableMethod.LastWithPredicate, + QueryableMethod.LastWithPredicate + ]); + + __longCount = MethodInfoSet.Create( + [ + EnumerableMethod.LongCount, + QueryableMethod.LongCount + ]); + + __longCountWithPredicate = MethodInfoSet.Create( + [ + EnumerableMethod.LongCountWithPredicate, + QueryableMethod.LongCountWithPredicate + ]); + + __ofType = MethodInfoSet.Create( + [ + EnumerableMethod.OfType, + QueryableMethod.OfType + ]); + + __orderBy = MethodInfoSet.Create( + [ + EnumerableMethod.OrderBy, + QueryableMethod.OrderBy + ]); + + __orderByDescending = MethodInfoSet.Create( + [ + EnumerableMethod.OrderByDescending, + QueryableMethod.OrderByDescending + ]); + + __prepend = MethodInfoSet.Create( + [ + EnumerableMethod.Prepend, + QueryableMethod.Prepend + ]); + + __reverse = MethodInfoSet.Create( + [ + EnumerableMethod.Reverse, + QueryableMethod.Reverse + ]); + + __select = MethodInfoSet.Create( + [ + EnumerableMethod.Select, + QueryableMethod.Select + ]); + + __selectManyWithCollectionSelectorAndResultSelector = MethodInfoSet.Create( + [ + EnumerableMethod.SelectManyWithCollectionSelectorAndResultSelector, + QueryableMethod.SelectManyWithCollectionSelectorAndResultSelector + ]); + + __selectManyWithSelector = MethodInfoSet.Create( + [ + EnumerableMethod.SelectManyWithSelector, + QueryableMethod.SelectManyWithSelector + ]); + + __single = MethodInfoSet.Create( + [ + EnumerableMethod.Single, + QueryableMethod.Single + ]); + + __singleOrDefault = MethodInfoSet.Create( + [ + EnumerableMethod.SingleOrDefault, + QueryableMethod.SingleOrDefault + ]); + + __singleOrDefaultWithPredicate = MethodInfoSet.Create( + [ + EnumerableMethod.SingleOrDefaultWithPredicate, + QueryableMethod.SingleOrDefaultWithPredicate + ]); + + __singleWithPredicate = MethodInfoSet.Create( + [ + EnumerableMethod.SingleWithPredicate, + QueryableMethod.SingleWithPredicate + ]); + + __skip = MethodInfoSet.Create( + [ + EnumerableMethod.Skip, + QueryableMethod.Skip + ]); + + __skipWhile = MethodInfoSet.Create( + [ + EnumerableMethod.SkipWhile, + QueryableMethod.SkipWhile + ]); + + __take = MethodInfoSet.Create( + [ + EnumerableMethod.Take, + QueryableMethod.Take + ]); + + __takeWhile = MethodInfoSet.Create( + [ + EnumerableMethod.TakeWhile, + QueryableMethod.TakeWhile + ]); + + __thenBy = MethodInfoSet.Create( + [ + EnumerableMethod.ThenBy, + QueryableMethod.ThenBy + ]); + + __thenByDescending = MethodInfoSet.Create( + [ + EnumerableMethod.ThenByDescending, + QueryableMethod.ThenByDescending + ]); + + __union = MethodInfoSet.Create( + [ + EnumerableMethod.Union, + QueryableMethod.Union, + ]); + + __where = MethodInfoSet.Create( + [ + EnumerableMethod.Where, + QueryableMethod.Where, + ]); + + __zip = MethodInfoSet.Create( + [ + EnumerableMethod.Zip, + QueryableMethod.Zip + ]); + + // initialize sets of methods after methods + __aggregateOverloads = MethodInfoSet.Create( + [ + __aggregateWithFunc, + __aggregateWithSeedAndFunc, + __aggregateWithSeedFuncAndResultSelector + ]); + + __aggregateWithSeedOverloads = MethodInfoSet.Create( + [ + __aggregateWithSeedAndFunc, + __aggregateWithSeedFuncAndResultSelector + ]); + + __anyOverloads = MethodInfoSet.Create( + [ + __any, + __anyWithPredicate + ]); + + __appendOrPrepend = MethodInfoSet.Create( + [ + __append, + __prepend + ]); + + __averageOverloads = MethodInfoSet.Create( + [ + EnumerableMethod.AverageDecimal, + EnumerableMethod.AverageDecimalWithSelector, + EnumerableMethod.AverageDouble, + EnumerableMethod.AverageDoubleWithSelector, + EnumerableMethod.AverageInt32, + EnumerableMethod.AverageInt32WithSelector, + EnumerableMethod.AverageInt64, + EnumerableMethod.AverageInt64WithSelector, + EnumerableMethod.AverageNullableDecimal, + EnumerableMethod.AverageNullableDecimalWithSelector, + EnumerableMethod.AverageNullableDouble, + EnumerableMethod.AverageNullableDoubleWithSelector, + EnumerableMethod.AverageNullableInt32, + EnumerableMethod.AverageNullableInt32WithSelector, + EnumerableMethod.AverageNullableInt64, + EnumerableMethod.AverageNullableInt64WithSelector, + EnumerableMethod.AverageNullableSingle, + EnumerableMethod.AverageNullableSingleWithSelector, + EnumerableMethod.AverageSingle, + EnumerableMethod.AverageSingleWithSelector, + QueryableMethod.AverageDecimal, + QueryableMethod.AverageDecimalWithSelector, + QueryableMethod.AverageDouble, + QueryableMethod.AverageDoubleWithSelector, + QueryableMethod.AverageInt32, + QueryableMethod.AverageInt32WithSelector, + QueryableMethod.AverageInt64, + QueryableMethod.AverageInt64WithSelector, + QueryableMethod.AverageNullableDecimal, + QueryableMethod.AverageNullableDecimalWithSelector, + QueryableMethod.AverageNullableDouble, + QueryableMethod.AverageNullableDoubleWithSelector, + QueryableMethod.AverageNullableInt32, + QueryableMethod.AverageNullableInt32WithSelector, + QueryableMethod.AverageNullableInt64, + QueryableMethod.AverageNullableInt64WithSelector, + QueryableMethod.AverageNullableSingle, + QueryableMethod.AverageNullableSingleWithSelector, + QueryableMethod.AverageSingle, + QueryableMethod.AverageSingleWithSelector + ]); + + __averageWithSelectorOverloads = MethodInfoSet.Create( + [ + EnumerableMethod.AverageDecimalWithSelector, + EnumerableMethod.AverageDoubleWithSelector, + EnumerableMethod.AverageInt32WithSelector, + EnumerableMethod.AverageInt64WithSelector, + EnumerableMethod.AverageNullableDecimalWithSelector, + EnumerableMethod.AverageNullableDoubleWithSelector, + EnumerableMethod.AverageNullableInt32WithSelector, + EnumerableMethod.AverageNullableInt64WithSelector, + EnumerableMethod.AverageNullableSingleWithSelector, + EnumerableMethod.AverageSingleWithSelector, + QueryableMethod.AverageDecimalWithSelector, + QueryableMethod.AverageDoubleWithSelector, + QueryableMethod.AverageInt32WithSelector, + QueryableMethod.AverageInt64WithSelector, + QueryableMethod.AverageNullableDecimalWithSelector, + QueryableMethod.AverageNullableDoubleWithSelector, + QueryableMethod.AverageNullableInt32WithSelector, + QueryableMethod.AverageNullableInt64WithSelector, + QueryableMethod.AverageNullableSingleWithSelector, + QueryableMethod.AverageSingleWithSelector, + ]); + + __countOverloads = MethodInfoSet.Create( + [ + __count, + __countWithPredicate, + __longCount, // it's conventiont to treat LongCount as if it was an overload of Count + __longCountWithPredicate + ]); + + __countWithPredicateOverloads = MethodInfoSet.Create( + [ + __countWithPredicate, + __longCountWithPredicate // it's conventiont to treat LongCount as if it was an overload of Count + ]); + + __defaultIfEmptyOverloads = MethodInfoSet.Create( + [ + __defaultIfEmpty, + __defaultIfEmptyWithDefaultValue + ]); + + __elementAtOverloads = MethodInfoSet.Create( + [ + __elementAt, + __elementAtOrDefault // it's conventiont to treat ElementAtOrDefault as if it was an overload of ElementAt + ]); + + __firstOverloads = MethodInfoSet.Create( + [ + __first, + __firstOrDefault, // it's convenient to treat FirstOrDefault as if it was an overload + __firstOrDefaultWithPredicate, + __firstWithPredicate + ]); + + __firstOrDefaultOverloads = MethodInfoSet.Create( + [ + __firstOrDefault, + __firstOrDefaultWithPredicate + ]); + + __firstWithPredicateOverloads = MethodInfoSet.Create( + [ + __firstOrDefaultWithPredicate, + __firstWithPredicate + ]); + + __groupByOverloads = MethodInfoSet.Create( + [ + __groupByWithKeySelector, + __groupByWithKeySelectorAndElementSelector, + __groupByWithKeySelectorAndResultSelector, + __groupByWithKeySelectorElementSelectorAndResultSelector + ]); + + __lastOverloads = MethodInfoSet.Create( + [ + __last, + __lastOrDefault, // it's convenient to treat LastOrDefault as if it was an overload + __lastOrDefaultWithPredicate, + __lastWithPredicate + ]); + + __lastOrDefaultOverloads = MethodInfoSet.Create( + [ + __lastOrDefault, + __lastOrDefaultWithPredicate + ]); + + __lastWithPredicateOverloads = MethodInfoSet.Create( + [ + __lastOrDefaultWithPredicate, + __lastWithPredicate + ]); + + __maxOverloads = MethodInfoSet.Create( + [ + EnumerableMethod.Max, + EnumerableMethod.MaxDecimal, + EnumerableMethod.MaxDecimalWithSelector, + EnumerableMethod.MaxDouble, + EnumerableMethod.MaxDoubleWithSelector, + EnumerableMethod.MaxInt32, + EnumerableMethod.MaxInt32WithSelector, + EnumerableMethod.MaxInt64, + EnumerableMethod.MaxInt64WithSelector, + EnumerableMethod.MaxNullableDecimal, + EnumerableMethod.MaxNullableDecimalWithSelector, + EnumerableMethod.MaxNullableDouble, + EnumerableMethod.MaxNullableDoubleWithSelector, + EnumerableMethod.MaxNullableInt32, + EnumerableMethod.MaxNullableInt32WithSelector, + EnumerableMethod.MaxNullableInt64, + EnumerableMethod.MaxNullableInt64WithSelector, + EnumerableMethod.MaxNullableSingle, + EnumerableMethod.MaxNullableSingleWithSelector, + EnumerableMethod.MaxSingle, + EnumerableMethod.MaxSingleWithSelector, + EnumerableMethod.MaxWithSelector, + QueryableMethod.Max, + QueryableMethod.MaxWithSelector, + ]); + + __maxWithSelectorOverloads = MethodInfoSet.Create( + [ + EnumerableMethod.MaxDecimalWithSelector, + EnumerableMethod.MaxDoubleWithSelector, + EnumerableMethod.MaxInt32WithSelector, + EnumerableMethod.MaxInt64WithSelector, + EnumerableMethod.MaxNullableDecimalWithSelector, + EnumerableMethod.MaxNullableDoubleWithSelector, + EnumerableMethod.MaxNullableInt32WithSelector, + EnumerableMethod.MaxNullableInt64WithSelector, + EnumerableMethod.MaxNullableSingleWithSelector, + EnumerableMethod.MaxSingleWithSelector, + EnumerableMethod.MaxWithSelector, + QueryableMethod.MaxWithSelector + ]); + + __minOverloads = MethodInfoSet.Create( + [ + EnumerableMethod.Min, + EnumerableMethod.MinDecimal, + EnumerableMethod.MinDecimalWithSelector, + EnumerableMethod.MinDouble, + EnumerableMethod.MinDoubleWithSelector, + EnumerableMethod.MinInt32, + EnumerableMethod.MinInt32WithSelector, + EnumerableMethod.MinInt64, + EnumerableMethod.MinInt64WithSelector, + EnumerableMethod.MinNullableDecimal, + EnumerableMethod.MinNullableDecimalWithSelector, + EnumerableMethod.MinNullableDouble, + EnumerableMethod.MinNullableDoubleWithSelector, + EnumerableMethod.MinNullableInt32, + EnumerableMethod.MinNullableInt32WithSelector, + EnumerableMethod.MinNullableInt64, + EnumerableMethod.MinNullableInt64WithSelector, + EnumerableMethod.MinNullableSingle, + EnumerableMethod.MinNullableSingleWithSelector, + EnumerableMethod.MinSingle, + EnumerableMethod.MinSingleWithSelector, + EnumerableMethod.MinWithSelector, + QueryableMethod.Min, + QueryableMethod.MinWithSelector, + ]); + + __minWithSelectorOverloads = MethodInfoSet.Create( + [ + EnumerableMethod.MinDecimalWithSelector, + EnumerableMethod.MinDoubleWithSelector, + EnumerableMethod.MinInt32WithSelector, + EnumerableMethod.MinInt64WithSelector, + EnumerableMethod.MinNullableDecimalWithSelector, + EnumerableMethod.MinNullableDoubleWithSelector, + EnumerableMethod.MinNullableInt32WithSelector, + EnumerableMethod.MinNullableInt64WithSelector, + EnumerableMethod.MinNullableSingleWithSelector, + EnumerableMethod.MinSingleWithSelector, + EnumerableMethod.MinWithSelector, + QueryableMethod.MinWithSelector + ]); + + __orderByOverloads = MethodInfoSet.Create( + [ + __orderBy, + __orderByDescending + ]); + + __reverseOverloads = MethodInfoSet.Create( + [ + __reverse, + [EnumerableMethod.ReverseWithArray] + ]); + + __selectManyOverloads = MethodInfoSet.Create( + [ + __selectManyWithSelector, + __selectManyWithCollectionSelectorAndResultSelector + ]); + + __singleOverloads = MethodInfoSet.Create( + [ + __single, + __singleOrDefault, // it's convenient to treat SingleOrDefault as if it was an overload + __singleOrDefaultWithPredicate, + __singleWithPredicate + ]); + + __singleWithPredicateOverloads = MethodInfoSet.Create( + [ + __singleOrDefaultWithPredicate, + __singleWithPredicate + ]); + + __skipOverloads = MethodInfoSet.Create( + [ + __skip, + __skipWhile, // it's convenient to treat SkipWhile as if it was an overload + [MongoQueryableMethod.SkipWithLong] // it's convenient to group our custom Skip method with the EnumerableOrQueryable Skip methods + ]); + + __skipWhileOrTakeWhile = MethodInfoSet.Create( + [ + __skipWhile, + __takeWhile + ]); + + __sumOverloads = MethodInfoSet.Create( + [ + EnumerableMethod.SumDecimal, + EnumerableMethod.SumDecimalWithSelector, + EnumerableMethod.SumDouble, + EnumerableMethod.SumDoubleWithSelector, + EnumerableMethod.SumInt32, + EnumerableMethod.SumInt32WithSelector, + EnumerableMethod.SumInt64, + EnumerableMethod.SumInt64WithSelector, + EnumerableMethod.SumNullableDecimal, + EnumerableMethod.SumNullableDecimalWithSelector, + EnumerableMethod.SumNullableDouble, + EnumerableMethod.SumNullableDoubleWithSelector, + EnumerableMethod.SumNullableInt32, + EnumerableMethod.SumNullableInt32WithSelector, + EnumerableMethod.SumNullableInt64, + EnumerableMethod.SumNullableInt64WithSelector, + EnumerableMethod.SumNullableSingle, + EnumerableMethod.SumNullableSingleWithSelector, + EnumerableMethod.SumSingle, + EnumerableMethod.SumSingleWithSelector, + QueryableMethod.SumDecimal, + QueryableMethod.SumDecimalWithSelector, + QueryableMethod.SumDouble, + QueryableMethod.SumDoubleWithSelector, + QueryableMethod.SumInt32, + QueryableMethod.SumInt32WithSelector, + QueryableMethod.SumInt64, + QueryableMethod.SumInt64WithSelector, + QueryableMethod.SumNullableDecimal, + QueryableMethod.SumNullableDecimalWithSelector, + QueryableMethod.SumNullableDouble, + QueryableMethod.SumNullableDoubleWithSelector, + QueryableMethod.SumNullableInt32, + QueryableMethod.SumNullableInt32WithSelector, + QueryableMethod.SumNullableInt64, + QueryableMethod.SumNullableInt64WithSelector, + QueryableMethod.SumNullableSingle, + QueryableMethod.SumNullableSingleWithSelector, + QueryableMethod.SumSingle, + QueryableMethod.SumSingleWithSelector + ]); + + __sumWithSelectorOverloads = MethodInfoSet.Create( + [ + EnumerableMethod.SumDecimalWithSelector, + EnumerableMethod.SumDoubleWithSelector, + EnumerableMethod.SumInt32WithSelector, + EnumerableMethod.SumInt64WithSelector, + EnumerableMethod.SumNullableDecimalWithSelector, + EnumerableMethod.SumNullableDoubleWithSelector, + EnumerableMethod.SumNullableInt32WithSelector, + EnumerableMethod.SumNullableInt64WithSelector, + EnumerableMethod.SumNullableSingleWithSelector, + EnumerableMethod.SumSingleWithSelector, + QueryableMethod.SumDecimalWithSelector, + QueryableMethod.SumDoubleWithSelector, + QueryableMethod.SumInt32WithSelector, + QueryableMethod.SumInt64WithSelector, + QueryableMethod.SumNullableDecimalWithSelector, + QueryableMethod.SumNullableDoubleWithSelector, + QueryableMethod.SumNullableInt32WithSelector, + QueryableMethod.SumNullableInt64WithSelector, + QueryableMethod.SumNullableSingleWithSelector, + QueryableMethod.SumSingleWithSelector, + ]); + + __takeOverloads = MethodInfoSet.Create( + [ + __take, + __takeWhile, // it's convenient to treat TakeWhile as if it was an overload of Take + [MongoQueryableMethod.TakeWithLong] // it's convenient to group our custom Take method with the EnumerableOrQueryable Take methods + ]); + + __thenByOverloads = MethodInfoSet.Create( + [ + __thenBy, + __thenByDescending + ]); + + // initialize sets that depend on other sets last + __firstOrLastOrSingleOverloads = MethodInfoSet.Create( + [ + __firstOverloads, + __lastOverloads, + __singleOverloads + ]); + + __firstOrLastOrSingleWithPredicateOverloads = MethodInfoSet.Create( + [ + __firstWithPredicateOverloads, + __lastWithPredicateOverloads, + __singleWithPredicateOverloads + ]); + + __firstOrLastOverloads = MethodInfoSet.Create( + [ + __firstOverloads, + __lastOverloads + ]); + + __firstOrLastWithPredicateOverloads = MethodInfoSet.Create( + [ + __firstWithPredicateOverloads, + __lastWithPredicateOverloads + ]); + + __maxOrMinOverloads = MethodInfoSet.Create( + [ + __maxOverloads, + __minOverloads + ]); + + __maxOrMinWithSelectorOverloads = MethodInfoSet.Create( + [ + __maxWithSelectorOverloads, + __minWithSelectorOverloads + ]); + + __orderByOrThenByOverloads = MethodInfoSet.Create( + [ + __orderByOverloads, + __thenByOverloads + ]); + + __skipOrTakeOverloads = MethodInfoSet.Create( + [ + __skipOverloads, + __takeOverloads + ]); + } + + // methods + public static IReadOnlyMethodInfoSet AggregateWithFunc => __aggregateWithFunc; + public static IReadOnlyMethodInfoSet AggregateWithSeedAndFunc => __aggregateWithSeedAndFunc; + public static IReadOnlyMethodInfoSet AggregateWithSeedFuncAndResultSelector => __aggregateWithSeedFuncAndResultSelector; + public static IReadOnlyMethodInfoSet All => __all; + public static IReadOnlyMethodInfoSet Any => __any; + public static IReadOnlyMethodInfoSet AnyWithPredicate => __anyWithPredicate; + public static IReadOnlyMethodInfoSet Append => __append; + public static IReadOnlyMethodInfoSet Concat => __concat; + public static IReadOnlyMethodInfoSet Count => __count; + public static IReadOnlyMethodInfoSet CountWithPredicate => __countWithPredicate; + public static IReadOnlyMethodInfoSet DefaultIfEmpty => __defaultIfEmpty; + public static IReadOnlyMethodInfoSet DefaultIfEmptyWithDefaultValue => __defaultIfEmptyWithDefaultValue; + public static IReadOnlyMethodInfoSet Distinct => __distinct; + public static IReadOnlyMethodInfoSet ElementAt => __elementAt; + public static IReadOnlyMethodInfoSet ElementAtOrDefault => __elementAtOrDefault; + public static IReadOnlyMethodInfoSet Except => __except; + public static IReadOnlyMethodInfoSet First => __first; + public static IReadOnlyMethodInfoSet FirstOrDefault => __firstOrDefault; + public static IReadOnlyMethodInfoSet FirstOrDefaultWithPredicate => __firstOrDefaultWithPredicate; + public static IReadOnlyMethodInfoSet FirstWithPredicate => __firstWithPredicate; + public static IReadOnlyMethodInfoSet GroupByWithKeySelector => __groupByWithKeySelector; + public static IReadOnlyMethodInfoSet GroupByWithKeySelectorAndElementSelector => __groupByWithKeySelectorAndElementSelector; + public static IReadOnlyMethodInfoSet GroupByWithKeySelectorAndResultSelector => __groupByWithKeySelectorAndResultSelector; + public static IReadOnlyMethodInfoSet GroupByWithKeySelectorElementSelectorAndResultSelector => __groupByWithKeySelectorElementSelectorAndResultSelector; + public static IReadOnlyMethodInfoSet Intersect => __intersect; + public static IReadOnlyMethodInfoSet Last => __last; + public static IReadOnlyMethodInfoSet LastOrDefault => __lastOrDefault; + public static IReadOnlyMethodInfoSet LastOrDefaultWithPredicate => __lastOrDefaultWithPredicate; + public static IReadOnlyMethodInfoSet LastWithPredicate => __lastWithPredicate; + public static IReadOnlyMethodInfoSet LongCount => __longCount; + public static IReadOnlyMethodInfoSet LongCountWithPredicate => __longCountWithPredicate; + public static IReadOnlyMethodInfoSet OfType => __ofType; + public static IReadOnlyMethodInfoSet Reverse => __reverse; + public static IReadOnlyMethodInfoSet Select => __select; + public static IReadOnlyMethodInfoSet SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector; + public static IReadOnlyMethodInfoSet SelectManyWithSelector => __selectManyWithSelector; + public static IReadOnlyMethodInfoSet Single => __single; + public static IReadOnlyMethodInfoSet SingleOrDefault => __singleOrDefault; + public static IReadOnlyMethodInfoSet SingleOrDefaultWithPredicate => __singleOrDefaultWithPredicate; + public static IReadOnlyMethodInfoSet SingleWithPredicate => __singleWithPredicate; + public static IReadOnlyMethodInfoSet Skip => __skip; + public static IReadOnlyMethodInfoSet SkipWhile => __skipWhile; + public static IReadOnlyMethodInfoSet Take => __take; + public static IReadOnlyMethodInfoSet TakeWhile => __takeWhile; + public static IReadOnlyMethodInfoSet Union => __union; + public static IReadOnlyMethodInfoSet Where => __where; + public static IReadOnlyMethodInfoSet Zip => __zip; + + // sets of methods + public static IReadOnlyMethodInfoSet AggregateOverloads => __aggregateOverloads; + public static IReadOnlyMethodInfoSet AggregateWithSeedOverloads => __aggregateWithSeedOverloads; + public static IReadOnlyMethodInfoSet AnyOverloads => __anyOverloads; + public static IReadOnlyMethodInfoSet AppendOrPrepend => __appendOrPrepend; + public static IReadOnlyMethodInfoSet AverageOverloads => __averageOverloads; + public static IReadOnlyMethodInfoSet AverageWithSelectorOverloads => __averageWithSelectorOverloads; + public static IReadOnlyMethodInfoSet CountOverloads => __countOverloads; + public static IReadOnlyMethodInfoSet CountWithPredicateOverloads => __countWithPredicateOverloads; + public static IReadOnlyMethodInfoSet DefaultIfEmptyOverloads => __defaultIfEmptyOverloads; + public static IReadOnlyMethodInfoSet ElementAtOverloads => __elementAtOverloads; + public static IReadOnlyMethodInfoSet FirstOverloads => __firstOverloads; + public static IReadOnlyMethodInfoSet FirstOrDefaultOverloads => __firstOrDefaultOverloads; + public static IReadOnlyMethodInfoSet FirstWithPredicateOverloads => __firstWithPredicateOverloads; + public static IReadOnlyMethodInfoSet GroupByOverloads => __groupByOverloads; + public static IReadOnlyMethodInfoSet FirstOrLastOverloads => __firstOrLastOverloads; + public static IReadOnlyMethodInfoSet FirstOrLastWithPredicateOverloads => __firstOrLastWithPredicateOverloads; + public static IReadOnlyMethodInfoSet FirstOrLastOrSingleOverloads => __firstOrLastOrSingleOverloads; + public static IReadOnlyMethodInfoSet FirstOrLastOrSingleWithPredicateOverloads => __firstOrLastOrSingleWithPredicateOverloads; + public static IReadOnlyMethodInfoSet LastOverloads => __lastOverloads; + public static IReadOnlyMethodInfoSet LastOrDefaultOverloads => __lastOrDefaultOverloads; + public static IReadOnlyMethodInfoSet LastWithPredicateOverloads => __lastWithPredicateOverloads; + public static IReadOnlyMethodInfoSet MaxOverloads => __maxOverloads; + public static IReadOnlyMethodInfoSet MaxWithSelectorOverloads => __maxWithSelectorOverloads; + public static IReadOnlyMethodInfoSet MaxOrMinOverloads => __maxOrMinOverloads; + public static IReadOnlyMethodInfoSet MaxOrMinWithSelectorOverloads => __maxOrMinWithSelectorOverloads; + public static IReadOnlyMethodInfoSet MinOverloads => __minOverloads; + public static IReadOnlyMethodInfoSet MinWithSelectorOverloads => __minWithSelectorOverloads; + public static IReadOnlyMethodInfoSet OrderByOrThenByOverloads => __orderByOrThenByOverloads; + public static IReadOnlyMethodInfoSet OrderByOverloads => __orderByOverloads; + public static IReadOnlyMethodInfoSet ReverseOverloads => __reverseOverloads; + public static IReadOnlyMethodInfoSet SelectManyOverloads => __selectManyOverloads; + public static IReadOnlyMethodInfoSet SingleOverloads => __singleOverloads; + public static IReadOnlyMethodInfoSet SingleWithPredicateOverloads => __singleWithPredicateOverloads; + public static IReadOnlyMethodInfoSet SkipOrTakeOverloads => __skipOrTakeOverloads; + public static IReadOnlyMethodInfoSet SkipOverloads => __skipOverloads; + public static IReadOnlyMethodInfoSet SkipWhileOrTakeWhile => __skipWhileOrTakeWhile; + public static IReadOnlyMethodInfoSet SumOverloads => __sumOverloads; + public static IReadOnlyMethodInfoSet SumWithSelectorOverloads => __sumWithSelectorOverloads; + public static IReadOnlyMethodInfoSet TakeOverloads => __takeOverloads; + public static IReadOnlyMethodInfoSet ThenByOverloads => __thenByOverloads; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableProperty.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableProperty.cs index 6e929c18d3c..845b0702a23 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableProperty.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableProperty.cs @@ -38,9 +38,9 @@ expression.Member is PropertyInfo propertyInfo && static bool ImplementsCollectionInterface(Type type) => - type.Implements(typeof(ICollection)) || - type.Implements(typeof(ICollection<>)) || - type.Implements(typeof(IReadOnlyCollection<>)); + type.ImplementsInterface(typeof(ICollection)) || + type.ImplementsInterface(typeof(ICollection<>)) || + type.ImplementsInterface(typeof(IReadOnlyCollection<>)); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/HashSetConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/HashSetConstructor.cs new file mode 100644 index 00000000000..3b062091f53 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/HashSetConstructor.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class HashSetConstructor + { + public static bool IsWithCollectionConstructor(ConstructorInfo constructor) + { + if (constructor != null) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(HashSet<>) && + parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var itemType) && + itemType == declaringType.GenericTypeArguments[0]; + } + + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ISetMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ISetMethod.cs new file mode 100644 index 00000000000..38898ffd9ef --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ISetMethod.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class ISetMethod + { + // public static methods + public static bool IsSetEqualsMethod(MethodInfo method) + { + // many types implement a SetEquals method but the declaringType should always implement ISet + var declaringType = method.DeclaringType; + return + declaringType.ImplementsISet(out var itemType) && + method.IsPublic && + !method.IsStatic && + method.ReturnType == typeof(bool) && + method.Name == "SetEquals" && + method.GetParameters() is var parameters && + parameters.Length == 1 && + parameters[0] is var otherParameter && + otherParameter.ParameterType == typeof(IEnumerable<>).MakeGenericType(itemType); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/KeyValuePairConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/KeyValuePairConstructor.cs new file mode 100644 index 00000000000..5ffa9e126f2 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/KeyValuePairConstructor.cs @@ -0,0 +1,44 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class KeyValuePairConstructor + { + public static bool IsWithKeyAndValueConstructor(ConstructorInfo constructor) + { + if (constructor != null) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>) && + declaringType.GetGenericArguments() is var typeParameters && + typeParameters[0] is var keyType && + typeParameters[1] is var valueType && + parameters.Length == 2 && + parameters[0].ParameterType == keyType && + parameters[1].ParameterType == valueType; + } + + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ListConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ListConstructor.cs new file mode 100644 index 00000000000..21c731c7ceb --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ListConstructor.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class ListConstructor + { + public static bool IsWithCollectionConstructor(ConstructorInfo constructor) + { + if (constructor != null) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(List<>) && + parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var itemType) && + itemType == declaringType.GenericTypeArguments[0]; + } + + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MathMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MathMethod.cs index 28623c07a47..3d83130e1fd 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MathMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MathMethod.cs @@ -58,9 +58,23 @@ internal static class MathMethod private static readonly MethodInfo __truncateDecimal; private static readonly MethodInfo __truncateDouble; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __absOverloads; + private static readonly IReadOnlyMethodInfoSet __logOverloads; + private static readonly IReadOnlyMethodInfoSet __roundOverloads; + private static readonly IReadOnlyMethodInfoSet __roundWithPlaceOverloads; + private static readonly IReadOnlyMethodInfoSet __trigonometricMethods; + // static constructor static MathMethod() { + // initialize methods before sets of methods +#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER + __acosh = ReflectionInfo.Method((double d) => Math.Acosh(d)); + __asinh = ReflectionInfo.Method((double d) => Math.Asinh(d)); + __atanh = ReflectionInfo.Method((double d) => Math.Atanh(d)); +#endif + __absDecimal = ReflectionInfo.Method((decimal value) => Math.Abs(value)); __absDouble = ReflectionInfo.Method((double value) => Math.Abs(value)); __absInt16 = ReflectionInfo.Method((short value) => Math.Abs(value)); @@ -69,18 +83,9 @@ static MathMethod() __absSByte = ReflectionInfo.Method((sbyte value) => Math.Abs(value)); __absSingle = ReflectionInfo.Method((float value) => Math.Abs(value)); __acos = ReflectionInfo.Method((double d) => Math.Acos(d)); -#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER - __acosh = ReflectionInfo.Method((double d) => Math.Acosh(d)); -#endif __asin = ReflectionInfo.Method((double d) => Math.Asin(d)); -#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER - __asinh = ReflectionInfo.Method((double d) => Math.Asinh(d)); -#endif __atan = ReflectionInfo.Method((double d) => Math.Atan(d)); __atan2 = ReflectionInfo.Method((double x, double y) => Math.Atan2(x, y)); -#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER - __atanh = ReflectionInfo.Method((double d) => Math.Atanh(d)); -#endif __ceilingWithDecimal = ReflectionInfo.Method((decimal d) => Math.Ceiling(d)); __ceilingWithDouble = ReflectionInfo.Method((double a) => Math.Ceiling(a)); __cos = ReflectionInfo.Method((double d) => Math.Cos(d)); @@ -103,6 +108,56 @@ static MathMethod() __tanh = ReflectionInfo.Method((double a) => Math.Tanh(a)); __truncateDecimal = ReflectionInfo.Method((decimal d) => Math.Truncate(d)); __truncateDouble = ReflectionInfo.Method((double d) => Math.Truncate(d)); + + // initialize sets of methods after methods + __absOverloads = MethodInfoSet.Create( + [ + __absDecimal, + __absDouble, + __absInt16, + __absInt32, + __absInt64, + __absSByte, + __absSingle + ]); + + __logOverloads = MethodInfoSet.Create( + [ + __log, + __log10, // it's convenient to treat Log10 as if it was an overload + __logWithNewBase + ]); + + __roundOverloads = MethodInfoSet.Create( + [ + __roundWithDecimal, + __roundWithDecimalAndDecimals, + __roundWithDouble, + __roundWithDoubleAndDigits + ]); + + __roundWithPlaceOverloads = MethodInfoSet.Create( + [ + __roundWithDecimalAndDecimals, + __roundWithDoubleAndDigits + ]); + + __trigonometricMethods = MethodInfoSet.Create( + [ + __acos, + __acosh, + __asin, + __asinh, + __atan, + __atanh, + __atan2, + __cos, + __cosh, + __sin, + __sinh, + __tan, + __tanh + ]); } // public properties @@ -142,5 +197,12 @@ static MathMethod() public static MethodInfo Tanh => __tanh; public static MethodInfo TruncateDecimal => __truncateDecimal; public static MethodInfo TruncateDouble => __truncateDouble; + + // sets of methods + public static IReadOnlyMethodInfoSet AbsOverloads => __absOverloads; + public static IReadOnlyMethodInfoSet LogOverloads => __logOverloads; + public static IReadOnlyMethodInfoSet RoundOverloads => __roundOverloads; + public static IReadOnlyMethodInfoSet RoundWithPlaceOverloads => __roundWithPlaceOverloads; + public static IReadOnlyMethodInfoSet TrigonometricMethods => __trigonometricMethods; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MethodInfoSet.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MethodInfoSet.cs new file mode 100644 index 00000000000..143eb85cf1c --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MethodInfoSet.cs @@ -0,0 +1,97 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal interface IReadOnlyMethodInfoSet : IEnumerable +{ + bool Contains(MethodInfo method); +} + +internal abstract class MethodInfoSet : IReadOnlyMethodInfoSet +{ + public static IReadOnlyMethodInfoSet Create(IEnumerable methods) + { + var hashSet = new HashSet(); + hashSet.UnionWith(methods.Where(m => m != null)); + return Create(hashSet); + } + + public static IReadOnlyMethodInfoSet Create(IEnumerable> methodSets) + { + var hashSet = new HashSet(); + + foreach (var methodSet in methodSets) + { + hashSet.UnionWith(methodSet.Where(m => m != null)); + } + + return Create(hashSet); + } + + private static IReadOnlyMethodInfoSet Create(HashSet hashSet) + { + return hashSet.Count <= 4 ? new ArrayBasedMethodInfoSet(hashSet.ToArray()) : new HashSetBasedMethodInfoSet(hashSet); + } + + public abstract bool Contains(MethodInfo method); + + public abstract IEnumerator GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); +} + +internal sealed class ArrayBasedMethodInfoSet : MethodInfoSet +{ + private readonly MethodInfo[] _methods; + + public ArrayBasedMethodInfoSet(MethodInfo[] methods) + { + _methods = methods; + } + + public override bool Contains(MethodInfo method) + { + return method.IsGenericMethod && !method.ContainsGenericParameters ? + _methods.Contains(method.GetGenericMethodDefinition()) : + _methods.Contains(method); + } + + public override IEnumerator GetEnumerator() => ((IEnumerable)_methods).GetEnumerator(); +} + +internal sealed class HashSetBasedMethodInfoSet : MethodInfoSet +{ + private readonly HashSet _methods; + + public HashSetBasedMethodInfoSet(HashSet methods) + { + _methods = new HashSet(methods); + } + + public override bool Contains(MethodInfo method) + { + return method.IsGenericMethod && !method.ContainsGenericParameters ? + _methods.Contains(method.GetGenericMethodDefinition()) : + _methods.Contains(method); + } + + public override IEnumerator GetEnumerator() => _methods.GetEnumerator(); +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs index c10550024c3..f8ad54596a6 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoEnumerableMethod.cs @@ -65,11 +65,60 @@ internal static class MongoEnumerableMethod private static readonly MethodInfo __percentileNullableSingleWithSelector; private static readonly MethodInfo __percentileSingle; private static readonly MethodInfo __percentileSingleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationDecimal; + private static readonly MethodInfo __standardDeviationPopulationDecimalWithSelector; + private static readonly MethodInfo __standardDeviationPopulationDouble; + private static readonly MethodInfo __standardDeviationPopulationDoubleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationInt32; + private static readonly MethodInfo __standardDeviationPopulationInt32WithSelector; + private static readonly MethodInfo __standardDeviationPopulationInt64; + private static readonly MethodInfo __standardDeviationPopulationInt64WithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableDecimal; + private static readonly MethodInfo __standardDeviationPopulationNullableDecimalWithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableDouble; + private static readonly MethodInfo __standardDeviationPopulationNullableDoubleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableInt32; + private static readonly MethodInfo __standardDeviationPopulationNullableInt32WithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableInt64; + private static readonly MethodInfo __standardDeviationPopulationNullableInt64WithSelector; + private static readonly MethodInfo __standardDeviationPopulationNullableSingle; + private static readonly MethodInfo __standardDeviationPopulationNullableSingleWithSelector; + private static readonly MethodInfo __standardDeviationPopulationSingle; + private static readonly MethodInfo __standardDeviationPopulationSingleWithSelector; + private static readonly MethodInfo __standardDeviationSampleDecimal; + private static readonly MethodInfo __standardDeviationSampleDecimalWithSelector; + private static readonly MethodInfo __standardDeviationSampleDouble; + private static readonly MethodInfo __standardDeviationSampleDoubleWithSelector; + private static readonly MethodInfo __standardDeviationSampleInt32; + private static readonly MethodInfo __standardDeviationSampleInt32WithSelector; + private static readonly MethodInfo __standardDeviationSampleInt64; + private static readonly MethodInfo __standardDeviationSampleInt64WithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableDecimal; + private static readonly MethodInfo __standardDeviationSampleNullableDecimalWithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableDouble; + private static readonly MethodInfo __standardDeviationSampleNullableDoubleWithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableInt32; + private static readonly MethodInfo __standardDeviationSampleNullableInt32WithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableInt64; + private static readonly MethodInfo __standardDeviationSampleNullableInt64WithSelector; + private static readonly MethodInfo __standardDeviationSampleNullableSingle; + private static readonly MethodInfo __standardDeviationSampleNullableSingleWithSelector; + private static readonly MethodInfo __standardDeviationSampleSingle; + private static readonly MethodInfo __standardDeviationSampleSingleWithSelector; private static readonly MethodInfo __whereWithLimit; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __medianOverloads; + private static readonly IReadOnlyMethodInfoSet __medianWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __percentileOverloads; + private static readonly IReadOnlyMethodInfoSet __percentileWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __standardDeviationOverloads; + private static readonly IReadOnlyMethodInfoSet __standardDeviationWithSelectorOverloads; + // static constructor static MongoEnumerableMethod() { + // initialize methods before sets of methods __allElements = ReflectionInfo.Method((IEnumerable source) => source.AllElements()); __allMatchingElements = ReflectionInfo.Method((IEnumerable source, string identifier) => source.AllMatchingElements(identifier)); __firstMatchingElement = ReflectionInfo.Method((IEnumerable source) => source.FirstMatchingElement()); @@ -113,7 +162,192 @@ static MongoEnumerableMethod() __percentileNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); __percentileSingle = ReflectionInfo.Method((IEnumerable source, IEnumerable percentiles) => source.Percentile(percentiles)); __percentileSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector, IEnumerable percentiles) => source.Percentile(selector, percentiles)); + __standardDeviationPopulationDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationNullableSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationPopulationSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationPopulation()); + __standardDeviationPopulationSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationPopulation(selector)); + __standardDeviationSampleDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableDecimal = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableDecimalWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableDouble = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableDoubleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableInt32 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableInt32WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableInt64 = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableInt64WithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleNullableSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleNullableSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); + __standardDeviationSampleSingle = ReflectionInfo.Method((IEnumerable source) => source.StandardDeviationSample()); + __standardDeviationSampleSingleWithSelector = ReflectionInfo.Method((IEnumerable source, Func selector) => source.StandardDeviationSample(selector)); __whereWithLimit = ReflectionInfo.Method((IEnumerable source, Func predicate, int limit) => source.Where(predicate, limit)); + + // initialize sets of methods after methods + __medianOverloads = MethodInfoSet.Create( + [ + __medianDecimal, + __medianDecimalWithSelector, + __medianDouble, + __medianDoubleWithSelector, + __medianInt32, + __medianInt32WithSelector, + __medianInt64, + __medianInt64WithSelector, + __medianNullableDecimal, + __medianNullableDecimalWithSelector, + __medianNullableDouble, + __medianNullableDoubleWithSelector, + __medianNullableInt32, + __medianNullableInt32WithSelector, + __medianNullableInt64, + __medianNullableInt64WithSelector, + __medianNullableSingle, + __medianNullableSingleWithSelector, + __medianSingle, + __medianSingleWithSelector + ]); + + __medianWithSelectorOverloads = MethodInfoSet.Create( + [ + __medianDecimalWithSelector, + __medianDoubleWithSelector, + __medianInt32WithSelector, + __medianInt64WithSelector, + __medianNullableDecimalWithSelector, + __medianNullableDoubleWithSelector, + __medianNullableInt32WithSelector, + __medianNullableInt64WithSelector, + __medianNullableSingleWithSelector, + __medianSingleWithSelector + ]); + + __percentileOverloads = MethodInfoSet.Create( + [ + __percentileDecimal, + __percentileDecimalWithSelector, + __percentileDouble, + __percentileDoubleWithSelector, + __percentileInt32, + __percentileInt32WithSelector, + __percentileInt64, + __percentileInt64WithSelector, + __percentileNullableDecimal, + __percentileNullableDecimalWithSelector, + __percentileNullableDouble, + __percentileNullableDoubleWithSelector, + __percentileNullableInt32, + __percentileNullableInt32WithSelector, + __percentileNullableInt64, + __percentileNullableInt64WithSelector, + __percentileNullableSingle, + __percentileNullableSingleWithSelector, + __percentileSingle, + __percentileSingleWithSelector + ]); + + __percentileWithSelectorOverloads = MethodInfoSet.Create( + [ + __percentileDecimalWithSelector, + __percentileDoubleWithSelector, + __percentileInt32WithSelector, + __percentileInt64WithSelector, + __percentileNullableDecimalWithSelector, + __percentileNullableDoubleWithSelector, + __percentileNullableInt32WithSelector, + __percentileNullableInt64WithSelector, + __percentileNullableSingleWithSelector, + __percentileSingleWithSelector + ]); + + __standardDeviationOverloads = MethodInfoSet.Create( + [ + __standardDeviationPopulationDecimal, + __standardDeviationPopulationDecimalWithSelector, + __standardDeviationPopulationDouble, + __standardDeviationPopulationDoubleWithSelector, + __standardDeviationPopulationInt32, + __standardDeviationPopulationInt32WithSelector, + __standardDeviationPopulationInt64, + __standardDeviationPopulationInt64WithSelector, + __standardDeviationPopulationNullableDecimal, + __standardDeviationPopulationNullableDecimalWithSelector, + __standardDeviationPopulationNullableDouble, + __standardDeviationPopulationNullableDoubleWithSelector, + __standardDeviationPopulationNullableInt32, + __standardDeviationPopulationNullableInt32WithSelector, + __standardDeviationPopulationNullableInt64, + __standardDeviationPopulationNullableInt64WithSelector, + __standardDeviationPopulationNullableSingle, + __standardDeviationPopulationNullableSingleWithSelector, + __standardDeviationPopulationSingle, + __standardDeviationPopulationSingleWithSelector, + __standardDeviationSampleDecimal, + __standardDeviationSampleDecimalWithSelector, + __standardDeviationSampleDouble, + __standardDeviationSampleDoubleWithSelector, + __standardDeviationSampleInt32, + __standardDeviationSampleInt32WithSelector, + __standardDeviationSampleInt64, + __standardDeviationSampleInt64WithSelector, + __standardDeviationSampleNullableDecimal, + __standardDeviationSampleNullableDecimalWithSelector, + __standardDeviationSampleNullableDouble, + __standardDeviationSampleNullableDoubleWithSelector, + __standardDeviationSampleNullableInt32, + __standardDeviationSampleNullableInt32WithSelector, + __standardDeviationSampleNullableInt64, + __standardDeviationSampleNullableInt64WithSelector, + __standardDeviationSampleNullableSingle, + __standardDeviationSampleNullableSingleWithSelector, + __standardDeviationSampleSingle, + __standardDeviationSampleSingleWithSelector, + ]); + + __standardDeviationWithSelectorOverloads = MethodInfoSet.Create( + [ + __standardDeviationPopulationDecimalWithSelector, + __standardDeviationPopulationDoubleWithSelector, + __standardDeviationPopulationInt32WithSelector, + __standardDeviationPopulationInt64WithSelector, + __standardDeviationPopulationNullableDecimalWithSelector, + __standardDeviationPopulationNullableDoubleWithSelector, + __standardDeviationPopulationNullableInt32WithSelector, + __standardDeviationPopulationNullableInt64WithSelector, + __standardDeviationPopulationNullableSingleWithSelector, + __standardDeviationPopulationSingleWithSelector, + __standardDeviationSampleDecimalWithSelector, + __standardDeviationSampleDoubleWithSelector, + __standardDeviationSampleInt32WithSelector, + __standardDeviationSampleInt64WithSelector, + __standardDeviationSampleNullableDecimalWithSelector, + __standardDeviationSampleNullableDoubleWithSelector, + __standardDeviationSampleNullableInt32WithSelector, + __standardDeviationSampleNullableInt64WithSelector, + __standardDeviationSampleNullableSingleWithSelector, + __standardDeviationSampleSingleWithSelector, + ]); } // public properties @@ -160,6 +394,54 @@ static MongoEnumerableMethod() public static MethodInfo PercentileNullableSingleWithSelector => __percentileNullableSingleWithSelector; public static MethodInfo PercentileSingle => __percentileSingle; public static MethodInfo PercentileSingleWithSelector => __percentileSingleWithSelector; + public static MethodInfo StandardDeviationPopulationDecimal => __standardDeviationPopulationDecimal; + public static MethodInfo StandardDeviationPopulationDecimalWithSelector => __standardDeviationPopulationDecimalWithSelector; + public static MethodInfo StandardDeviationPopulationDouble => __standardDeviationPopulationDouble; + public static MethodInfo StandardDeviationPopulationDoubleWithSelector => __standardDeviationPopulationDoubleWithSelector; + public static MethodInfo StandardDeviationPopulationInt32 => __standardDeviationPopulationInt32; + public static MethodInfo StandardDeviationPopulationInt32WithSelector => __standardDeviationPopulationInt32WithSelector; + public static MethodInfo StandardDeviationPopulationInt64 => __standardDeviationPopulationInt64; + public static MethodInfo StandardDeviationPopulationInt64WithSelector => __standardDeviationPopulationInt64WithSelector; + public static MethodInfo StandardDeviationPopulationNullableDecimal => __standardDeviationPopulationNullableDecimal; + public static MethodInfo StandardDeviationPopulationNullableDecimalWithSelector => __standardDeviationPopulationNullableDecimalWithSelector; + public static MethodInfo StandardDeviationPopulationNullableDouble => __standardDeviationPopulationNullableDouble; + public static MethodInfo StandardDeviationPopulationNullableDoubleWithSelector => __standardDeviationPopulationNullableDoubleWithSelector; + public static MethodInfo StandardDeviationPopulationNullableInt32 => __standardDeviationPopulationNullableInt32; + public static MethodInfo StandardDeviationPopulationNullableInt32WithSelector => __standardDeviationPopulationNullableInt32WithSelector; + public static MethodInfo StandardDeviationPopulationNullableInt64 => __standardDeviationPopulationNullableInt64; + public static MethodInfo StandardDeviationPopulationNullableInt64WithSelector => __standardDeviationPopulationNullableInt64WithSelector; + public static MethodInfo StandardDeviationPopulationNullableSingle => __standardDeviationPopulationNullableSingle; + public static MethodInfo StandardDeviationPopulationNullableSingleWithSelector => __standardDeviationPopulationNullableSingleWithSelector; + public static MethodInfo StandardDeviationPopulationSingle => __standardDeviationPopulationSingle; + public static MethodInfo StandardDeviationPopulationSingleWithSelector => __standardDeviationPopulationSingleWithSelector; + public static MethodInfo StandardDeviationSampleDecimal => __standardDeviationSampleDecimal; + public static MethodInfo StandardDeviationSampleDecimalWithSelector => __standardDeviationSampleDecimalWithSelector; + public static MethodInfo StandardDeviationSampleDouble => __standardDeviationSampleDouble; + public static MethodInfo StandardDeviationSampleDoubleWithSelector => __standardDeviationSampleDoubleWithSelector; + public static MethodInfo StandardDeviationSampleInt32 => __standardDeviationSampleInt32; + public static MethodInfo StandardDeviationSampleInt32WithSelector => __standardDeviationSampleInt32WithSelector; + public static MethodInfo StandardDeviationSampleInt64 => __standardDeviationSampleInt64; + public static MethodInfo StandardDeviationSampleInt64WithSelector => __standardDeviationSampleInt64WithSelector; + public static MethodInfo StandardDeviationSampleNullableDecimal => __standardDeviationSampleNullableDecimal; + public static MethodInfo StandardDeviationSampleNullableDecimalWithSelector => __standardDeviationSampleNullableDecimalWithSelector; + public static MethodInfo StandardDeviationSampleNullableDouble => __standardDeviationSampleNullableDouble; + public static MethodInfo StandardDeviationSampleNullableDoubleWithSelector => __standardDeviationSampleNullableDoubleWithSelector; + public static MethodInfo StandardDeviationSampleNullableInt32 => __standardDeviationSampleNullableInt32; + public static MethodInfo StandardDeviationSampleNullableInt32WithSelector => __standardDeviationSampleNullableInt32WithSelector; + public static MethodInfo StandardDeviationSampleNullableInt64 => __standardDeviationSampleNullableInt64; + public static MethodInfo StandardDeviationSampleNullableInt64WithSelector => __standardDeviationSampleNullableInt64WithSelector; + public static MethodInfo StandardDeviationSampleNullableSingle => __standardDeviationSampleNullableSingle; + public static MethodInfo StandardDeviationSampleNullableSingleWithSelector => __standardDeviationSampleNullableSingleWithSelector; + public static MethodInfo StandardDeviationSampleSingle => __standardDeviationSampleSingle; + public static MethodInfo StandardDeviationSampleSingleWithSelector => __standardDeviationSampleSingleWithSelector; public static MethodInfo WhereWithLimit => __whereWithLimit; + + // sets of methods + public static IReadOnlyMethodInfoSet MedianOverloads => __medianOverloads; + public static IReadOnlyMethodInfoSet MedianWithSelectorOverloads => __medianWithSelectorOverloads; + public static IReadOnlyMethodInfoSet PercentileOverloads => __percentileOverloads; + public static IReadOnlyMethodInfoSet PercentileWithSelectorOverloads => __percentileWithSelectorOverloads; + public static IReadOnlyMethodInfoSet StandardDeviationOverloads => __standardDeviationOverloads; + public static IReadOnlyMethodInfoSet StandardDeviationWithSelectorOverloads => __standardDeviationWithSelectorOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs index 245c04c6733..1cfede97f61 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs @@ -179,9 +179,36 @@ internal static class MongoQueryableMethod private static readonly MethodInfo __sumSingleWithSelectorAsync; private static readonly MethodInfo __takeWithLong; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __averageOverloads; + private static readonly IReadOnlyMethodInfoSet __averageWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __countOverloads; + private static readonly IReadOnlyMethodInfoSet __firstOverloads; + private static readonly IReadOnlyMethodInfoSet __firstWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __longCountOverloads; + private static readonly IReadOnlyMethodInfoSet __lookupOverloads; + private static readonly IReadOnlyMethodInfoSet __lookupWithDocumentsOverloads; + private static readonly IReadOnlyMethodInfoSet __lookupWithDocumentsAndPipelineOverloads; + private static readonly IReadOnlyMethodInfoSet __lookupWithFromOverloads; + private static readonly IReadOnlyMethodInfoSet __lookupWithFromAndPipelineOverloads; + private static readonly IReadOnlyMethodInfoSet __lookupWithLocalFieldAndForeignFieldOverloads; + private static readonly IReadOnlyMethodInfoSet __maxOverloads; + private static readonly IReadOnlyMethodInfoSet __minOverloads; + private static readonly IReadOnlyMethodInfoSet __singleOverloads; + private static readonly IReadOnlyMethodInfoSet __singleOrDefaultOverloads; + private static readonly IReadOnlyMethodInfoSet __singleWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __skipOrTakeWithLong; + private static readonly IReadOnlyMethodInfoSet __standardDeviationOverloads; + private static readonly IReadOnlyMethodInfoSet __standardDeviationNullableOverloads; + private static readonly IReadOnlyMethodInfoSet __standardDeviationPopulationOverloads; + private static readonly IReadOnlyMethodInfoSet __standardDeviationWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __sumOverloads; + private static readonly IReadOnlyMethodInfoSet __sumWithSelectorOverloads; + // static constructor static MongoQueryableMethod() { + // initialize methods before sets of methods __anyAsync = ReflectionInfo.Method((IQueryable source, CancellationToken cancellationToken) => source.AnyAsync(cancellationToken)); __anyWithPredicateAsync = ReflectionInfo.Method((IQueryable source, Expression> predicate, CancellationToken cancellationToken) => source.AnyAsync(predicate, cancellationToken)); __appendStage = ReflectionInfo.Method((IQueryable source, PipelineStageDefinition stage, IBsonSerializer resultSerializer) => source.AppendStage(stage, resultSerializer)); @@ -334,7 +361,408 @@ static MongoQueryableMethod() __sumSingleAsync = ReflectionInfo.Method((IQueryable source, CancellationToken cancellationToken) => source.SumAsync(cancellationToken)); __sumSingleWithSelectorAsync = ReflectionInfo.Method((IQueryable source, Expression> selector, CancellationToken cancellationToken) => source.SumAsync(selector, cancellationToken)); __takeWithLong = ReflectionInfo.Method((IQueryable source, long count) => source.Take(count)); - } + + // initialize sets of methods after methods + __averageOverloads = MethodInfoSet.Create( + [ + __averageDecimalAsync, + __averageDecimalWithSelectorAsync, + __averageDoubleAsync, + __averageDoubleWithSelectorAsync, + __averageInt32Async, + __averageInt32WithSelectorAsync, + __averageInt64Async, + __averageInt64WithSelectorAsync, + __averageNullableDecimalAsync, + __averageNullableDecimalWithSelectorAsync, + __averageNullableDoubleAsync, + __averageNullableDoubleWithSelectorAsync, + __averageNullableInt32Async, + __averageNullableInt32WithSelectorAsync, + __averageNullableInt64Async, + __averageNullableInt64WithSelectorAsync, + __averageNullableSingleAsync, + __averageNullableSingleWithSelectorAsync, + __averageSingleAsync, + __averageSingleWithSelectorAsync + ]); + + __averageWithSelectorOverloads = MethodInfoSet.Create( + [ + __averageDecimalWithSelectorAsync, + __averageDoubleWithSelectorAsync, + __averageInt32WithSelectorAsync, + __averageInt64WithSelectorAsync, + __averageNullableDecimalWithSelectorAsync, + __averageNullableDoubleWithSelectorAsync, + __averageNullableInt32WithSelectorAsync, + __averageNullableInt64WithSelectorAsync, + __averageNullableSingleWithSelectorAsync, + __averageSingleWithSelectorAsync + ]); + + __countOverloads = MethodInfoSet.Create( + [ + __countAsync, + __countWithPredicateAsync + ]); + + __firstOverloads = MethodInfoSet.Create( + [ + __firstAsync, + __firstOrDefaultAsync, + __firstOrDefaultWithPredicateAsync, + __firstWithPredicateAsync + ]); + + __firstWithPredicateOverloads = MethodInfoSet.Create( + [ + __firstOrDefaultWithPredicateAsync, + __firstWithPredicateAsync + ]); + + __longCountOverloads = MethodInfoSet.Create( + [ + __longCountAsync, + __longCountWithPredicateAsync + ]); + + __lookupOverloads = MethodInfoSet.Create( + [ + __lookupWithDocumentsAndLocalFieldAndForeignField, + __lookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, + __lookupWithDocumentsAndPipeline, + __lookupWithFromAndLocalFieldAndForeignField, + __lookupWithFromAndLocalFieldAndForeignFieldAndPipeline, + __lookupWithFromAndPipeline + ]); + + __lookupWithDocumentsOverloads = MethodInfoSet.Create( + [ + __lookupWithDocumentsAndLocalFieldAndForeignField, + __lookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, + __lookupWithDocumentsAndPipeline + ]); + + __lookupWithDocumentsAndPipelineOverloads = MethodInfoSet.Create( + [ + __lookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, + __lookupWithDocumentsAndPipeline + ]); + + __lookupWithFromOverloads = MethodInfoSet.Create( + [ + __lookupWithFromAndLocalFieldAndForeignField, + __lookupWithFromAndLocalFieldAndForeignFieldAndPipeline, + __lookupWithFromAndPipeline + ]); + + __lookupWithFromAndPipelineOverloads = MethodInfoSet.Create( + [ + __lookupWithFromAndLocalFieldAndForeignFieldAndPipeline, + __lookupWithFromAndPipeline + ]); + + __lookupWithLocalFieldAndForeignFieldOverloads = MethodInfoSet.Create( + [ + __lookupWithDocumentsAndLocalFieldAndForeignField, + __lookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, + __lookupWithFromAndLocalFieldAndForeignField, + __lookupWithFromAndLocalFieldAndForeignFieldAndPipeline + ]); + + __maxOverloads = MethodInfoSet.Create( + [ + __maxAsync, + __maxWithSelectorAsync + ]); + + __minOverloads = MethodInfoSet.Create( + [ + __minAsync, + __minWithSelectorAsync + ]); + + __singleOverloads = MethodInfoSet.Create( + [ + __singleAsync, + __singleOrDefaultAsync, + __singleOrDefaultWithPredicateAsync, + __singleWithPredicateAsync + ]); + + __singleOrDefaultOverloads = MethodInfoSet.Create( + [ + __singleOrDefaultAsync, + __singleOrDefaultWithPredicateAsync + ]); + + __singleWithPredicateOverloads = MethodInfoSet.Create( + [ + __singleOrDefaultWithPredicateAsync, + __singleWithPredicateAsync + ]); + + __skipOrTakeWithLong = MethodInfoSet.Create( + [ + __skipWithLong, + __takeWithLong + ]); + + __standardDeviationOverloads = MethodInfoSet.Create( + [ + __standardDeviationPopulationDecimal, + __standardDeviationPopulationDecimalAsync, + __standardDeviationPopulationDecimalWithSelector, + __standardDeviationPopulationDecimalWithSelectorAsync, + __standardDeviationPopulationDouble, + __standardDeviationPopulationDoubleAsync, + __standardDeviationPopulationDoubleWithSelector, + __standardDeviationPopulationDoubleWithSelectorAsync, + __standardDeviationPopulationInt32, + __standardDeviationPopulationInt32Async, + __standardDeviationPopulationInt32WithSelector, + __standardDeviationPopulationInt32WithSelectorAsync, + __standardDeviationPopulationInt64, + __standardDeviationPopulationInt64Async, + __standardDeviationPopulationInt64WithSelector, + __standardDeviationPopulationInt64WithSelectorAsync, + __standardDeviationPopulationNullableDecimal, + __standardDeviationPopulationNullableDecimalAsync, + __standardDeviationPopulationNullableDecimalWithSelector, + __standardDeviationPopulationNullableDecimalWithSelectorAsync, + __standardDeviationPopulationNullableDouble, + __standardDeviationPopulationNullableDoubleAsync, + __standardDeviationPopulationNullableDoubleWithSelector, + __standardDeviationPopulationNullableDoubleWithSelectorAsync, + __standardDeviationPopulationNullableInt32, + __standardDeviationPopulationNullableInt32Async, + __standardDeviationPopulationNullableInt32WithSelector, + __standardDeviationPopulationNullableInt32WithSelectorAsync, + __standardDeviationPopulationNullableInt64, + __standardDeviationPopulationNullableInt64Async, + __standardDeviationPopulationNullableInt64WithSelector, + __standardDeviationPopulationNullableInt64WithSelectorAsync, + __standardDeviationPopulationNullableSingle, + __standardDeviationPopulationNullableSingleAsync, + __standardDeviationPopulationNullableSingleWithSelector, + __standardDeviationPopulationNullableSingleWithSelectorAsync, + __standardDeviationPopulationSingle, + __standardDeviationPopulationSingleAsync, + __standardDeviationPopulationSingleWithSelector, + __standardDeviationPopulationSingleWithSelectorAsync, + __standardDeviationSampleDecimal, + __standardDeviationSampleDecimalAsync, + __standardDeviationSampleDecimalWithSelector, + __standardDeviationSampleDecimalWithSelectorAsync, + __standardDeviationSampleDouble, + __standardDeviationSampleDoubleAsync, + __standardDeviationSampleDoubleWithSelector, + __standardDeviationSampleDoubleWithSelectorAsync, + __standardDeviationSampleInt32, + __standardDeviationSampleInt32Async, + __standardDeviationSampleInt32WithSelector, + __standardDeviationSampleInt32WithSelectorAsync, + __standardDeviationSampleInt64, + __standardDeviationSampleInt64Async, + __standardDeviationSampleInt64WithSelector, + __standardDeviationSampleInt64WithSelectorAsync, + __standardDeviationSampleNullableDecimal, + __standardDeviationSampleNullableDecimalAsync, + __standardDeviationSampleNullableDecimalWithSelector, + __standardDeviationSampleNullableDecimalWithSelectorAsync, + __standardDeviationSampleNullableDouble, + __standardDeviationSampleNullableDoubleAsync, + __standardDeviationSampleNullableDoubleWithSelector, + __standardDeviationSampleNullableDoubleWithSelectorAsync, + __standardDeviationSampleNullableInt32, + __standardDeviationSampleNullableInt32Async, + __standardDeviationSampleNullableInt32WithSelector, + __standardDeviationSampleNullableInt32WithSelectorAsync, + __standardDeviationSampleNullableInt64, + __standardDeviationSampleNullableInt64Async, + __standardDeviationSampleNullableInt64WithSelector, + __standardDeviationSampleNullableInt64WithSelectorAsync, + __standardDeviationSampleNullableSingle, + __standardDeviationSampleNullableSingleAsync, + __standardDeviationSampleNullableSingleWithSelector, + __standardDeviationSampleNullableSingleWithSelectorAsync, + __standardDeviationSampleSingle, + __standardDeviationSampleSingleAsync, + __standardDeviationSampleSingleWithSelector, + __standardDeviationSampleSingleWithSelectorAsync + ]); + + __standardDeviationNullableOverloads = MethodInfoSet.Create( + [ + __standardDeviationPopulationNullableDecimal, + __standardDeviationPopulationNullableDecimalAsync, + __standardDeviationPopulationNullableDecimalWithSelector, + __standardDeviationPopulationNullableDecimalWithSelectorAsync, + __standardDeviationPopulationNullableDouble, + __standardDeviationPopulationNullableDoubleAsync, + __standardDeviationPopulationNullableDoubleWithSelector, + __standardDeviationPopulationNullableDoubleWithSelectorAsync, + __standardDeviationPopulationNullableInt32, + __standardDeviationPopulationNullableInt32Async, + __standardDeviationPopulationNullableInt32WithSelector, + __standardDeviationPopulationNullableInt32WithSelectorAsync, + __standardDeviationPopulationNullableInt64, + __standardDeviationPopulationNullableInt64Async, + __standardDeviationPopulationNullableInt64WithSelector, + __standardDeviationPopulationNullableInt64WithSelectorAsync, + __standardDeviationPopulationNullableSingle, + __standardDeviationPopulationNullableSingleAsync, + __standardDeviationPopulationNullableSingleWithSelector, + __standardDeviationPopulationNullableSingleWithSelectorAsync, + __standardDeviationSampleNullableDecimal, + __standardDeviationSampleNullableDecimalAsync, + __standardDeviationSampleNullableDecimalWithSelector, + __standardDeviationSampleNullableDecimalWithSelectorAsync, + __standardDeviationSampleNullableDouble, + __standardDeviationSampleNullableDoubleAsync, + __standardDeviationSampleNullableDoubleWithSelector, + __standardDeviationSampleNullableDoubleWithSelectorAsync, + __standardDeviationSampleNullableInt32, + __standardDeviationSampleNullableInt32Async, + __standardDeviationSampleNullableInt32WithSelector, + __standardDeviationSampleNullableInt32WithSelectorAsync, + __standardDeviationSampleNullableInt64, + __standardDeviationSampleNullableInt64Async, + __standardDeviationSampleNullableInt64WithSelector, + __standardDeviationSampleNullableInt64WithSelectorAsync, + __standardDeviationSampleNullableSingle, + __standardDeviationSampleNullableSingleAsync, + __standardDeviationSampleNullableSingleWithSelector, + __standardDeviationSampleNullableSingleWithSelectorAsync + ]); + + __standardDeviationPopulationOverloads = MethodInfoSet.Create( + [ + __standardDeviationPopulationDecimal, + __standardDeviationPopulationDecimalAsync, + __standardDeviationPopulationDecimalWithSelector, + __standardDeviationPopulationDecimalWithSelectorAsync, + __standardDeviationPopulationDouble, + __standardDeviationPopulationDoubleAsync, + __standardDeviationPopulationDoubleWithSelector, + __standardDeviationPopulationDoubleWithSelectorAsync, + __standardDeviationPopulationInt32, + __standardDeviationPopulationInt32Async, + __standardDeviationPopulationInt32WithSelector, + __standardDeviationPopulationInt32WithSelectorAsync, + __standardDeviationPopulationInt64, + __standardDeviationPopulationInt64Async, + __standardDeviationPopulationInt64WithSelector, + __standardDeviationPopulationInt64WithSelectorAsync, + __standardDeviationPopulationNullableDecimal, + __standardDeviationPopulationNullableDecimalAsync, + __standardDeviationPopulationNullableDecimalWithSelector, + __standardDeviationPopulationNullableDecimalWithSelectorAsync, + __standardDeviationPopulationNullableDouble, + __standardDeviationPopulationNullableDoubleAsync, + __standardDeviationPopulationNullableDoubleWithSelector, + __standardDeviationPopulationNullableDoubleWithSelectorAsync, + __standardDeviationPopulationNullableInt32, + __standardDeviationPopulationNullableInt32Async, + __standardDeviationPopulationNullableInt32WithSelector, + __standardDeviationPopulationNullableInt32WithSelectorAsync, + __standardDeviationPopulationNullableInt64, + __standardDeviationPopulationNullableInt64Async, + __standardDeviationPopulationNullableInt64WithSelector, + __standardDeviationPopulationNullableInt64WithSelectorAsync, + __standardDeviationPopulationNullableSingle, + __standardDeviationPopulationNullableSingleAsync, + __standardDeviationPopulationNullableSingleWithSelector, + __standardDeviationPopulationNullableSingleWithSelectorAsync, + __standardDeviationPopulationSingle, + __standardDeviationPopulationSingleAsync, + __standardDeviationPopulationSingleWithSelector, + __standardDeviationPopulationSingleWithSelectorAsync + ]); + + __standardDeviationWithSelectorOverloads = MethodInfoSet.Create( + [ + __standardDeviationPopulationDecimalWithSelector, + __standardDeviationPopulationDecimalWithSelectorAsync, + __standardDeviationPopulationDoubleWithSelector, + __standardDeviationPopulationDoubleWithSelectorAsync, + __standardDeviationPopulationInt32WithSelector, + __standardDeviationPopulationInt32WithSelectorAsync, + __standardDeviationPopulationInt64WithSelector, + __standardDeviationPopulationInt64WithSelectorAsync, + __standardDeviationPopulationNullableDecimalWithSelector, + __standardDeviationPopulationNullableDecimalWithSelectorAsync, + __standardDeviationPopulationNullableDoubleWithSelector, + __standardDeviationPopulationNullableDoubleWithSelectorAsync, + __standardDeviationPopulationNullableInt32WithSelector, + __standardDeviationPopulationNullableInt32WithSelectorAsync, + __standardDeviationPopulationNullableInt64WithSelector, + __standardDeviationPopulationNullableInt64WithSelectorAsync, + __standardDeviationPopulationNullableSingleWithSelector, + __standardDeviationPopulationNullableSingleWithSelectorAsync, + __standardDeviationPopulationSingleWithSelector, + __standardDeviationPopulationSingleWithSelectorAsync, + __standardDeviationSampleDecimalWithSelector, + __standardDeviationSampleDecimalWithSelectorAsync, + __standardDeviationSampleDoubleWithSelector, + __standardDeviationSampleDoubleWithSelectorAsync, + __standardDeviationSampleInt32WithSelector, + __standardDeviationSampleInt32WithSelectorAsync, + __standardDeviationSampleInt64WithSelector, + __standardDeviationSampleInt64WithSelectorAsync, + __standardDeviationSampleNullableDecimalWithSelector, + __standardDeviationSampleNullableDecimalWithSelectorAsync, + __standardDeviationSampleNullableDoubleWithSelector, + __standardDeviationSampleNullableDoubleWithSelectorAsync, + __standardDeviationSampleNullableInt32WithSelector, + __standardDeviationSampleNullableInt32WithSelectorAsync, + __standardDeviationSampleNullableInt64WithSelector, + __standardDeviationSampleNullableInt64WithSelectorAsync, + __standardDeviationSampleNullableSingleWithSelector, + __standardDeviationSampleNullableSingleWithSelectorAsync, + __standardDeviationSampleSingleWithSelector, + __standardDeviationSampleSingleWithSelectorAsync + ]); + + __sumOverloads = MethodInfoSet.Create( + [ + __sumDecimalAsync, + __sumDecimalWithSelectorAsync, + __sumDoubleAsync, + __sumDoubleWithSelectorAsync, + __sumInt32Async, + __sumInt32WithSelectorAsync, + __sumInt64Async, + __sumInt64WithSelectorAsync, + __sumNullableDecimalAsync, + __sumNullableDecimalWithSelectorAsync, + __sumNullableDoubleAsync, + __sumNullableDoubleWithSelectorAsync, + __sumNullableInt32Async, + __sumNullableInt32WithSelectorAsync, + __sumNullableInt64Async, + __sumNullableInt64WithSelectorAsync, + __sumNullableSingleAsync, + __sumNullableSingleWithSelectorAsync, + __sumSingleAsync, + __sumSingleWithSelectorAsync + ]); + + __sumWithSelectorOverloads = MethodInfoSet.Create( + [ + __sumDecimalWithSelectorAsync, + __sumDoubleWithSelectorAsync, + __sumInt32WithSelectorAsync, + __sumInt64WithSelectorAsync, + __sumNullableDecimalWithSelectorAsync, + __sumNullableDoubleWithSelectorAsync, + __sumNullableInt32WithSelectorAsync, + __sumNullableInt64WithSelectorAsync, + __sumNullableSingleWithSelectorAsync, + __sumSingleWithSelectorAsync + ]); + } // public properties public static MethodInfo AnyAsync => __anyAsync; @@ -489,5 +917,31 @@ static MongoQueryableMethod() public static MethodInfo SumSingleAsync => __sumSingleAsync; public static MethodInfo SumSingleWithSelectorAsync => __sumSingleWithSelectorAsync; public static MethodInfo TakeWithLong => __takeWithLong; + + // sets of methods + public static IReadOnlyMethodInfoSet AverageOverloads => __averageOverloads; + public static IReadOnlyMethodInfoSet AverageWithSelectorOverloads => __averageWithSelectorOverloads; + public static IReadOnlyMethodInfoSet CountOverloads => __countOverloads; + public static IReadOnlyMethodInfoSet FirstOverloads => __firstOverloads; + public static IReadOnlyMethodInfoSet FirstWithPredicateOverloads => __firstWithPredicateOverloads; + public static IReadOnlyMethodInfoSet LongCountOverloads => __longCountOverloads; + public static IReadOnlyMethodInfoSet LookupOverloads => __lookupOverloads; + public static IReadOnlyMethodInfoSet LookupWithDocumentsOverloads => __lookupWithDocumentsOverloads; + public static IReadOnlyMethodInfoSet LookupWithDocumentsAndPipelineOverloads => __lookupWithDocumentsAndPipelineOverloads; + public static IReadOnlyMethodInfoSet LookupWithFromOverloads => __lookupWithFromOverloads; + public static IReadOnlyMethodInfoSet LookupWithFromAndPipelineOverloads => __lookupWithFromAndPipelineOverloads; + public static IReadOnlyMethodInfoSet LookupWithLocalFieldAndForeignFieldOverloads => __lookupWithLocalFieldAndForeignFieldOverloads; + public static IReadOnlyMethodInfoSet MaxOverloads => __maxOverloads; + public static IReadOnlyMethodInfoSet MinOverloads => __minOverloads; + public static IReadOnlyMethodInfoSet SingleOverloads => __singleOverloads; + public static IReadOnlyMethodInfoSet SingleOrDefaultOverloads => __singleOrDefaultOverloads; + public static IReadOnlyMethodInfoSet SingleWithPredicateOverloads => __singleWithPredicateOverloads; + public static IReadOnlyMethodInfoSet SkipOrTakeWithLong => __skipOrTakeWithLong; + public static IReadOnlyMethodInfoSet StandardDeviationOverloads => __standardDeviationOverloads; + public static IReadOnlyMethodInfoSet StandardDeviationNullableOverloads => __standardDeviationNullableOverloads; + public static IReadOnlyMethodInfoSet StandardDeviationPopulationOverloads => __standardDeviationPopulationOverloads; + public static IReadOnlyMethodInfoSet StandardDeviationWithSelectorOverloads => __standardDeviationWithSelectorOverloads; + public static IReadOnlyMethodInfoSet SumOverloads => __sumOverloads; + public static IReadOnlyMethodInfoSet SumWithSelectorOverloads => __sumWithSelectorOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MqlMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MqlMethod.cs index 4b82e4a545c..2def335459e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MqlMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MqlMethod.cs @@ -36,9 +36,15 @@ internal static class MqlMethod private static readonly MethodInfo __isNullOrMissing; private static readonly MethodInfo __sigmoid; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __dateFromStringOverloads; + private static readonly IReadOnlyMethodInfoSet __dateFromStringWithFormatOverloads; + private static readonly IReadOnlyMethodInfoSet __dateFromStringWithTimezoneOverloads; + // static constructor static MqlMethod() { + // initialize methods before sets of methods __constantWithRepresentation = ReflectionInfo.Method((object value, BsonType representation) => Mql.Constant(value, representation)); __constantWithSerializer = ReflectionInfo.Method((object value, IBsonSerializer serializer) => Mql.Constant(value, serializer)); __convert = ReflectionInfo.Method((object value, ConvertOptions options) => Mql.Convert(value, options)); @@ -51,6 +57,28 @@ static MqlMethod() __isMissing = ReflectionInfo.Method((object field) => Mql.IsMissing(field)); __isNullOrMissing = ReflectionInfo.Method((object field) => Mql.IsNullOrMissing(field)); __sigmoid = ReflectionInfo.Method((double value) => Mql.Sigmoid(value)); + + // initialize sets of methods after methods + __dateFromStringOverloads = MethodInfoSet.Create( + [ + __dateFromString, + __dateFromStringWithFormat, + __dateFromStringWithFormatAndTimezone, + __dateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull + ]); + + __dateFromStringWithFormatOverloads = MethodInfoSet.Create( + [ + __dateFromStringWithFormat, + __dateFromStringWithFormatAndTimezone, + __dateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull + ]); + + __dateFromStringWithTimezoneOverloads = MethodInfoSet.Create( + [ + __dateFromStringWithFormatAndTimezone, + __dateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull + ]); } // public properties @@ -66,5 +94,10 @@ static MqlMethod() public static MethodInfo IsMissing => __isMissing; public static MethodInfo IsNullOrMissing => __isNullOrMissing; public static MethodInfo Sigmoid => __sigmoid; + + // sets of methods + public static IReadOnlyMethodInfoSet DateFromStringOverloads => __dateFromStringOverloads; + public static IReadOnlyMethodInfoSet DateFromStringWithFormatOverloads => __dateFromStringWithFormatOverloads; + public static IReadOnlyMethodInfoSet DateFromStringWithTimezoneOverloads => __dateFromStringWithTimezoneOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs index 17896da1313..8b36a92623e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs @@ -28,6 +28,7 @@ internal static class QueryableMethod private static readonly MethodInfo __aggregateWithSeedAndFunc; private static readonly MethodInfo __aggregateWithSeedFuncAndResultSelector; private static readonly MethodInfo __all; + private static readonly MethodInfo __allWithPredicate; private static readonly MethodInfo __any; private static readonly MethodInfo __anyWithPredicate; private static readonly MethodInfo __append; @@ -90,7 +91,7 @@ internal static class QueryableMethod private static readonly MethodInfo __prepend; private static readonly MethodInfo __reverse; private static readonly MethodInfo __select; - private static readonly MethodInfo __selectMany; + private static readonly MethodInfo __selectManyWithSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorAndResultSelector; private static readonly MethodInfo __selectManyWithCollectionSelectorTakingIndexAndResultSelector; private static readonly MethodInfo __selectManyWithSelectorTakingIndex; @@ -131,13 +132,36 @@ internal static class QueryableMethod private static readonly MethodInfo __whereWithPredicateTakingIndex; private static readonly MethodInfo __zip; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __averageOverloads; + private static readonly IReadOnlyMethodInfoSet __averageWithSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __countOverloads; + private static readonly IReadOnlyMethodInfoSet __firstOverloads; + private static readonly IReadOnlyMethodInfoSet __firstWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __groupByOverloads; + private static readonly IReadOnlyMethodInfoSet __groupByWithElementSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __groupByWithResultSelectorOverloads; + private static readonly IReadOnlyMethodInfoSet __lastOverloads; + private static readonly IReadOnlyMethodInfoSet __lastWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __longCountOverloads; + private static readonly IReadOnlyMethodInfoSet __maxOverloads; + private static readonly IReadOnlyMethodInfoSet __minOverloads; + private static readonly IReadOnlyMethodInfoSet __selectManyOverloads; + private static readonly IReadOnlyMethodInfoSet __singleOverloads; + private static readonly IReadOnlyMethodInfoSet __singleOrDefaultOverloads; + private static readonly IReadOnlyMethodInfoSet __singleWithPredicateOverloads; + private static readonly IReadOnlyMethodInfoSet __sumOverloads; + private static readonly IReadOnlyMethodInfoSet __sumWithSelectorOverloads; + // static constructor static QueryableMethod() { + // initialize methods before sets of methods __aggregateWithFunc = ReflectionInfo.Method((IQueryable source, Expression> func) => source.Aggregate(func)); __aggregateWithSeedAndFunc = ReflectionInfo.Method((IQueryable source, object seed, Expression> func) => source.Aggregate(seed, func)); __aggregateWithSeedFuncAndResultSelector = ReflectionInfo.Method((IQueryable source, object seed, Expression> func, Expression> selector) => source.Aggregate(seed, func, selector)); __all = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.All(predicate)); + __allWithPredicate = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.All(predicate)); __any = ReflectionInfo.Method((IQueryable source) => source.Any()); __anyWithPredicate = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.Any(predicate)); __append = ReflectionInfo.Method((IQueryable source, object element) => source.Append(element)); @@ -200,7 +224,7 @@ static QueryableMethod() __prepend = ReflectionInfo.Method((IQueryable source, object element) => source.Prepend(element)); __reverse = ReflectionInfo.Method((IQueryable source) => source.Reverse()); __select = ReflectionInfo.Method((IQueryable source, Expression> selector) => source.Select(selector)); - __selectMany = ReflectionInfo.Method((IQueryable source, Expression>> selector) => source.SelectMany(selector)); + __selectManyWithSelector = ReflectionInfo.Method((IQueryable source, Expression>> selector) => source.SelectMany(selector)); __selectManyWithCollectionSelectorAndResultSelector = ReflectionInfo.Method((IQueryable source, Expression>> collectionSelector, Expression> resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithCollectionSelectorTakingIndexAndResultSelector = ReflectionInfo.Method((IQueryable source, Expression>> collectionSelector, Expression> resultSelector) => source.SelectMany(collectionSelector, resultSelector)); __selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IQueryable source, Expression>> selector) => source.SelectMany(selector)); @@ -240,6 +264,181 @@ static QueryableMethod() __where = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.Where(predicate)); __whereWithPredicateTakingIndex = ReflectionInfo.Method((IQueryable source, Expression> predicate) => source.Where(predicate)); __zip = ReflectionInfo.Method((IQueryable source1, IEnumerable source2, Expression> resultSelector) => source1.Zip(source2, resultSelector)); + + // initialize sets of methods after methods + __averageOverloads = MethodInfoSet.Create( + [ + __averageDecimal, + __averageDecimalWithSelector, + __averageDouble, + __averageDoubleWithSelector, + __averageInt32, + __averageInt32WithSelector, + __averageInt64, + __averageInt64WithSelector, + __averageNullableDecimal, + __averageNullableDecimalWithSelector, + __averageNullableDouble, + __averageNullableDoubleWithSelector, + __averageNullableInt32, + __averageNullableInt32WithSelector, + __averageNullableInt64, + __averageNullableInt64WithSelector, + __averageNullableSingle, + __averageNullableSingleWithSelector, + __averageSingle, + __averageSingleWithSelector + ]); + + __averageWithSelectorOverloads = MethodInfoSet.Create( + [ + __averageDecimalWithSelector, + __averageDoubleWithSelector, + __averageInt32WithSelector, + __averageInt64WithSelector, + __averageNullableDecimalWithSelector, + __averageNullableDoubleWithSelector, + __averageNullableInt32WithSelector, + __averageNullableInt64WithSelector, + __averageNullableSingleWithSelector, + __averageSingleWithSelector + ]); + + __countOverloads = MethodInfoSet.Create( + [ + __count, + __countWithPredicate + ]); + + __firstOverloads = MethodInfoSet.Create( + [ + __first, + __firstOrDefault, + __firstOrDefaultWithPredicate, + __firstWithPredicate + ]); + + __firstWithPredicateOverloads = MethodInfoSet.Create( + [ + __firstOrDefaultWithPredicate, + __firstWithPredicate + ]); + + __groupByOverloads = MethodInfoSet.Create( + [ + __groupByWithKeySelector, + __groupByWithKeySelectorAndElementSelector, + __groupByWithKeySelectorAndResultSelector, + __groupByWithKeySelectorElementSelectorAndResultSelector + ]); + + __groupByWithElementSelectorOverloads = MethodInfoSet.Create( + [ + __groupByWithKeySelectorAndElementSelector, + __groupByWithKeySelectorElementSelectorAndResultSelector + ]); + + __groupByWithResultSelectorOverloads = MethodInfoSet.Create( + [ + __groupByWithKeySelectorAndResultSelector, + __groupByWithKeySelectorElementSelectorAndResultSelector + ]); + + __lastOverloads = MethodInfoSet.Create( + [ + __last, + __lastOrDefault, + __lastOrDefaultWithPredicate, + __lastWithPredicate + ]); + + __lastWithPredicateOverloads = MethodInfoSet.Create( + [ + __lastOrDefaultWithPredicate, + __lastWithPredicate + ]); + + __longCountOverloads = MethodInfoSet.Create( + [ + __longCount, + __longCountWithPredicate + ]); + + __maxOverloads = MethodInfoSet.Create( + [ + __max, + __maxWithSelector + ]); + + __minOverloads = MethodInfoSet.Create( + [ + __min, + __minWithSelector + ]); + + __selectManyOverloads = MethodInfoSet.Create( + [ + __selectManyWithSelector, + __selectManyWithCollectionSelectorAndResultSelector + ]); + + __singleOverloads = MethodInfoSet.Create( + [ + __single, + __singleOrDefault, + __singleOrDefaultWithPredicate, + __singleWithPredicate + ]); + + __singleOrDefaultOverloads = MethodInfoSet.Create( + [ + __singleOrDefault, + __singleOrDefaultWithPredicate + ]); + + __singleWithPredicateOverloads = MethodInfoSet.Create( + [ + __singleOrDefaultWithPredicate, + __singleWithPredicate + ]); + + __sumOverloads = MethodInfoSet.Create( + [ + __sumDecimal, + __sumDecimalWithSelector, + __sumDouble, + __sumDoubleWithSelector, + __sumInt32, + __sumInt32WithSelector, + __sumInt64, + __sumInt64WithSelector, + __sumNullableDecimal, + __sumNullableDecimalWithSelector, + __sumNullableDouble, + __sumNullableDoubleWithSelector, + __sumNullableInt32, + __sumNullableInt32WithSelector, + __sumNullableInt64, + __sumNullableInt64WithSelector, + __sumNullableSingle, + __sumNullableSingleWithSelector, + __sumSingle, + __sumSingleWithSelector + ]); + + __sumWithSelectorOverloads = MethodInfoSet.Create( + [ + __sumDecimalWithSelector, + __sumDoubleWithSelector, + __sumInt32WithSelector, + __sumInt64WithSelector, + __sumNullableDecimalWithSelector, + __sumNullableDoubleWithSelector, + __sumNullableInt32WithSelector, + __sumNullableInt64WithSelector, + __sumNullableSingleWithSelector, + __sumSingleWithSelector, + ]); } // public properties @@ -247,6 +446,7 @@ static QueryableMethod() public static MethodInfo AggregateWithSeedAndFunc => __aggregateWithSeedAndFunc; public static MethodInfo AggregateWithSeedFuncAndResultSelector => __aggregateWithSeedFuncAndResultSelector; public static MethodInfo All => __all; + public static MethodInfo AllWithPredicate => __allWithPredicate; public static MethodInfo Any => __any; public static MethodInfo AnyWithPredicate => __anyWithPredicate; public static MethodInfo Append => __append; @@ -291,7 +491,7 @@ static QueryableMethod() public static MethodInfo GroupByWithKeySelectorAndResultSelector => __groupByWithKeySelectorAndResultSelector; public static MethodInfo GroupByWithKeySelectorElementSelectorAndResultSelector => __groupByWithKeySelectorElementSelectorAndResultSelector; public static MethodInfo GroupJoin => __groupJoin; - public static MethodInfo Interset => __intersect; + public static MethodInfo Intersect => __intersect; public static MethodInfo Join => __join; public static MethodInfo Last => __last; public static MethodInfo LastOrDefault => __lastOrDefault; @@ -309,7 +509,7 @@ static QueryableMethod() public static MethodInfo Prepend => __prepend; public static MethodInfo Reverse => __reverse; public static MethodInfo Select => __select; - public static MethodInfo SelectMany => __selectMany; + public static MethodInfo SelectManyWithSelector => __selectManyWithSelector; public static MethodInfo SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector; public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector; public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex; @@ -350,6 +550,27 @@ static QueryableMethod() public static MethodInfo WhereWithPredicateTakingIndex => __whereWithPredicateTakingIndex; public static MethodInfo Zip => __zip; + // sets of methods + public static IReadOnlyMethodInfoSet AverageOverloads => __averageOverloads; + public static IReadOnlyMethodInfoSet AverageWithSelectorOverloads => __averageWithSelectorOverloads; + public static IReadOnlyMethodInfoSet CountOverloads => __countOverloads; + public static IReadOnlyMethodInfoSet FirstOverloads => __firstOverloads; + public static IReadOnlyMethodInfoSet FirstWithPredicateOverloads => __firstWithPredicateOverloads; + public static IReadOnlyMethodInfoSet GroupByOverloads => __groupByOverloads; + public static IReadOnlyMethodInfoSet GroupByWithElementSelectorOverloads => __groupByWithElementSelectorOverloads; + public static IReadOnlyMethodInfoSet GroupByWithResultSelectorOverloads => __groupByWithResultSelectorOverloads; + public static IReadOnlyMethodInfoSet LastOverloads => __lastOverloads; + public static IReadOnlyMethodInfoSet LastWithPredicateOverloads => __lastWithPredicateOverloads; + public static IReadOnlyMethodInfoSet LongCountOverloads => __longCountOverloads; + public static IReadOnlyMethodInfoSet MaxOverloads => __maxOverloads; + public static IReadOnlyMethodInfoSet MinOverloads => __minOverloads; + public static IReadOnlyMethodInfoSet SelectManyOverloads => __selectManyOverloads; + public static IReadOnlyMethodInfoSet SingleOverloads => __singleOverloads; + public static IReadOnlyMethodInfoSet SingleOrDefaultOverloads => __singleOrDefaultOverloads; + public static IReadOnlyMethodInfoSet SingleWithPredicateOverloads => __singleWithPredicateOverloads; + public static IReadOnlyMethodInfoSet SumOverloads => __sumOverloads; + public static IReadOnlyMethodInfoSet SumWithSelectorOverloads => __sumWithSelectorOverloads; + // public methods public static MethodInfo MakeSelect(Type tsource, Type tresult) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs index 96177eab363..cf305afaf48 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/StringMethod.cs @@ -28,6 +28,8 @@ internal static class StringMethod private static readonly MethodInfo __anyStringInWithParams; private static readonly MethodInfo __anyStringNinWithEnumerable; private static readonly MethodInfo __anyStringNinWithParams; + private static readonly MethodInfo __compare; + private static readonly MethodInfo __compareWithIgnoreCase; private static readonly MethodInfo __concatWith1Object; private static readonly MethodInfo __concatWith2Objects; private static readonly MethodInfo __concatWith3Objects; @@ -44,6 +46,7 @@ internal static class StringMethod private static readonly MethodInfo __endsWithWithString; private static readonly MethodInfo __endsWithWithStringAndComparisonType; private static readonly MethodInfo __endsWithWithStringAndIgnoreCaseAndCulture; + private static readonly MethodInfo __equalsWithComparisonType; private static readonly MethodInfo __getChars; private static readonly MethodInfo __indexOfAny; private static readonly MethodInfo __indexOfAnyWithStartIndex; @@ -72,8 +75,7 @@ internal static class StringMethod private static readonly MethodInfo __startsWithWithString; private static readonly MethodInfo __startsWithWithStringAndComparisonType; private static readonly MethodInfo __startsWithWithStringAndIgnoreCaseAndCulture; - private static readonly MethodInfo __staticCompare; - private static readonly MethodInfo __staticCompareWithIgnoreCase; + private static readonly MethodInfo __staticEqualsWithComparisonType; private static readonly MethodInfo __stringInWithEnumerable; private static readonly MethodInfo __stringInWithParams; private static readonly MethodInfo __stringNinWithEnumerable; @@ -93,9 +95,40 @@ internal static class StringMethod private static readonly MethodInfo __trimStart; private static readonly MethodInfo __trimWithChars; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __anyStringInOverloads; + private static readonly IReadOnlyMethodInfoSet __anyStringNinOverloads; + private static readonly IReadOnlyMethodInfoSet __compareOverloads; + private static readonly IReadOnlyMethodInfoSet __concatOverloads; + private static readonly IReadOnlyMethodInfoSet __containsOverloads; + private static readonly IReadOnlyMethodInfoSet __endsWithOrStartsWithOverloads; + private static readonly IReadOnlyMethodInfoSet __endsWithOverloads; + private static readonly IReadOnlyMethodInfoSet __indexOfAnyOverloads; + private static readonly IReadOnlyMethodInfoSet __indexOfOverloads; + private static readonly IReadOnlyMethodInfoSet __indexOfBytesOverloads; + private static readonly IReadOnlyMethodInfoSet __indexOfWithCharOverloads; + private static readonly IReadOnlyMethodInfoSet __indexOfWithComparisonTypeOverloads; + private static readonly IReadOnlyMethodInfoSet __indexOfWithCountOverloads; + private static readonly IReadOnlyMethodInfoSet __indexOfWithStartIndexOverloads; + private static readonly IReadOnlyMethodInfoSet __indexOfWithStringOverloads; + private static readonly IReadOnlyMethodInfoSet __indexOfWithStringComparisonOverloads; + private static readonly IReadOnlyMethodInfoSet __splitOverloads; + private static readonly IReadOnlyMethodInfoSet __splitWithCharsOverloads; + private static readonly IReadOnlyMethodInfoSet __splitWithCountOverloads; + private static readonly IReadOnlyMethodInfoSet __splitWithOptionsOverloads; + private static readonly IReadOnlyMethodInfoSet __splitWithStringsOverloads; + private static readonly IReadOnlyMethodInfoSet __startsWithOverloads; + private static readonly IReadOnlyMethodInfoSet __stringInOverloads; + private static readonly IReadOnlyMethodInfoSet __stringNinOverloads; + private static readonly IReadOnlyMethodInfoSet __toLowerOrToUpperOverloads; + private static readonly IReadOnlyMethodInfoSet __toLowerOverloads; + private static readonly IReadOnlyMethodInfoSet __toUpperOverloads; + private static readonly IReadOnlyMethodInfoSet __trimOverloads; + // static constructor static StringMethod() { + // initialize methods before sets of methods #if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER __containsWithChar = ReflectionInfo.Method((string s, char value) => s.Contains(value)); __containsWithCharAndComparisonType = ReflectionInfo.Method((string s, char value, StringComparison comparisonType) => s.Contains(value, comparisonType)); @@ -114,6 +147,8 @@ static StringMethod() __anyStringInWithParams = ReflectionInfo.Method((IEnumerable s, StringOrRegularExpression[] values) => s.AnyStringIn(values)); __anyStringNinWithEnumerable = ReflectionInfo.Method((IEnumerable s, IEnumerable values) => s.AnyStringNin(values)); __anyStringNinWithParams = ReflectionInfo.Method((IEnumerable s, StringOrRegularExpression[] values) => s.AnyStringNin(values)); + __compare = ReflectionInfo.Method((string strA, string strB) => String.Compare(strA, strB)); + __compareWithIgnoreCase = ReflectionInfo.Method((string strA, string strB, bool ignoreCase) => String.Compare(strA, strB, ignoreCase)); __concatWith1Object = ReflectionInfo.Method((object arg) => string.Concat(arg)); __concatWith2Objects = ReflectionInfo.Method((object arg0, object arg1) => string.Concat(arg0, arg1)); __concatWith3Objects = ReflectionInfo.Method((object arg0, object arg1, object arg2) => string.Concat(arg0, arg1, arg2)); @@ -126,6 +161,7 @@ static StringMethod() __endsWithWithString = ReflectionInfo.Method((string s, string value) => s.EndsWith(value)); __endsWithWithStringAndComparisonType = ReflectionInfo.Method((string s, string value, StringComparison comparisonType) => s.EndsWith(value, comparisonType)); __endsWithWithStringAndIgnoreCaseAndCulture = ReflectionInfo.Method((string s, string value, bool ignoreCase, CultureInfo culture) => s.EndsWith(value, ignoreCase, culture)); + __equalsWithComparisonType = ReflectionInfo.Method((string s, string value, StringComparison comparisonType) => s.Equals(value, comparisonType)); __getChars = ReflectionInfo.Method((string s, int index) => s[index]); __indexOfAny = ReflectionInfo.Method((string s, char[] anyOf) => s.IndexOfAny(anyOf)); __indexOfAnyWithStartIndex = ReflectionInfo.Method((string s, char[] anyOf, int startIndex) => s.IndexOfAny(anyOf, startIndex)); @@ -153,8 +189,7 @@ static StringMethod() __startsWithWithString = ReflectionInfo.Method((string s, string value) => s.StartsWith(value)); __startsWithWithStringAndComparisonType = ReflectionInfo.Method((string s, string value, StringComparison comparisonType) => s.StartsWith(value, comparisonType)); __startsWithWithStringAndIgnoreCaseAndCulture = ReflectionInfo.Method((string s, string value, bool ignoreCase, CultureInfo culture) => s.StartsWith(value, ignoreCase, culture)); - __staticCompare = ReflectionInfo.Method((string strA, string strB) => String.Compare(strA, strB)); - __staticCompareWithIgnoreCase = ReflectionInfo.Method((string strA, string strB, bool ignoreCase) => String.Compare(strA, strB, ignoreCase)); + __staticEqualsWithComparisonType = ReflectionInfo.Method((string a, string b, StringComparison comparisonType) => string.Equals(a, b, comparisonType)); __stringInWithEnumerable = ReflectionInfo.Method((string s, IEnumerable values) => s.StringIn(values)); __stringInWithParams = ReflectionInfo.Method((string s, StringOrRegularExpression[] values) => s.StringIn(values)); __stringNinWithEnumerable = ReflectionInfo.Method((string s, IEnumerable values) => s.StringNin(values)); @@ -173,6 +208,234 @@ static StringMethod() __trimEnd = ReflectionInfo.Method((string s, char[] trimChars) => s.TrimEnd(trimChars)); __trimStart = ReflectionInfo.Method((string s, char[] trimChars) => s.TrimStart(trimChars)); __trimWithChars = ReflectionInfo.Method((string s, char[] trimChars) => s.Trim(trimChars)); + + // initialize sets of methods after methods + __anyStringInOverloads = MethodInfoSet.Create( + [ + __anyStringInWithEnumerable, + __anyStringInWithParams + ]); + + __anyStringNinOverloads = MethodInfoSet.Create( + [ + __anyStringNinWithEnumerable, + __anyStringNinWithParams, + ]); + + __compareOverloads = MethodInfoSet.Create( + [ + __compare, + __compareWithIgnoreCase + ]); + + __concatOverloads = MethodInfoSet.Create( + [ + __concatWith1Object, + __concatWith2Objects, + __concatWith2Strings, + __concatWith3Objects, + __concatWith3Strings, + __concatWith4Strings, + __concatWithObjectArray, + __concatWithStringArray + ]); + + __containsOverloads = MethodInfoSet.Create( + [ + __containsWithChar, + __containsWithCharAndComparisonType, + __containsWithString, + __containsWithStringAndComparisonType + ]); + + __endsWithOverloads = MethodInfoSet.Create( + [ + __endsWithWithChar, + __endsWithWithString, + __endsWithWithStringAndComparisonType, + __endsWithWithStringAndIgnoreCaseAndCulture, + ]); + + __indexOfAnyOverloads = MethodInfoSet.Create( + [ + __indexOfAny, + __indexOfAnyWithStartIndex, + __indexOfAnyWithStartIndexAndCount, + ]); + + __indexOfOverloads = MethodInfoSet.Create( + [ + __indexOfAny, + __indexOfAnyWithStartIndex, + __indexOfAnyWithStartIndexAndCount, + __indexOfBytesWithValue, + __indexOfBytesWithValueAndStartIndex, + __indexOfBytesWithValueAndStartIndexAndCount, + __indexOfWithChar, + __indexOfWithCharAndStartIndex, + __indexOfWithCharAndStartIndexAndCount, + __indexOfWithString, + __indexOfWithStringAndComparisonType, + __indexOfWithStringAndStartIndex, + __indexOfWithStringAndStartIndexAndComparisonType, + __indexOfWithStringAndStartIndexAndCount, + __indexOfWithStringAndStartIndexAndCountAndComparisonType, + ]); + + __indexOfBytesOverloads = MethodInfoSet.Create( + [ + __indexOfBytesWithValue, + __indexOfBytesWithValueAndStartIndex, + __indexOfBytesWithValueAndStartIndexAndCount + ]); + + __indexOfWithCharOverloads = MethodInfoSet.Create( + [ + __indexOfWithChar, + __indexOfWithCharAndStartIndex, + __indexOfWithCharAndStartIndexAndCount, + ]); + + __indexOfWithComparisonTypeOverloads = MethodInfoSet.Create( + [ + __indexOfWithStringAndComparisonType, + __indexOfWithStringAndStartIndexAndComparisonType, + __indexOfWithStringAndStartIndexAndCountAndComparisonType + ]); + + __indexOfWithCountOverloads = MethodInfoSet.Create( + [ + __indexOfAnyWithStartIndexAndCount, + __indexOfBytesWithValueAndStartIndexAndCount, + __indexOfWithCharAndStartIndexAndCount, + __indexOfWithStringAndStartIndexAndCount, + __indexOfWithStringAndStartIndexAndCountAndComparisonType + ]); + + __indexOfWithStartIndexOverloads = MethodInfoSet.Create( + [ + __indexOfAnyWithStartIndex, + __indexOfAnyWithStartIndexAndCount, + __indexOfBytesWithValueAndStartIndex, + __indexOfBytesWithValueAndStartIndexAndCount, + __indexOfWithCharAndStartIndex, + __indexOfWithCharAndStartIndexAndCount, + __indexOfWithStringAndStartIndex, + __indexOfWithStringAndStartIndexAndCount, + __indexOfWithStringAndStartIndexAndComparisonType, + __indexOfWithStringAndStartIndexAndCountAndComparisonType + ]); + + __indexOfWithStringOverloads = MethodInfoSet.Create( + [ + __indexOfWithString, + __indexOfWithStringAndComparisonType, + __indexOfWithStringAndStartIndex, + __indexOfWithStringAndStartIndexAndComparisonType, + __indexOfWithStringAndStartIndexAndCount, + __indexOfWithStringAndStartIndexAndCountAndComparisonType + ]); + + __indexOfWithStringComparisonOverloads = MethodInfoSet.Create( + [ + __indexOfWithStringAndComparisonType, + __indexOfWithStringAndStartIndexAndComparisonType, + __indexOfWithStringAndStartIndexAndCountAndComparisonType + ]); + + __splitOverloads = MethodInfoSet.Create( + [ + __splitWithChars, + __splitWithCharsAndCount, + __splitWithCharsAndCountAndOptions, + __splitWithCharsAndOptions, + __splitWithStringsAndCountAndOptions, + __splitWithStringsAndOptions + ]); + + __splitWithCharsOverloads = MethodInfoSet.Create( + [ + __splitWithChars, + __splitWithCharsAndCount, + __splitWithCharsAndCountAndOptions, + __splitWithCharsAndOptions + ]); + + __splitWithCountOverloads = MethodInfoSet.Create( + [ + __splitWithCharsAndCount, + __splitWithCharsAndCountAndOptions, + __splitWithStringsAndCountAndOptions + ]); + + __splitWithOptionsOverloads = MethodInfoSet.Create( + [ + __splitWithCharsAndCountAndOptions, + __splitWithCharsAndOptions, + __splitWithStringsAndCountAndOptions, + __splitWithStringsAndOptions + ]); + + __splitWithStringsOverloads = MethodInfoSet.Create( + [ + __splitWithStringsAndCountAndOptions, + __splitWithStringsAndOptions + ]); + + __startsWithOverloads = MethodInfoSet.Create( + [ + __startsWithWithChar, + __startsWithWithString, + __startsWithWithStringAndComparisonType, + __startsWithWithStringAndIgnoreCaseAndCulture + ]); + + __stringInOverloads = MethodInfoSet.Create( + [ + __stringInWithEnumerable, + __stringInWithParams + ]); + + __stringNinOverloads = MethodInfoSet.Create( + [ + __stringNinWithEnumerable, + __stringNinWithParams + ]); + + __toLowerOverloads = MethodInfoSet.Create( + [ + __toLower, + __toLowerInvariant, + __toLowerWithCulture, + ]); + + __toUpperOverloads = MethodInfoSet.Create( + [ + __toUpper, + __toUpperInvariant, + __toUpperWithCulture, + ]); + + __trimOverloads = MethodInfoSet.Create( + [ + __trim, + __trimEnd, + __trimStart, + __trimWithChars + ]); + + // initialize sets of methods after individual methods + __endsWithOrStartsWithOverloads = MethodInfoSet.Create( + [ + __endsWithOverloads, + __startsWithOverloads + ]); + + __toLowerOrToUpperOverloads = MethodInfoSet.Create( + [ + __toLowerOverloads, + __toUpperOverloads + ]); } // public properties @@ -180,6 +443,8 @@ static StringMethod() public static MethodInfo AnyStringInWithParams => __anyStringInWithParams; public static MethodInfo AnyStringNinWithEnumerable => __anyStringNinWithEnumerable; public static MethodInfo AnyStringNinWithParams => __anyStringNinWithParams; + public static MethodInfo Compare => __compare; + public static MethodInfo CompareWithIgnoreCase => __compareWithIgnoreCase; public static MethodInfo ConcatWith1Object => __concatWith1Object; public static MethodInfo ConcatWith2Objects => __concatWith2Objects; public static MethodInfo ConcatWith3Objects => __concatWith3Objects; @@ -196,6 +461,7 @@ static StringMethod() public static MethodInfo EndsWithWithString => __endsWithWithString; public static MethodInfo EndsWithWithStringAndComparisonType => __endsWithWithStringAndComparisonType; public static MethodInfo EndsWithWithStringAndIgnoreCaseAndCulture => __endsWithWithStringAndIgnoreCaseAndCulture; + public static MethodInfo EqualsWithComparisonType => __equalsWithComparisonType; public static MethodInfo GetChars => __getChars; public static MethodInfo IndexOfAny => __indexOfAny; public static MethodInfo IndexOfAnyWithStartIndex => __indexOfAnyWithStartIndex; @@ -224,8 +490,7 @@ static StringMethod() public static MethodInfo StartsWithWithString => __startsWithWithString; public static MethodInfo StartsWithWithStringAndComparisonType => __startsWithWithStringAndComparisonType; public static MethodInfo StartsWithWithStringAndIgnoreCaseAndCulture => __startsWithWithStringAndIgnoreCaseAndCulture; - public static MethodInfo StaticCompare => __staticCompare; - public static MethodInfo StaticCompareWithIgnoreCase => __staticCompareWithIgnoreCase; + public static MethodInfo StaticEqualsWithComparisonType => __staticEqualsWithComparisonType; public static MethodInfo StringInWithEnumerable => __stringInWithEnumerable; public static MethodInfo StringInWithParams => __stringInWithParams; public static MethodInfo StringNinWithEnumerable => __stringNinWithEnumerable; @@ -244,5 +509,35 @@ static StringMethod() public static MethodInfo TrimEnd => __trimEnd; public static MethodInfo TrimStart => __trimStart; public static MethodInfo TrimWithChars => __trimWithChars; + + // sets of methods + public static IReadOnlyMethodInfoSet AnyStringInOverloads => __anyStringInOverloads; + public static IReadOnlyMethodInfoSet AnyStringNinOverloads => __anyStringNinOverloads; + public static IReadOnlyMethodInfoSet CompareOverloads => __compareOverloads; + public static IReadOnlyMethodInfoSet ConcatOverloads => __concatOverloads; + public static IReadOnlyMethodInfoSet ContainsOverloads => __containsOverloads; + public static IReadOnlyMethodInfoSet EndsWithOrStartsWithOverloads => __endsWithOrStartsWithOverloads; + public static IReadOnlyMethodInfoSet EndsWithOverloads => __endsWithOverloads; + public static IReadOnlyMethodInfoSet IndexOfAnyOverloads => __indexOfAnyOverloads; + public static IReadOnlyMethodInfoSet IndexOfOverloads => __indexOfOverloads; + public static IReadOnlyMethodInfoSet IndexOfBytesOverloads => __indexOfBytesOverloads; + public static IReadOnlyMethodInfoSet IndexOfWithCountOverloads => __indexOfWithCountOverloads; + public static IReadOnlyMethodInfoSet IndexOfWithCharOverloads => __indexOfWithCharOverloads; + public static IReadOnlyMethodInfoSet IndexOfWithComparisonTypeOverloads => __indexOfWithComparisonTypeOverloads; + public static IReadOnlyMethodInfoSet IndexOfWithStartIndexOverloads => __indexOfWithStartIndexOverloads; + public static IReadOnlyMethodInfoSet IndexOfWithStringOverloads => __indexOfWithStringOverloads; + public static IReadOnlyMethodInfoSet IndexOfWithStringComparisonOverloads => __indexOfWithStringComparisonOverloads; + public static IReadOnlyMethodInfoSet SplitOverloads => __splitOverloads; + public static IReadOnlyMethodInfoSet SplitWithCharsOverloads => __splitWithCharsOverloads; + public static IReadOnlyMethodInfoSet SplitWithCountOverloads => __splitWithCountOverloads; + public static IReadOnlyMethodInfoSet SplitWithOptionsOverloads => __splitWithOptionsOverloads; + public static IReadOnlyMethodInfoSet SplitWithStringsOverloads => __splitWithStringsOverloads; + public static IReadOnlyMethodInfoSet StartsWithOverloads => __startsWithOverloads; + public static IReadOnlyMethodInfoSet StringInOverloads => __stringInOverloads; + public static IReadOnlyMethodInfoSet StringNinOverloads => __stringNinOverloads; + public static IReadOnlyMethodInfoSet ToLowerOrToUpperOverloads => __toLowerOrToUpperOverloads; + public static IReadOnlyMethodInfoSet ToLowerOverloads => __toLowerOverloads; + public static IReadOnlyMethodInfoSet ToUpperOverloads => __toUpperOverloads; + public static IReadOnlyMethodInfoSet TrimOverloads => __trimOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleMethod.cs index 653e6655e4c..778e4306e23 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleMethod.cs @@ -30,9 +30,13 @@ internal static class TupleMethod private static readonly MethodInfo __create7; private static readonly MethodInfo __create8; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __createOverloads; + // static constructor static TupleMethod() { + // initialize methods before sets of methods __create1 = ReflectionInfo.Method((object item1) => Tuple.Create(item1)); __create2 = ReflectionInfo.Method((object item1, object item2) => Tuple.Create(item1, item2)); __create3 = ReflectionInfo.Method((object item1, object item2, object item3) => Tuple.Create(item1, item2, item3)); @@ -41,6 +45,19 @@ static TupleMethod() __create6 = ReflectionInfo.Method((object item1, object item2, object item3, object item4, object item5, object item6) => Tuple.Create(item1, item2, item3, item4, item5, item6)); __create7 = ReflectionInfo.Method((object item1, object item2, object item3, object item4, object item5, object item6, object item7) => Tuple.Create(item1, item2, item3, item4, item5, item6, item7)); __create8 = ReflectionInfo.Method((object item1, object item2, object item3, object item4, object item5, object item6, object item7, object item8) => Tuple.Create(item1, item2, item3, item4, item5, item6, item7, item8)); + + // initialize sets of methods after methods + __createOverloads = MethodInfoSet.Create( + [ + __create1, + __create2, + __create3, + __create4, + __create5, + __create6, + __create7, + __create8 + ]); } // public properties @@ -52,5 +69,8 @@ static TupleMethod() public static MethodInfo Create6 => __create6; public static MethodInfo Create7 => __create7; public static MethodInfo Create8 => __create8; + + // sets of methods + public static IReadOnlyMethodInfoSet CreateOverloads => __createOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleConstructor.cs new file mode 100644 index 00000000000..80f8502c060 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleConstructor.cs @@ -0,0 +1,29 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal static class TupleOrValueTupleConstructor +{ + public static bool IsTupleOrValueTupleConstructor(ConstructorInfo constructor) + { + return + constructor != null && + constructor.DeclaringType.IsTupleOrValueTuple(); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleMethod.cs new file mode 100644 index 00000000000..722ca7cd92a --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/TupleOrValueTupleMethod.cs @@ -0,0 +1,32 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal static class TupleOrValueTupleMethod +{ + private static readonly IReadOnlyMethodInfoSet __createOverloads; + + static TupleOrValueTupleMethod() + { + __createOverloads = MethodInfoSet.Create( + [ + TupleMethod.CreateOverloads, + ValueTupleMethod.CreateOverloads + ]); + } + + public static IReadOnlyMethodInfoSet CreateOverloads => __createOverloads; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ValueTupleMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ValueTupleMethod.cs index 9971d53519d..2abb2a4d6d1 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ValueTupleMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ValueTupleMethod.cs @@ -30,9 +30,13 @@ internal static class ValueTupleMethod private static readonly MethodInfo __create7; private static readonly MethodInfo __create8; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __createOverloads; + // static constructor static ValueTupleMethod() { + // initialize methods before sets of methods __create1 = ReflectionInfo.Method((object item1) => ValueTuple.Create(item1)); __create2 = ReflectionInfo.Method((object item1, object item2) => ValueTuple.Create(item1, item2)); __create3 = ReflectionInfo.Method((object item1, object item2, object item3) => ValueTuple.Create(item1, item2, item3)); @@ -41,6 +45,19 @@ static ValueTupleMethod() __create6 = ReflectionInfo.Method((object item1, object item2, object item3, object item4, object item5, object item6) => ValueTuple.Create(item1, item2, item3, item4, item5, item6)); __create7 = ReflectionInfo.Method((object item1, object item2, object item3, object item4, object item5, object item6, object item7) => ValueTuple.Create(item1, item2, item3, item4, item5, item6, item7)); __create8 = ReflectionInfo.Method((object item1, object item2, object item3, object item4, object item5, object item6, object item7, object item8) => ValueTuple.Create(item1, item2, item3, item4, item5, item6, item7, item8)); + + // initialize sets of methods after methods + __createOverloads = MethodInfoSet.Create( + [ + __create1, + __create2, + __create3, + __create4, + __create5, + __create6, + __create7, + __create8 + ]); } // public properties @@ -52,5 +69,8 @@ static ValueTupleMethod() public static MethodInfo Create6 => __create6; public static MethodInfo Create7 => __create7; public static MethodInfo Create8 => __create8; + + // sets of methods + public static IReadOnlyMethodInfoSet CreateOverloads => __createOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs index 693e8762269..f2b5b73f106 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/WindowMethod.cs @@ -141,9 +141,13 @@ internal static class WindowMethod private static readonly MethodInfo __sumWithNullableSingle; private static readonly MethodInfo __sumWithSingle; + // sets of methods + private static readonly IReadOnlyMethodInfoSet __percentileOverloads; + // static constructor static WindowMethod() { + // initialize methods before sets of methods __addToSet = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.AddToSet(selector, window)); __averageWithDecimal = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Average(selector, window)); __averageWithDouble = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Average(selector, window)); @@ -262,6 +266,21 @@ static WindowMethod() __sumWithNullableInt64 = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Sum(selector, window)); __sumWithNullableSingle = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Sum(selector, window)); __sumWithSingle = ReflectionInfo.Method((ISetWindowFieldsPartition partition, Func selector, SetWindowFieldsWindow window) => partition.Sum(selector, window)); + + // initialize sets of methods after methods + __percentileOverloads = MethodInfoSet.Create( + [ + __percentileWithDecimal, + __percentileWithDouble, + __percentileWithInt32, + __percentileWithInt64, + __percentileWithNullableDecimal, + __percentileWithNullableDouble, + __percentileWithNullableInt32, + __percentileWithNullableInt64, + __percentileWithNullableSingle, + __percentileWithSingle + ]); } // public properties @@ -383,5 +402,8 @@ static WindowMethod() public static MethodInfo SumWithNullableInt64 => __sumWithNullableInt64; public static MethodInfo SumWithNullableSingle => __sumWithNullableSingle; public static MethodInfo SumWithSingle => __sumWithSingle; + + // sets of methods + public static IReadOnlyMethodInfoSet PercentileOverloads => __percentileOverloads; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/MissingSerializerFinder.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/MissingSerializerFinder.cs new file mode 100644 index 00000000000..fd98d43ce2a --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/MissingSerializerFinder.cs @@ -0,0 +1,62 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal class MissingSerializerFinder : ExpressionVisitor +{ + public static Expression FindExpressionWithMissingSerializer(Expression expression, SerializerMap nodeSerializers) + { + var visitor = new MissingSerializerFinder(nodeSerializers); + visitor.Visit(expression); + return visitor._expressionWithMissingSerializer; + } + + private Expression _expressionWithMissingSerializer = null; + private readonly SerializerMap _nodeSerializers; + + public MissingSerializerFinder(SerializerMap nodeSerializers) + { + _nodeSerializers = nodeSerializers; + } + + public Expression ExpressionWithMissingSerializer => _expressionWithMissingSerializer; + + public override Expression Visit(Expression node) + { + if (_nodeSerializers.IsKnown(node, out var nodeSerializer)) + { + if (nodeSerializer is IIgnoreSubtreeSerializer or IUnknowableSerializer) + { + return node; // don't visit subtree + } + } + + base.Visit(node); + + if (_expressionWithMissingSerializer == null && + node != null && + _nodeSerializers.IsNotKnown(node)) + { + _expressionWithMissingSerializer = node; + } + + return node; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinder.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinder.cs new file mode 100644 index 00000000000..c12a8cb8759 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinder.cs @@ -0,0 +1,45 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal static class SerializerFinder +{ + public static void FindSerializers( + Expression expression, + ExpressionTranslationOptions translationOptions, + SerializerMap nodeSerializers) + { + var visitor = new SerializerFinderVisitor(translationOptions, nodeSerializers); + + do + { + visitor.StartPass(); + visitor.Visit(expression); + visitor.EndPass(); + } + while (visitor.IsMakingProgress); + + //#if DEBUG + var expressionWithMissingSerializer = MissingSerializerFinder.FindExpressionWithMissingSerializer(expression, nodeSerializers); + if (expressionWithMissingSerializer != null) + { + throw new ExpressionNotSupportedException(expressionWithMissingSerializer, because: "we were unable to determine which serializer to use for the result"); + } + //#endif + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderHelperMethods.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderHelperMethods.cs new file mode 100644 index 00000000000..3bfed3cee83 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderHelperMethods.cs @@ -0,0 +1,236 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using IOrderedEnumerableSerializer=MongoDB.Driver.Linq.Linq3Implementation.Serializers.IOrderedEnumerableSerializer; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + private void AddNodeSerializer(Expression node, IBsonSerializer serializer) => _nodeSerializers.AddSerializer(node, serializer); + + private bool AreAllKnown(IEnumerable nodes, out IReadOnlyList nodeSerializers) + { + var nodeSerializersList = new List(); + foreach (var node in nodes) + { + if (IsKnown(node, out var nodeSerializer)) + { + nodeSerializersList.Add(nodeSerializer); + } + else + { + nodeSerializers = null; + return false; + } + } + + nodeSerializers = nodeSerializersList; + return true; + } + + private bool IsAnyKnown(IEnumerable nodes, out IBsonSerializer nodeSerializer) + { + foreach (var node in nodes) + { + if (IsKnown(node, out var outSerializer)) + { + nodeSerializer = outSerializer; + return true; + } + } + + nodeSerializer = null; + return false; + } + + private bool IsAnyNotKnown(IEnumerable nodes) + { + return nodes.Any(IsNotKnown); + } + + IBsonSerializer CreateCollectionSerializerFromCollectionSerializer(Type collectionType, IBsonSerializer collectionSerializer) + { + if (collectionSerializer.ValueType == collectionType) + { + return collectionSerializer; + } + + if (collectionSerializer is IUnknowableSerializer) + { + return UnknowableSerializer.Create(collectionType); + } + + var itemSerializer = collectionSerializer.GetItemSerializer(); + return CreateCollectionSerializerFromItemSerializer(collectionType, itemSerializer); + } + + IBsonSerializer CreateCollectionSerializerFromItemSerializer(Type collectionType, IBsonSerializer itemSerializer) + { + if (itemSerializer is IUnknowableSerializer) + { + return UnknowableSerializer.Create(collectionType); + } + + return collectionType switch + { + _ when collectionType.IsArray => ArraySerializer.Create(itemSerializer), + _ when collectionType.IsConstructedGenericType && collectionType.GetGenericTypeDefinition() == typeof(IEnumerable<>) => IEnumerableSerializer.Create(itemSerializer), + _ when collectionType.IsConstructedGenericType && collectionType.GetGenericTypeDefinition() == typeof(IOrderedEnumerable<>) => IOrderedEnumerableSerializer.Create(itemSerializer), + _ when collectionType.IsConstructedGenericType && collectionType.GetGenericTypeDefinition() == typeof(IQueryable<>) => IQueryableSerializer.Create(itemSerializer), + _ => (BsonSerializer.LookupSerializer(collectionType) as IChildSerializerConfigurable)?.WithChildSerializer(itemSerializer) + }; + } + + private void DeduceBaseTypeAndDerivedTypeSerializers(Expression baseTypeExpression, Expression derivedTypeExpression) + { + IBsonSerializer baseTypeSerializer; + IBsonSerializer derivedTypeSerializer; + + if (IsNotKnown(baseTypeExpression) && IsKnown(derivedTypeExpression, out derivedTypeSerializer)) + { + baseTypeSerializer = derivedTypeSerializer.GetBaseTypeSerializer(baseTypeExpression.Type); + AddNodeSerializer(baseTypeExpression, baseTypeSerializer); + } + + if (IsNotKnown(derivedTypeExpression) && IsKnown(baseTypeExpression, out baseTypeSerializer)) + { + derivedTypeSerializer = baseTypeSerializer.GetDerivedTypeSerializer(baseTypeExpression.Type); + AddNodeSerializer(derivedTypeExpression, derivedTypeSerializer); + } + } + + private void DeduceBooleanSerializer(Expression node) + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, BooleanSerializer.Instance); + } + } + + private void DeduceCharSerializer(Expression node) + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, CharSerializer.Instance); + } + } + + private void DeduceCollectionAndCollectionSerializers(Expression collectionExpression1, Expression collectionExpression2) + { + IBsonSerializer collectionSerializer1; + IBsonSerializer collectionSerializer2; + + if (IsNotKnown(collectionExpression1) && IsKnown(collectionExpression2, out collectionSerializer2)) + { + collectionSerializer1 = CreateCollectionSerializerFromCollectionSerializer(collectionExpression1.Type, collectionSerializer2); + AddNodeSerializer(collectionExpression1, collectionSerializer1); + } + + if (IsNotKnown(collectionExpression2) && IsKnown(collectionExpression1, out collectionSerializer1)) + { + collectionSerializer2 = CreateCollectionSerializerFromCollectionSerializer(collectionExpression2.Type, collectionSerializer1); + AddNodeSerializer(collectionExpression2, collectionSerializer2); + } + } + + private void DeduceCollectionAndItemSerializers(Expression collectionExpression, Expression itemExpression) + { + DeduceItemAndCollectionSerializers(itemExpression, collectionExpression); + } + + private void DeduceItemAndCollectionSerializers(Expression itemExpression, Expression collectionExpression) + { + if (IsNotKnown(itemExpression) && IsItemSerializerKnown(collectionExpression, out var itemSerializer)) + { + AddNodeSerializer(itemExpression, itemSerializer); + } + + if (IsNotKnown(collectionExpression) && IsKnown(itemExpression, out itemSerializer)) + { + var collectionSerializer = CreateCollectionSerializerFromItemSerializer(collectionExpression.Type, itemSerializer); + if (collectionSerializer != null) + { + AddNodeSerializer(collectionExpression, collectionSerializer); + } + } + } + + private void DeduceSerializer(Expression node, IBsonSerializer serializer) + { + if (IsNotKnown(node) && serializer != null) + { + AddNodeSerializer(node, serializer); + } + } + + private void DeduceSerializers(Expression expression1, Expression expression2) + { + if (IsNotKnown(expression1) && IsKnown(expression2, out var expression2Serializer) && expression2Serializer.ValueType == expression1.Type) + { + AddNodeSerializer(expression1, expression2Serializer); + } + + if (IsNotKnown(expression2) && IsKnown(expression1, out var expression1Serializer)&& expression1Serializer.ValueType == expression2.Type) + { + AddNodeSerializer(expression2, expression1Serializer); + } + } + + private void DeduceStringSerializer(Expression node) + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, StringSerializer.Instance); + } + } + + private void DeduceUnknowableSerializer(Expression node) + { + if (IsNotKnown(node)) + { + var unknowableSerializer = UnknowableSerializer.Create(node.Type); + AddNodeSerializer(node, unknowableSerializer); + } + } + + private bool IsItemSerializerKnown(Expression node, out IBsonSerializer itemSerializer) + { + if (IsKnown(node, out var nodeSerializer) && + nodeSerializer is IBsonArraySerializer arraySerializer && + arraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) + { + itemSerializer = itemSerializationInfo.Serializer; + return true; + } + + itemSerializer = null; + return false; + } + + private bool IsKnown(Expression node) => _nodeSerializers.IsKnown(node); + + private bool IsKnown(Expression node, out IBsonSerializer nodeSerializer) => _nodeSerializers.IsKnown(node, out nodeSerializer); + + private bool IsNotKnown(Expression node) => _nodeSerializers.IsNotKnown(node); +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderNewExpressionSerializerCreator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderNewExpressionSerializerCreator.cs new file mode 100644 index 00000000000..15bc1ca0674 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderNewExpressionSerializerCreator.cs @@ -0,0 +1,201 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + public IBsonSerializer CreateNewExpressionSerializer( + Expression expression, + NewExpression newExpression, + IReadOnlyList bindings) + { + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct + var constructorArguments = newExpression.Arguments; + var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); + + if (constructorInfo != null && creatorMap != null) + { + var constructorParameters = constructorInfo.GetParameters(); + var creatorMapParameters = creatorMap.Arguments?.ToArray(); + if (constructorParameters.Length > 0) + { + if (creatorMapParameters == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); + } + + if (creatorMapParameters.Length != constructorParameters.Length) + { + throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); + } + + for (var i = 0; i < creatorMapParameters.Length; i++) + { + var creatorMapParameter = creatorMapParameters[i]; + var constructorArgumentExpression = constructorArguments[i]; + if (!IsKnown(constructorArgumentExpression, out var constructorArgumentSerializer)) + { + return null; + } + var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); + EnsureDefaultValue(memberMap); + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); + memberMap.SetSerializer(memberSerializer); + } + } + } + + if (bindings != null) + { + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; + var memberMap = FindMemberMap(expression, classMap, member.Name); + var valueExpression = memberAssignment.Expression; + if (!IsKnown(valueExpression, out var valueSerializer)) + { + return null; + } + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueSerializer); + memberMap.SetSerializer(memberSerializer); + } + } + + classMap.Freeze(); + + var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); + return (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); + } + + private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) + { + BsonClassMap baseClassMap = null; + if (classType.BaseType != null) + { + baseClassMap = CreateClassMap(classType.BaseType, null, out _); + } + + var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); + var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); + var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); + if (constructorInfo != null) + { + creatorMap = classMap.MapConstructor(constructorInfo); + } + else + { + creatorMap = null; + } + + classMap.AutoMap(); + classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here + + return classMap; + } + + private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) + { + var memberType = memberMap.MemberType; + var memberSerializer = memberMap.GetSerializer(); + var sourceType = sourceSerializer.ValueType; + + if (memberType != sourceType && + memberType.ImplementsIEnumerable(out var memberItemType) && + sourceType.ImplementsIEnumerable(out var sourceItemType) && + sourceItemType == memberItemType && + sourceSerializer is IBsonArraySerializer sourceArraySerializer && + sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && + memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) + { + var sourceItemSerializer = sourceItemSerializationInfo.Serializer; + return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); + } + + return sourceSerializer; + } + + private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) + { + var declaringClassMap = classMap; + while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) + { + declaringClassMap = declaringClassMap.BaseClassMap; + + if (declaringClassMap == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); + } + } + + foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) + { + if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) + { + return memberMap; + } + } + + return declaringClassMap.MapMember(creatorMapParameter); + + static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) + { + var memberInfo = memberMap.MemberInfo; + return + memberInfo.MemberType == creatorMapParameter.MemberType && + memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); + } + } + + private static void EnsureDefaultValue(BsonMemberMap memberMap) + { + if (memberMap.IsDefaultValueSpecified) + { + return; + } + + var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; + memberMap.SetDefaultValue(defaultValue); + } + + private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) + { + foreach (var memberMap in classMap.DeclaredMemberMaps) + { + if (memberMap.MemberName == memberName) + { + return memberMap; + } + } + + if (classMap.BaseClassMap != null) + { + return FindMemberMap(expression, classMap.BaseClassMap, memberName); + } + + throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitBinary.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitBinary.cs new file mode 100644 index 00000000000..8deba5b6492 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitBinary.cs @@ -0,0 +1,173 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitBinary(BinaryExpression node) + { + base.VisitBinary(node); + + var @operator = node.NodeType; + var leftExpression = node.Left; + var rightExpression = node.Right; + + if (node.NodeType == ExpressionType.Add && node.Type == typeof(string)) + { + DeduceStringSerializer(node); + return node; + } + + if (IsSymmetricalBinaryOperator(@operator)) + { + // expr1 op expr2 => expr1: expr2Serializer or expr2: expr1Serializer + DeduceSerializers(leftExpression, rightExpression); + } + + if (@operator == ExpressionType.ArrayIndex) + { + if (IsNotKnown(node) && + IsKnown(leftExpression, out var leftSerializer)) + { + IBsonSerializer itemSerializer; + if (leftSerializer is IPolymorphicArraySerializer polymorphicArraySerializer) + { + var index = rightExpression.GetConstantValue(node); + itemSerializer = polymorphicArraySerializer.GetItemSerializer(index); + } + else + { + itemSerializer = leftSerializer.GetItemSerializer(); + } + + // expr[index] => node: itemSerializer + AddNodeSerializer(node, itemSerializer); + } + } + + if (@operator == ExpressionType.Coalesce) + { + if (IsNotKnown(node) && + IsKnown(leftExpression, out var leftSerializer)) + { + if (leftSerializer.ValueType == node.Type) + { + AddNodeSerializer(node, leftSerializer); + } + else if ( + leftSerializer is INullableSerializer nullableSerializer && + nullableSerializer.ValueSerializer is var nullableSerializerValueSerializer && + nullableSerializerValueSerializer.ValueType == node.Type) + { + AddNodeSerializer(node, nullableSerializerValueSerializer); + } + else + { + DeduceUnknowableSerializer(node); // coalesce will be executed client-side + } + } + } + + if (leftExpression.IsConvert(out var leftConvertOperand) && + rightExpression.IsConvert(out var rightConvertOperand) && + leftConvertOperand.Type == rightConvertOperand.Type) + { + DeduceSerializers(leftConvertOperand, rightConvertOperand); + } + + if (IsNotKnown(node)) + { + var resultSerializer = GetResultSerializer(node, @operator); + if (resultSerializer != null) + { + AddNodeSerializer(node, resultSerializer); + } + } + + return node; + + static IBsonSerializer GetResultSerializer(Expression node, ExpressionType @operator) + { + switch (@operator) + { + case ExpressionType.And: + case ExpressionType.ExclusiveOr: + case ExpressionType.Or: + switch (node.Type) + { + case Type t when t == typeof(bool): return BooleanSerializer.Instance; + case Type t when t == typeof(int): return Int32Serializer.Instance; + } + goto default; + + case ExpressionType.AndAlso: + case ExpressionType.Equal: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.NotEqual: + case ExpressionType.OrElse: + case ExpressionType.TypeEqual: + return BooleanSerializer.Instance; + + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + if (StandardSerializers.TryGetSerializer(node.Type, out var resultSerializer)) + { + return resultSerializer; + } + goto default; + + default: + return null; + } + } + + static bool IsSymmetricalBinaryOperator(ExpressionType @operator) + => @operator is + ExpressionType.Add or + ExpressionType.AddChecked or + ExpressionType.And or + ExpressionType.AndAlso or + ExpressionType.Coalesce or + ExpressionType.Divide or + ExpressionType.Equal or + ExpressionType.GreaterThan or + ExpressionType.GreaterThanOrEqual or + ExpressionType.Modulo or + ExpressionType.Multiply or + ExpressionType.MultiplyChecked or + ExpressionType.Or or + ExpressionType.OrElse or + ExpressionType.Subtract or + ExpressionType.SubtractChecked; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConditional.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConditional.cs new file mode 100644 index 00000000000..cdfbe59e81d --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConditional.cs @@ -0,0 +1,40 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitConditional(ConditionalExpression node) + { + var ifTrueExpression = node.IfTrue; + var ifFalseExpression = node.IfFalse; + + DeduceConditionalSerializers(); + base.VisitConditional(node); + DeduceConditionalSerializers(); + + return node; + + void DeduceConditionalSerializers() + { + DeduceBaseTypeAndDerivedTypeSerializers(node, ifTrueExpression); + DeduceBaseTypeAndDerivedTypeSerializers(node, ifFalseExpression); + DeduceBaseTypeAndDerivedTypeSerializers(node, ifTrueExpression); // call a second time in case ifFalse is the only known serializer + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConstant.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConstant.cs new file mode 100644 index 00000000000..7aede943349 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitConstant.cs @@ -0,0 +1,41 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitConstant(ConstantExpression node) + { + if (IsNotKnown(node) && _useDefaultSerializerForConstants) + { + if (StandardSerializers.TryGetSerializer(node.Type, out var standardSerializer)) + { + AddNodeSerializer(node, standardSerializer); + } + else + { + var registeredSerializer = BsonSerializer.LookupSerializer(node.Type); // TODO: don't use static registry + AddNodeSerializer(node, registeredSerializer); + } + } + + return node; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitIndex.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitIndex.cs new file mode 100644 index 00000000000..9245fa024ad --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitIndex.cs @@ -0,0 +1,91 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitIndex(IndexExpression node) + { + base.VisitIndex(node); + + var collectionExpression = node.Object; + var indexer = node.Indexer; + var arguments = node.Arguments; + + if (IsBsonValueIndexer()) + { + DeduceSerializer(node, BsonValueSerializer.Instance); + } + else if (IsDictionaryIndexer()) + { + if (IsKnown(collectionExpression, out var collectionSerializer) && + collectionSerializer is IBsonDictionarySerializer dictionarySerializer) + { + var valueSerializer = dictionarySerializer.ValueSerializer; + DeduceSerializer(node, valueSerializer); + } + } + // check array indexer AFTER dictionary indexer + else if (IsCollectionIndexer()) + { + if (IsKnown(collectionExpression, out var collectionSerializer) && + collectionSerializer is IBsonArraySerializer arraySerializer) + { + var itemSerializer = arraySerializer.GetItemSerializer(); + DeduceSerializer(node, itemSerializer); + } + } + // handle generic cases? + + return node; + + bool IsCollectionIndexer() + { + return + arguments.Count == 1 && + arguments[0] is var index && + index.Type == typeof(int); + } + + bool IsBsonValueIndexer() + { + var declaringType = indexer.DeclaringType; + return + (declaringType == typeof(BsonValue) || declaringType.IsSubclassOf(typeof(BsonValue))) && + arguments.Count == 1 && + arguments[0] is var index && + (index.Type == typeof(int) || index.Type == typeof(string)); + } + + bool IsDictionaryIndexer() + { + return + collectionExpression.Type.ImplementsDictionaryInterface(out var keyType, out var valueType) && + arguments.Count == 1 && + arguments[0] is var indexExpression && + indexExpression.Type == keyType && + indexer.PropertyType == valueType; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitLambda.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitLambda.cs new file mode 100644 index 00000000000..df044fc4060 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitLambda.cs @@ -0,0 +1,33 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitLambda(Expression node) + { + if (IsNotKnown(node)) + { + var ignoreNodeSerializer = IgnoreNodeSerializer.Create(node.Type); + AddNodeSerializer(node, ignoreNodeSerializer); + } + + return base.VisitLambda(node); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitListInit.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitListInit.cs new file mode 100644 index 00000000000..dcac29e1792 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitListInit.cs @@ -0,0 +1,39 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitListInit(ListInitExpression node) + { + var newExpression = node.NewExpression; + var initializers = node.Initializers; + + DeduceListInitSerializers(); + base.VisitListInit(node); + DeduceListInitSerializers(); + + return node; + + void DeduceListInitSerializers() + { + // TODO: handle initializers? + DeduceSerializers(node, newExpression); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMember.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMember.cs new file mode 100644 index 00000000000..e5fc66f555c --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMember.cs @@ -0,0 +1,234 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitMember(MemberExpression node) + { + IBsonSerializer containingSerializer; + var member = node.Member; + var declaringType = member.DeclaringType; + var memberName = member.Name; + + base.VisitMember(node); + + if (IsNotKnown(node)) + { + var containingExpression = node.Expression; + if (IsKnown(containingExpression, out containingSerializer)) + { + // TODO: are there are other cases that still need to be handled? + + var resultSerializer = node.Member switch + { + _ when declaringType == typeof(BsonValue) => GetBsonValuePropertySerializer(), + _ when IsCollectionCountOrLengthProperty() => GetCollectionCountOrLengthPropertySerializer(), + _ when declaringType == typeof(DateTime) => GetDateTimePropertySerializer(), + _ when declaringType.IsConstructedGenericType && declaringType.GetGenericTypeDefinition() == typeof(Dictionary<,>) => GetDictionaryPropertySerializer(), + _ when declaringType.IsConstructedGenericType && declaringType.GetGenericTypeDefinition() == typeof(IDictionary<,>) => GetIDictionaryPropertySerializer(), + _ when declaringType.IsNullable() => GetNullablePropertySerializer(), + _ when declaringType.IsTupleOrValueTuple() => GetTupleOrValueTuplePropertySerializer(), + _ => GetPropertySerializer() + }; + + AddNodeSerializer(node, resultSerializer); + } + } + + return node; + + IBsonSerializer GetBsonValuePropertySerializer() + { + return memberName switch + { + "AsBoolean" => BooleanSerializer.Instance, + "AsBsonArray" => BsonArraySerializer.Instance, + "AsBsonBinaryData" => BsonBinaryDataSerializer.Instance, + "AsBsonDateTime" => BsonDateTimeSerializer.Instance, + "AsBsonDocument" => BsonDocumentSerializer.Instance, + "AsBsonJavaScript" => BsonJavaScriptSerializer.Instance, + "AsBsonJavaScriptWithScope" => BsonJavaScriptWithScopeSerializer.Instance, + "AsBsonMaxKey" => BsonMaxKeySerializer.Instance, + "AsBsonMinKey" => BsonMinKeySerializer.Instance, + "AsBsonNull" => BsonNullSerializer.Instance, + "AsBsonRegularExpression" => BsonRegularExpressionSerializer.Instance, + "AsBsonSymbol" => BsonSymbolSerializer.Instance, + "AsBsonTimestamp" => BsonTimestampSerializer.Instance, + "AsBsonUndefined" => BsonUndefinedSerializer.Instance, + "AsBsonValue" => BsonValueSerializer.Instance, + "AsByteArray" => ByteArraySerializer.Instance, + "AsDecimal128" => Decimal128Serializer.Instance, + "AsDecimal" => DecimalSerializer.Instance, + "AsDouble" => DoubleSerializer.Instance, + "AsGuid" => GuidSerializer.StandardInstance, + "AsInt32" => Int32Serializer.Instance, + "AsInt64" => Int64Serializer.Instance, + "AsLocalTime" => DateTimeSerializer.LocalInstance, + "AsNullableBoolean" => NullableSerializer.NullableBooleanInstance, + "AsNullableDecimal128" => NullableSerializer.NullableDecimal128Instance, + "AsNullableDecimal" => NullableSerializer.NullableDecimalInstance, + "AsNullableDouble" => NullableSerializer.NullableDoubleInstance, + "AsNullableGuid" => NullableSerializer.NullableStandardGuidInstance, + "AsNullableInt32" => NullableSerializer.NullableInt32Instance, + "AsNullableInt64" => NullableSerializer.NullableInt64Instance, + "AsNullableLocalTime" => NullableSerializer.NullableLocalDateTimeInstance, + "AsNullableObjectId" => NullableSerializer.NullableObjectIdInstance, + "AsNullableUniversalTime" => NullableSerializer.NullableUtcDateTimeInstance, + "AsObjectId" => ObjectIdSerializer.Instance, + "AsRegex" => RegexSerializer.RegularExpressionInstance, + "AsString" => StringSerializer.Instance, + "AsUniversalTime" => DateTimeSerializer.UtcInstance, + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetCollectionCountOrLengthPropertySerializer() + { + return Int32Serializer.Instance; + } + + IBsonSerializer GetDateTimePropertySerializer() + { + return memberName switch + { + "Date" => DateTimeSerializer.Instance, + "Day" => Int32Serializer.Instance, + "DayOfWeek" => new EnumSerializer(BsonType.Int32), + "DayOfYear" => Int32Serializer.Instance, + "Hour" => Int32Serializer.Instance, + "Millisecond" => Int32Serializer.Instance, + "Minute" => Int32Serializer.Instance, + "Month" => Int32Serializer.Instance, + "Now" => DateTimeSerializer.Instance, + "Second" => Int32Serializer.Instance, + "Ticks" => Int64Serializer.Instance, + "TimeOfDay" => new TimeSpanSerializer(BsonType.Int64, TimeSpanUnits.Milliseconds), + "Today" => DateTimeSerializer.Instance, + "UtcNow" => DateTimeSerializer.Instance, + "Year" => Int32Serializer.Instance, + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetDictionaryPropertySerializer() + { + if (containingSerializer.GetValueSerializerIfWrapped() is not IBsonDictionarySerializer dictionarySerializer) + { + throw new ExpressionNotSupportedException(node, because: "dictionary serializer does not implement IBsonDictionarySerializer"); + } + + var keySerializer = dictionarySerializer.KeySerializer; + var valueSerializer = dictionarySerializer.ValueSerializer; + + return memberName switch + { + "Keys" => DictionaryKeyCollectionSerializer.Create(keySerializer, valueSerializer), + "Values" => DictionaryValueCollectionSerializer.Create(keySerializer, valueSerializer), + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetIDictionaryPropertySerializer() + { + if (containingSerializer is not IBsonDictionarySerializer dictionarySerializer) + { + throw new ExpressionNotSupportedException(node, because: "IDictionarySerializer does not implement IBsonDictionarySerializer"); + } + + var keySerializer = dictionarySerializer.KeySerializer; + var valueSerializer = dictionarySerializer.ValueSerializer; + + return memberName switch + { + "Keys" => ICollectionSerializer.Create(keySerializer), + "Values" => ICollectionSerializer.Create(valueSerializer), + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetNullablePropertySerializer() + { + return memberName switch + { + "HasValue" => BooleanSerializer.Instance, + "Value" => (containingSerializer as INullableSerializer)?.ValueSerializer, + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + IBsonSerializer GetPropertySerializer() + { + if (containingSerializer is not IBsonDocumentSerializer documentSerializer) + { + // TODO: return UnknowableSerializer??? + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not implement the {nameof(IBsonDocumentSerializer)} interface"); + } + + if (!documentSerializer.TryGetMemberSerializationInfo(memberName, out var memberSerializationInfo)) + { + // TODO: return UnknowableSerializer??? + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not support a member named: {memberName}"); + } + + return memberSerializationInfo.Serializer; + } + + IBsonSerializer GetTupleOrValueTuplePropertySerializer() + { + if (containingSerializer is not IBsonTupleSerializer tupleSerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {containingSerializer.GetType()} does not implement the {nameof(IBsonTupleSerializer)} interface"); + } + + return memberName switch + { + "Item1" => tupleSerializer.GetItemSerializer(1), + "Item2" => tupleSerializer.GetItemSerializer(2), + "Item3" => tupleSerializer.GetItemSerializer(3), + "Item4" => tupleSerializer.GetItemSerializer(4), + "Item5" => tupleSerializer.GetItemSerializer(5), + "Item6" => tupleSerializer.GetItemSerializer(6), + "Item7" => tupleSerializer.GetItemSerializer(7), + "Rest" => tupleSerializer.GetItemSerializer(8), + // TODO: return UnknowableSerializer??? + _ => throw new ExpressionNotSupportedException(node, because: $"Unexpected member name: {memberName}") + }; + } + + bool IsCollectionCountOrLengthProperty() + { + return + (declaringType.ImplementsInterface(typeof(IEnumerable)) || declaringType == typeof(BitArray)) && + node.Type == typeof(int) && + (member.Name == "Count" || member.Name == "Length"); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMemberInit.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMemberInit.cs new file mode 100644 index 00000000000..b90992d9be4 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMemberInit.cs @@ -0,0 +1,97 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitMemberInit(MemberInitExpression node) + { + if (IsKnown(node, out var nodeSerializer)) + { + var newExpression = node.NewExpression; + if (newExpression != null) + { + if (IsNotKnown(newExpression)) + { + AddNodeSerializer(newExpression, nodeSerializer); + } + } + + if (node.Bindings.Count > 0) + { + if (nodeSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new ExpressionNotSupportedException(node, because: $"serializer type {nodeSerializer.GetType()} does not implement IBsonDocumentSerializer interface"); + } + + foreach (var binding in node.Bindings) + { + if (binding is MemberAssignment memberAssignment) + { + if (IsNotKnown(memberAssignment.Expression)) + { + var member = memberAssignment.Member; + var memberName = member.Name; + if (!documentSerializer.TryGetMemberSerializationInfo(memberName, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(node, because: $"type {member.DeclaringType} does not have a member named: {memberName}"); + } + var expressionSerializer = memberSerializationInfo.Serializer; + + if (expressionSerializer.ValueType != memberAssignment.Expression.Type && + expressionSerializer.ValueType.IsAssignableFrom(memberAssignment.Expression.Type)) + { + expressionSerializer = expressionSerializer.GetDerivedTypeSerializer(memberAssignment.Expression.Type); + } + + // member = expression => expression: memberSerializer (or derivedTypeSerializer) + AddNodeSerializer(memberAssignment.Expression, expressionSerializer); + } + } + } + } + } + + base.VisitMemberInit(node); + + if (IsNotKnown(node)) + { + var resultSerializer = GetResultSerializer(); + if (resultSerializer != null) + { + AddNodeSerializer(node, resultSerializer); + } + } + + return node; + + IBsonSerializer GetResultSerializer() + { + if (node.Type == typeof(BsonDocument)) + { + return BsonDocumentSerializer.Instance; + } + var newExpression = node.NewExpression; + var bindings = node.Bindings; + return CreateNewExpressionSerializer(node, newExpression, bindings); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMethodCall.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMethodCall.cs new file mode 100644 index 00000000000..804f47c6ee8 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMethodCall.cs @@ -0,0 +1,2539 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + private static readonly IReadOnlyMethodInfoSet __averageOrMedianOrPercentileOverloads = MethodInfoSet.Create( + [ + EnumerableOrQueryableMethod.AverageOverloads, + MongoEnumerableMethod.MedianOverloads, + MongoEnumerableMethod.PercentileOverloads, + WindowMethod.PercentileOverloads + ]); + + private static readonly IReadOnlyMethodInfoSet __averageOrMedianOrPercentileWithSelectorOverloads = MethodInfoSet.Create( + [ + EnumerableOrQueryableMethod.AverageWithSelectorOverloads, + MongoEnumerableMethod.MedianWithSelectorOverloads, + MongoEnumerableMethod.PercentileWithSelectorOverloads, + WindowMethod.PercentileOverloads + ]); + + private static readonly IReadOnlyMethodInfoSet __whereOverloads = MethodInfoSet.Create( + [ + EnumerableOrQueryableMethod.Where, + [MongoEnumerableMethod.WhereWithLimit] + ]); + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + var method = node.Method; + var arguments = node.Arguments; + + DeduceMethodCallSerializers(); + base.VisitMethodCall(node); + DeduceMethodCallSerializers(); + + return node; + + void DeduceMethodCallSerializers() + { + switch (node.Method.Name) + { + case "Abs": DeduceAbsMethodSerializers(); break; + case "Add": DeduceAddMethodSerializers(); break; + case "AddDays": DeduceAddDaysMethodSerializers(); break; + case "AddHours": DeduceAddHoursMethodSerializers(); break; + case "AddMilliseconds": DeduceAddMillisecondsMethodSerializers(); break; + case "AddMinutes": DeduceAddMinutesMethodSerializers(); break; + case "AddMonths": DeduceAddMonthsMethodSerializers(); break; + case "AddQuarters": DeduceAddQuartersMethodSerializers(); break; + case "AddSeconds": DeduceAddSecondsMethodSerializers(); break; + case "AddTicks": DeduceAddTicksMethodSerializers(); break; + case "AddWeeks": DeduceAddWeeksMethodSerializers(); break; + case "AddYears": DeduceAddYearsMethodSerializers(); break; + case "Aggregate": DeduceAggregateMethodSerializers(); break; + case "All": DeduceAllMethodSerializers(); break; + case "Any": DeduceAnyMethodSerializers(); break; + case "AppendStage": DeduceAppendStageMethodSerializers(); break; + case "As": DeduceAsMethodSerializers(); break; + case "AsQueryable": DeduceAsQueryableMethodSerializers(); break; + case "Concat": DeduceConcatMethodSerializers(); break; + case "Constant": DeduceConstantMethodSerializers(); break; + case "Contains": DeduceContainsMethodSerializers(); break; + case "ContainsKey": DeduceContainsKeyMethodSerializers(); break; + case "ContainsValue": DeduceContainsValueMethodSerializers(); break; + case "Convert": DeduceConvertMethodSerializers(); break; + case "Create": DeduceCreateMethodSerializers(); break; + case "DefaultIfEmpty": DeduceDefaultIfEmptyMethodSerializers(); break; + case "DegreesToRadians": DeduceDegreesToRadiansMethodSerializers(); break; + case "Distinct": DeduceDistinctMethodSerializers(); break; + case "Documents": DeduceDocumentsMethodSerializers(); break; + case "Equals": DeduceEqualsMethodSerializers(); break; + case "Except": DeduceExceptMethodSerializers(); break; + case "Exists": DeduceExistsMethodSerializers(); break; + case "Exp": DeduceExpMethodSerializers(); break; + case "Field": DeduceFieldMethodSerializers(); break; + case "get_Item": DeduceGetItemMethodSerializers(); break; + case "get_Chars": DeduceGetCharsMethodSerializers(); break; + case "GroupBy": DeduceGroupByMethodSerializers(); break; + case "GroupJoin": DeduceGroupJoinMethodSerializers(); break; + case "Inject": DeduceInjectMethodSerializers(); break; + case "Intersect": DeduceIntersectMethodSerializers(); break; + case "IsMatch": DeduceIsMatchMethodSerializers(); break; + case "IsSubsetOf": DeduceIsSubsetOfMethodSerializers(); break; + case "Join": DeduceJoinMethodSerializers(); break; + case "Lookup": DeduceLookupMethodSerializers(); break; + case "OfType": DeduceOfTypeMethodSerializers(); break; + case "Parse": DeduceParseMethodSerializers(); break; + case "Pow": DeducePowMethodSerializers(); break; + case "RadiansToDegrees": DeduceRadiansToDegreesMethodSerializers(); break; + case "Range": DeduceRangeMethodSerializers(); break; + case "Repeat": DeduceRepeatMethodSerializers(); break; + case "Reverse": DeduceReverseMethodSerializers(); break; + case "Round": DeduceRoundMethodSerializers(); break; + case "Select": DeduceSelectMethodSerializers(); break; + case "SelectMany": DeduceSelectManySerializers(); break; + case "SequenceEqual": DeduceSequenceEqualMethodSerializers(); break; + case "SetEquals": DeduceSetEqualsMethodSerializers(); break; + case "SetWindowFields": DeduceSetWindowFieldsMethodSerializers(); break; + case "Shift": DeduceShiftMethodSerializers(); break; + case "Split": DeduceSplitMethodSerializers(); break; + case "Sqrt": DeduceSqrtMethodSerializers(); break; + case "StringIn": DeduceStringInMethodSerializers(); break; + case "StrLenBytes": DeduceStrLenBytesMethodSerializers(); break; + case "Subtract": DeduceSubtractMethodSerializers(); break; + case "Sum": DeduceSumMethodSerializers(); break; + case "ToArray": DeduceToArrayMethodSerializers(); break; + case "ToList": DeduceToListSerializers(); break; + case "ToString": DeduceToStringSerializers(); break; + case "Truncate": DeduceTruncateSerializers(); break; + case "Union": DeduceUnionSerializers(); break; + case "Week": DeduceWeekSerializers(); break; + case "Where": DeduceWhereSerializers(); break; + case "Zip": DeduceZipSerializers(); break; + + case "Acos": + case "Acosh": + case "Asin": + case "Asinh": + case "Atan": + case "Atanh": + case "Atan2": + case "Cos": + case "Cosh": + case "Sin": + case "Sinh": + case "Tan": + case "Tanh": + DeduceTrigonometricMethodSerializers(); + break; + + case "AllElements": + case "AllMatchingElements": + case "FirstMatchingElement": + DeduceMatchingElementsMethodSerializers(); + break; + + case "Append": + case "Prepend": + DeduceAppendOrPrependMethodSerializers(); + break; + + case "Average": + case "Median": + case "Percentile": + DeduceAverageOrMedianOrPercentileMethodSerializers(); + break; + + case "Bottom": + case "BottomN": + case "FirstN": + case "LastN": + case "MaxN": + case "MinN": + case "Top": + case "TopN": + DeducePickMethodSerializers(); + break; + + case "Ceiling": + case "Floor": + DeduceCeilingOrFloorMethodSerializers(); + break; + + case "Compare": + case "CompareTo": + DeduceCompareOrCompareToMethodSerializers(); + break; + + case "Count": + case "LongCount": + DeduceCountMethodSerializers(); + break; + + case "ElementAt": + case "ElementAtOrDefault": + DeduceElementAtMethodSerializers(); + break; + + case "EndsWith": + case "StartsWith": + DeduceEndsWithOrStartsWithMethodSerializers(); + break; + + case "First": + case "FirstOrDefault": + case "Last": + case "LastOrDefault": + case "Single": + case "SingleOrDefault": + DeduceFirstOrLastOrSingleMethodsSerializers(); + break; + + case "IndexOf": + case "IndexOfBytes": + DeduceIndexOfMethodSerializers(); + break; + + case "IsMissing": + case "IsNullOrMissing": + DeduceIsMissingOrIsNullOrMissingMethodSerializers(); + break; + + case "IsNullOrEmpty": + case "IsNullOrWhiteSpace": + DeduceIsNullOrEmptyOrIsNullOrWhiteSpaceMethodSerializers(); + break; + + case "Ln": + case "Log": + case "Log10": + DeduceLogMethodSerializers(); + break; + + case "Max": + case "Min": + DeduceMaxOrMinMethodSerializers(); + break; + + case "OrderBy": + case "OrderByDescending": + case "ThenBy": + case "ThenByDescending": + DeduceOrderByMethodSerializers(); + break; + + case "Skip": + case "SkipWhile": + case "Take": + case "TakeWhile": + DeduceSkipOrTakeMethodSerializers(); + break; + + case "StandardDeviationPopulation": + case "StandardDeviationSample": + DeduceStandardDeviationMethodSerializers(); + break; + + case "Substring": + case "SubstrBytes": + DeduceSubstringMethodSerializers(); + break; + + case "ToLower": + case "ToLowerInvariant": + case "ToUpper": + case "ToUpperInvariant": + DeduceToLowerOrToUpperSerializers(); + break; + + default: + DeduceUnknownMethodSerializer(); + break; + } + } + + void DeduceAbsMethodSerializers() + { + if (method.IsOneOf(MathMethod.AbsOverloads)) + { + var valueExpression = arguments[0]; + DeduceSerializers(node, valueExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.Add, DateTimeMethod.AddWithTimezone, DateTimeMethod.AddWithUnit, DateTimeMethod.AddWithUnitAndTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddDaysMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddDays, DateTimeMethod.AddDaysWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddHoursMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddHours, DateTimeMethod.AddHoursWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMillisecondsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddMilliseconds, DateTimeMethod.AddMillisecondsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMinutesMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddMinutes, DateTimeMethod.AddMinutesWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddMonthsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddMonths, DateTimeMethod.AddMonthsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddQuartersMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddQuarters, DateTimeMethod.AddQuartersWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddSecondsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddSeconds, DateTimeMethod.AddSecondsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddTicksMethodSerializers() + { + if (method.Is(DateTimeMethod.AddTicks)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddWeeksMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddWeeks, DateTimeMethod.AddWeeksWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAddYearsMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.AddYears, DateTimeMethod.AddYearsWithTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAggregateMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateOverloads)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithFunc)) + { + var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var funcAccumulatorParameter = funcLambda.Parameters[0]; + var funcSourceItemParameter = funcLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(funcAccumulatorParameter, sourceExpression); + DeduceItemAndCollectionSerializers(funcSourceItemParameter, sourceExpression); + DeduceItemAndCollectionSerializers(funcLambda.Body, sourceExpression); + DeduceSerializers(node, funcLambda.Body); + } + + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithSeedAndFunc)) + { + var seedExpression = arguments[1]; + var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var funcAccumulatorParameter = funcLambda.Parameters[0]; + var funcSourceItemParameter = funcLambda.Parameters[1]; + + DeduceSerializers(seedExpression, funcLambda.Body); + DeduceSerializers(funcAccumulatorParameter, funcLambda.Body); + DeduceItemAndCollectionSerializers(funcSourceItemParameter, sourceExpression); + DeduceSerializers(node, funcLambda.Body); + } + + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithSeedFuncAndResultSelector)) + { + var seedExpression = arguments[1]; + var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var funcAccumulatorParameter = funcLambda.Parameters[0]; + var funcSourceItemParameter = funcLambda.Parameters[1]; + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var resultSelectorAccumulatorParameter = resultSelectorLambda.Parameters[0]; + + DeduceSerializers(seedExpression, funcLambda.Body); + DeduceSerializers(funcAccumulatorParameter, funcLambda.Body); + DeduceItemAndCollectionSerializers(funcSourceItemParameter, sourceExpression); + DeduceSerializers(resultSelectorAccumulatorParameter, funcLambda.Body); + DeduceSerializers(node, resultSelectorLambda.Body); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAllMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.AllWithPredicate, QueryableMethod.AllWithPredicate)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAnyMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.AnyOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.AnyWithPredicate)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters[0]; + + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAppendOrPrependMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.AppendOrPrepend)) + { + var sourceExpression = arguments[0]; + var elementExpression = arguments[1]; + + DeduceItemAndCollectionSerializers(elementExpression, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAsMethodSerializers() + { + if (method.Is(MongoQueryableMethod.As)) + { + if (IsNotKnown(node)) + { + var resultSerializerExpression = arguments[1]; + if (resultSerializerExpression is not ConstantExpression resultSerializerConstantExpression) + { + throw new ExpressionNotSupportedException(node, because: "resultSerializer argument must be a constant"); + } + + var resultItemSerializer = (IBsonSerializer)resultSerializerConstantExpression.Value; + if (resultItemSerializer == null) + { + var resultItemType = method.GetGenericArguments()[1]; + resultItemSerializer = BsonSerializer.LookupSerializer(resultItemType); + } + + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAppendStageMethodSerializers() + { + if (method.Is(MongoQueryableMethod.AppendStage)) + { + if (IsNotKnown(node)) + { + var sourceExpression = arguments[0]; + var stageExpression = arguments[1]; + var resultSerializerExpression = arguments[2]; + + if (stageExpression is not ConstantExpression stageConstantExpression) + { + throw new ExpressionNotSupportedException(node, because: "stage argument must be a constant"); + } + var stageDefinition = (IPipelineStageDefinition)stageConstantExpression.Value; + + if (resultSerializerExpression is not ConstantExpression resultSerializerConstantExpression) + { + throw new ExpressionNotSupportedException(node, because: "resultSerializer argument must be a constant"); + } + var resultItemSerializer = (IBsonSerializer)resultSerializerConstantExpression.Value; + + if (resultItemSerializer == null && IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var serializerRegistry = BsonSerializer.SerializerRegistry; // TODO: get correct registry + var translationOptions = new ExpressionTranslationOptions(); // TODO: get correct translation options + var renderedStage = stageDefinition.Render(sourceItemSerializer, serializerRegistry, translationOptions); + resultItemSerializer = renderedStage.OutputSerializer; + } + + if (resultItemSerializer != null) + { + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAsQueryableMethodSerializers() + { + if (method.Is(QueryableMethod.AsQueryable)) + { + var sourceExpression = arguments[0]; + + if (IsNotKnown(node) && IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var resultSerializer = NestedAsQueryableSerializer.Create(sourceItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceAverageOrMedianOrPercentileMethodSerializers() + { + if (method.IsOneOf(__averageOrMedianOrPercentileOverloads)) + { + if (method.IsOneOf(__averageOrMedianOrPercentileWithSelectorOverloads)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorSourceItemParameter = selectorLambda.Parameters[0]; + + DeduceItemAndCollectionSerializers(selectorSourceItemParameter, sourceExpression); + } + + if (IsNotKnown(node)) + { + var nodeSerializer = StandardSerializers.GetSerializer(node.Type); + AddNodeSerializer(node, nodeSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCeilingOrFloorMethodSerializers() + { + if (method.IsOneOf(MathMethod.CeilingWithDecimal, MathMethod.CeilingWithDouble, MathMethod.FloorWithDecimal, MathMethod.FloorWithDouble)) + { + DeduceReturnsNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCompareOrCompareToMethodSerializers() + { + if (method.IsStaticCompareMethod() || + method.IsInstanceCompareToMethod() || + method.IsOneOf(StringMethod.CompareOverloads)) + { + var valueExpression = method.IsStatic ? arguments[0] : node.Object; + var comparandExpression = method.IsStatic ? arguments[1] : arguments[0]; + DeduceSerializers(valueExpression, comparandExpression); + DeduceReturnsInt32Serializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceConcatMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Concat, QueryableMethod.Concat)) + { + var firstExpression = arguments[0]; + var secondExpression = arguments[1]; + + DeduceCollectionAndCollectionSerializers(firstExpression, secondExpression); + DeduceCollectionAndCollectionSerializers(node, firstExpression); + } + else if (method.IsOneOf(StringMethod.ConcatOverloads)) + { + DeduceReturnsStringSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceConstantMethodSerializers() + { + if (method.IsOneOf(MqlMethod.ConstantWithRepresentation, MqlMethod.ConstantWithSerializer)) + { + var valueExpression = arguments[0]; + IBsonSerializer serializer = null; + + if (IsNotKnown(node) || IsNotKnown(valueExpression)) + { + if (method.Is(MqlMethod.ConstantWithRepresentation)) + { + var representationExpression = arguments[1]; + + var representation = representationExpression.GetConstantValue(node); + var defaultSerializer = BsonSerializer.LookupSerializer(valueExpression.Type); // TODO: don't use BsonSerializer + if (defaultSerializer is IRepresentationConfigurable representationConfigurableSerializer) + { + serializer = representationConfigurableSerializer.WithRepresentation(representation); + } + } + else if (method.Is(MqlMethod.ConstantWithSerializer)) + { + var serializerExpression = arguments[1]; + serializer = serializerExpression.GetConstantValue(node); + } + } + + DeduceSerializer(valueExpression, serializer); + DeduceSerializer(node, serializer); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceContainsKeyMethodSerializers() + { + if (IsDictionaryContainsKeyExpression(out var keyExpression)) + { + var dictionaryExpression = node.Object; + if (IsNotKnown(keyExpression) && IsKnown(dictionaryExpression, out var dictionarySerializer)) + { + var keySerializer = (dictionarySerializer as IBsonDictionarySerializer)?.KeySerializer; + AddNodeSerializer(keyExpression, keySerializer); + } + + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceContainsMethodSerializers() + { + if (method.IsOneOf(StringMethod.ContainsOverloads)) + { + DeduceReturnsBooleanSerializer(); + } + else if (EnumerableMethod.IsContainsMethod(node, out var collectionExpression, out var itemExpression)) + { + DeduceCollectionAndItemSerializers(collectionExpression, itemExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceContainsValueMethodSerializers() + { + if (IsContainsValueInstanceMethod(out var collectionExpression, out var valueExpression)) + { + if (IsNotKnown(valueExpression) && + IsKnown(collectionExpression, out var collectionSerializer)) + { + if (collectionSerializer is IBsonDictionarySerializer dictionarySerializer) + { + var valueSerializer = dictionarySerializer.ValueSerializer; + AddNodeSerializer(valueExpression, valueSerializer); + } + } + + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsContainsValueInstanceMethod(out Expression collectionExpression, out Expression valueExpression) + { + if (method.IsPublic && + method.IsStatic == false && + method.ReturnType == typeof(bool) && + method.Name == "ContainsValue" && + method.GetParameters() is var parameters && + parameters.Length == 1) + { + collectionExpression = node.Object; + valueExpression = arguments[0]; + return true; + } + + collectionExpression = null; + valueExpression = null; + return false; + } + } + + void DeduceConvertMethodSerializers() + { + if (method.Is(MqlMethod.Convert)) + { + if (IsNotKnown(node)) + { + var toType = method.GetGenericArguments()[1]; + var resultSerializer = GetResultSerializer(node, toType); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + + static IBsonSerializer GetResultSerializer(Expression expression, Type toType) + { + // TODO: should we use StandardSerializers at least for the subset of types where it would return the correct serializer? + + var isNullable = toType.IsNullable(); + var valueType = isNullable ? Nullable.GetUnderlyingType(toType) : toType; + + var valueSerializer = (IBsonSerializer)(Type.GetTypeCode(valueType) switch + { + TypeCode.Boolean => BooleanSerializer.Instance, + TypeCode.Byte => ByteSerializer.Instance, + TypeCode.Char => StringSerializer.Instance, + TypeCode.DateTime => DateTimeSerializer.Instance, + TypeCode.Decimal => DecimalSerializer.Instance, + TypeCode.Double => DoubleSerializer.Instance, + TypeCode.Int16 => Int16Serializer.Instance, + TypeCode.Int32 => Int32Serializer.Instance, + TypeCode.Int64 => Int64Serializer.Instance, + TypeCode.SByte => SByteSerializer.Instance, + TypeCode.Single => SingleSerializer.Instance, + TypeCode.String => StringSerializer.Instance, + TypeCode.UInt16 => UInt16Serializer.Instance, + TypeCode.UInt32 => Int32Serializer.Instance, + TypeCode.UInt64 => UInt64Serializer.Instance, + + _ when valueType == typeof(byte[]) => ByteArraySerializer.Instance, + _ when valueType == typeof(BsonBinaryData) => BsonBinaryDataSerializer.Instance, + _ when valueType == typeof(Decimal128) => Decimal128Serializer.Instance, + _ when valueType == typeof(Guid) => GuidSerializer.StandardInstance, + _ when valueType == typeof(ObjectId) => ObjectIdSerializer.Instance, + + _ => throw new ExpressionNotSupportedException(expression, because: $"{toType} is not a valid TTo for Convert") + }); + + return isNullable ? NullableSerializer.Create(valueSerializer) : valueSerializer; + } + } + + void DeduceCreateMethodSerializers() + { +#if NET6_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER + if (method.Is(KeyValuePairMethod.Create)) + { + if (IsAnyNotKnown(arguments) && IsKnown(node, out var nodeSerializer)) + { + var keyExpression = arguments[0]; + var valueExpression = arguments[1]; + + if (nodeSerializer.IsKeyValuePairSerializer(out _, out _, out var keySerializer, out var valueSerializer)) + { + DeduceSerializer(keyExpression, keySerializer); + DeduceSerializer(valueExpression, valueSerializer); + } + } + + if (IsNotKnown(node) && AreAllKnown(arguments, out var argumentSerializers)) + { + var keySerializer = argumentSerializers[0]; + var valueSerializer = argumentSerializers[1]; + var keyValuePairSerializer = KeyValuePairSerializer.Create(BsonType.Document, keySerializer, valueSerializer); + AddNodeSerializer(node, keyValuePairSerializer); + } + } + else + #endif + if (method.IsOneOf(TupleOrValueTupleMethod.CreateOverloads)) + { + if (IsAnyNotKnown(arguments) && IsKnown(node, out var nodeSerializer)) + { + if (nodeSerializer is IBsonTupleSerializer tupleSerializer) + { + for (var i = 1; i <= arguments.Count; i++) + { + var argumentExpression = arguments[i]; + if (IsNotKnown(argumentExpression)) + { + var itemSerializer = tupleSerializer.GetItemSerializer(i); + if (i == 8) + { + itemSerializer = (itemSerializer as IBsonTupleSerializer)?.GetItemSerializer(1); + } + AddNodeSerializer(argumentExpression, itemSerializer); + } + } + } + } + + if (IsNotKnown(node) && AreAllKnown(arguments, out var argumentSerializers)) + { + var tupleType = method.ReturnType; + + if (arguments.Count == 8) + { + var item8Expression = arguments[7]; + var item8Type = item8Expression.Type; + var item8Serializer = argumentSerializers[7]; + var restTupleType = (tupleType.IsTuple() ? typeof(Tuple<>) : typeof(ValueTuple<>)).MakeGenericType(item8Type); + var restSerializer = TupleOrValueTupleSerializer.Create(restTupleType, [item8Serializer]); + argumentSerializers = argumentSerializers.Take(7).Append(restSerializer).ToArray(); + } + + var tupleSerializer = TupleOrValueTupleSerializer.Create(tupleType, argumentSerializers); + AddNodeSerializer(node, tupleSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceCountMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.CountOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.CountWithPredicateOverloads)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceReturnsNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDefaultIfEmptyMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.DefaultIfEmpty, EnumerableMethod.DefaultIfEmptyWithDefaultValue, QueryableMethod.DefaultIfEmpty, QueryableMethod.DefaultIfEmptyWithDefaultValue)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(EnumerableMethod.DefaultIfEmptyWithDefaultValue, QueryableMethod.DefaultIfEmptyWithDefaultValue)) + { + var defaultValueExpression = arguments[1]; + DeduceItemAndCollectionSerializers(defaultValueExpression, sourceExpression); + } + + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDegreesToRadiansMethodSerializers() + { + if (method.Is(MongoDBMathMethod.DegreesToRadians)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDistinctMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Distinct, QueryableMethod.Distinct)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceDocumentsMethodSerializers() + { + if (method.IsOneOf(MongoQueryableMethod.Documents, MongoQueryableMethod.DocumentsWithSerializer)) + { + if (IsNotKnown(node)) + { + IBsonSerializer documentSerializer; + if (method.Is(MongoQueryableMethod.DocumentsWithSerializer)) + { + var documentSerializerExpression = arguments[2]; + documentSerializer = documentSerializerExpression.GetConstantValue(node); + } + else + { + var documentsParameter = method.GetParameters()[1]; + var documentType = documentsParameter.ParameterType.GetElementType(); + documentSerializer = BsonSerializer.LookupSerializer(documentType); // TODO: don't use static registry + } + + var nodeSerializer = IQueryableSerializer.Create(documentSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceElementAtMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.ElementAtOverloads)) + { + var sourceExpression = arguments[0]; + DeduceItemAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceEqualsMethodSerializers() + { + if (IsEqualsReturningBooleanMethod(out var expression1, out var expression2)) + { + DeduceSerializers(expression1, expression2); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsEqualsReturningBooleanMethod(out Expression expression1, out Expression expression2) + { + if (method.Name == "Equals" && + method.ReturnType == typeof(bool) && + method.IsPublic) + { + if (method.IsStatic && + arguments.Count == 2) + { + expression1 = arguments[0]; + expression2 = arguments[1]; + return true; + } + + if (!method.IsStatic && + arguments.Count == 1) + { + expression1 = node.Object; + expression2 = arguments[0]; + return true; + } + + if (method.Is(StringMethod.EqualsWithComparisonType)) + { + expression1 = node.Object; + expression2 = arguments[0]; + return true; + } + + if (method.Is(StringMethod.StaticEqualsWithComparisonType)) + { + expression1 = arguments[0]; + expression2 = arguments[1]; + return true; + } + } + + expression1 = null; + expression2 = null; + return false; + } + } + + void DeduceExceptMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Except, QueryableMethod.Except)) + { + var firstExpression = arguments[0]; + var secondExpression = arguments[1]; + DeduceCollectionAndCollectionSerializers(secondExpression, firstExpression); + DeduceCollectionAndCollectionSerializers(node, firstExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceExistsMethodSerializers() + { + if (method.Is(ArrayMethod.Exists) || ListMethod.IsExistsMethod(method)) + { + var collectionExpression = method.IsStatic ? arguments[0] : node.Object; + var predicateExpression = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, method.IsStatic ? arguments[1] : arguments[0]); + var predicateParameter = predicateExpression.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, collectionExpression); + DeduceReturnsBooleanSerializer(); + } + else if (method.Is(MqlMethod.Exists)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceExpMethodSerializers() + { + if (method.Is(MathMethod.Exp)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceFieldMethodSerializers() + { + if (method.Is(MqlMethod.Field)) + { + if (IsNotKnown(node)) + { + var fieldSerializerExpression = arguments[2]; + var fieldSerializer = fieldSerializerExpression.GetConstantValue(node); + if (fieldSerializer == null) + { + throw new ExpressionNotSupportedException(node, because: "fieldSerializer is null"); + } + + AddNodeSerializer(node, fieldSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceFirstOrLastOrSingleMethodsSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrLastOrSingleOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrLastOrSingleWithPredicateOverloads)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceReturnsOneSourceItemSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceGetItemMethodSerializers() + { + if (IsNotKnown(node)) + { + if (BsonValueMethod.IsGetItemWithIntMethod(method) || BsonValueMethod.IsGetItemWithStringMethod(method)) + { + AddNodeSerializer(node, BsonValueSerializer.Instance); + } + else if (IsInstanceGetItemMethod(out var collectionExpression, out var indexExpression)) + { + if (IsKnown(collectionExpression, out var collectionSerializer)) + { + if (collectionSerializer is IBsonArraySerializer arraySerializer && + indexExpression.Type == typeof(int) && + arraySerializer.GetItemSerializer() is var itemSerializer && + itemSerializer.ValueType == method.ReturnType) + { + AddNodeSerializer(node, itemSerializer); + } + else if ( + collectionSerializer is IBsonDictionarySerializer dictionarySerializer && + dictionarySerializer.KeySerializer is var keySerializer && + dictionarySerializer.ValueSerializer is var valueSerializer && + keySerializer.ValueType == indexExpression.Type && + valueSerializer.ValueType == method.ReturnType) + { + AddNodeSerializer(node, valueSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + bool IsInstanceGetItemMethod(out Expression collectionExpression, out Expression indexExpression) + { + if (method.IsStatic == false && + method.Name == "get_Item") + { + collectionExpression = node.Object; + indexExpression = arguments[0]; + return true; + } + + collectionExpression = null; + indexExpression = null; + return false; + } + } + + void DeduceGetCharsMethodSerializers() + { + if (method.Is(StringMethod.GetChars)) + { + DeduceCharSerializer(node); + } + + DeduceUnknowableSerializer(node); + } + + void DeduceGroupByMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByOverloads)) + { + var sourceExpression = arguments[0]; + var keySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var keySelectorParameter = keySelectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + + if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByWithKeySelector)) + { + if (IsNotKnown(node) && IsKnown(keySelectorLambda.Body, out var keySerializer) && IsItemSerializerKnown(sourceExpression, out var elementSerializer)) + { + var groupingSerializer = IGroupingSerializer.Create(keySerializer, elementSerializer); + var nodeSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, groupingSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + else if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByWithKeySelectorAndElementSelector)) + { + var elementSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var elementSelectorParameter = elementSelectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(elementSelectorParameter, sourceExpression); + if (IsNotKnown(node) && IsKnown(keySelectorLambda.Body, out var keySerializer) && IsKnown(elementSelectorLambda.Body, out var elementSerializer)) + { + var groupingSerializer = IGroupingSerializer.Create(keySerializer, elementSerializer); + var nodeSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, groupingSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + else if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByWithKeySelectorAndResultSelector)) + { + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var resultSelectorKeyParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorElementsParameter = resultSelectorLambda.Parameters[1]; + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + DeduceSerializers(resultSelectorKeyParameter, keySelectorLambda.Body); + DeduceCollectionAndCollectionSerializers(resultSelectorElementsParameter, sourceExpression); + DeduceResultSerializer(resultSelectorLambda.Body); + } + else if (method.IsOneOf(EnumerableOrQueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector)) + { + var elementSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var elementSelectorParameter = elementSelectorLambda.Parameters.Single(); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var resultSelectorKeyParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorElementsParameter = resultSelectorLambda.Parameters[1]; + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + DeduceItemAndCollectionSerializers(elementSelectorParameter, sourceExpression); + DeduceSerializers(resultSelectorKeyParameter, keySelectorLambda.Body); + DeduceCollectionAndItemSerializers(resultSelectorElementsParameter, elementSelectorLambda.Body); + DeduceResultSerializer(resultSelectorLambda.Body); + } + + void DeduceResultSerializer(Expression resultExpression) + { + if (IsNotKnown(node) && IsKnown(resultExpression, out var resultSerializer)) + { + var nodeSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceGroupJoinMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.GroupJoin, QueryableMethod.GroupJoin)) + { + var outerExpression = arguments[0]; + var innerExpression = arguments[1]; + var outerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var outerKeySelectorItemParameter = outerKeySelectorLambda.Parameters.Single(); + var innerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var innerKeySelectorItemParameter = innerKeySelectorLambda.Parameters.Single(); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var resultSelectorOuterItemParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorInnerItemsParameter = resultSelectorLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(outerKeySelectorItemParameter, outerExpression); + DeduceItemAndCollectionSerializers(innerKeySelectorItemParameter, innerExpression); + DeduceItemAndCollectionSerializers(resultSelectorOuterItemParameter, outerExpression); + DeduceCollectionAndCollectionSerializers(resultSelectorInnerItemsParameter, innerExpression); + DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIndexOfMethodSerializers() + { + if (method.IsOneOf(StringMethod.IndexOfOverloads)) + { + DeduceReturnsInt32Serializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceInjectMethodSerializers() + { + if (method.Is(LinqExtensionsMethod.Inject)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIntersectMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Intersect, QueryableMethod.Intersect)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsMatchMethodSerializers() + { + if (method.Is(RegexMethod.StaticIsMatch)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsMissingOrIsNullOrMissingMethodSerializers() + { + if (method.IsOneOf(MqlMethod.IsMissing, MqlMethod.IsNullOrMissing)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsSubsetOfMethodSerializers() + { + if (IsSubsetOfMethod(method)) + { + var objectExpression = node.Object; + var otherExpression = arguments[0]; + + DeduceCollectionAndCollectionSerializers(objectExpression, otherExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + + static bool IsSubsetOfMethod(MethodInfo method) + { + var declaringType = method.DeclaringType; + var parameters = method.GetParameters(); + return + method.IsPublic && + method.IsStatic == false && + method.ReturnType == typeof(bool) && + method.Name == "IsSubsetOf" && + parameters.Length == 1 && + parameters[0] is var otherParameter && + declaringType.ImplementsIEnumerable(out var declaringTypeItemType) && + otherParameter.ParameterType.ImplementsIEnumerable(out var otherTypeItemType) && + otherTypeItemType == declaringTypeItemType; + } + } + + void DeduceJoinMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Join, QueryableMethod.Join)) + { + var outerExpression = arguments[0]; + var innerExpression = arguments[1]; + var outerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var outerKeySelectorItemParameter = outerKeySelectorLambda.Parameters.Single(); + var innerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var innerKeySelectorItemParameter = innerKeySelectorLambda.Parameters.Single(); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var resultSelectorOuterItemParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorInnerItemsParameter = resultSelectorLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(outerKeySelectorItemParameter, outerExpression); + DeduceItemAndCollectionSerializers(innerKeySelectorItemParameter, innerExpression); + DeduceItemAndCollectionSerializers(resultSelectorOuterItemParameter, outerExpression); + DeduceItemAndCollectionSerializers(resultSelectorInnerItemsParameter, innerExpression); + DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceIsNullOrEmptyOrIsNullOrWhiteSpaceMethodSerializers() + { + if (method.IsOneOf(StringMethod.IsNullOrEmpty, StringMethod.IsNullOrWhiteSpace)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceLogMethodSerializers() + { + if (method.IsOneOf(MathMethod.LogOverloads)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceLookupMethodSerializers() + { + if (method.IsOneOf(MongoQueryableMethod.LookupOverloads)) + { + var sourceExpression = arguments[0]; + + if (method.Is(MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignField)) + { + var documentsLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var documentsLambdaParameter = documentsLambda.Parameters.Single(); + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(documentsLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(foreignFieldLambdaParameter, documentsLambda.Body); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(documentsLambda.Body, out var documentSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, documentSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline)) + { + var documentsLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var documentsLambdaParameter = documentsLambda.Parameters.Single(); + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var pipelineLambdaLocalParameter = pipelineLambda.Parameters[0]; + var pipelineLambdaForeignQueryableParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(documentsLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(foreignFieldLambdaParameter, documentsLambda.Body); + DeduceItemAndCollectionSerializers(pipelineLambdaLocalParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(pipelineLambdaForeignQueryableParameter, documentsLambda.Body); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineDocumentSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineDocumentSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithDocumentsAndPipeline)) + { + var documentsLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var documentsLambdaParameter = documentsLambda.Parameters.Single(); + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var pipelineLambdaSourceParameter = pipelineLambda.Parameters[0]; + var pipelineLambdaQueryableDocumentParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(documentsLambdaParameter, sourceExpression); + DeduceItemAndCollectionSerializers(pipelineLambdaSourceParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(pipelineLambdaQueryableDocumentParameter, documentsLambda.Body); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineItemSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineItemSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + + if (method.Is(MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignField)) + { + var fromExpression = arguments[1]; + var fromCollection = fromExpression.GetConstantValue(node); + var foreignDocumentSerializer = fromCollection.DocumentSerializer; + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceSerializer(foreignFieldLambdaParameter, foreignDocumentSerializer); + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, foreignDocumentSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignFieldAndPipeline)) + { + var fromExpression = arguments[1]; + var fromCollection = fromExpression.GetConstantValue(node); + var foreignDocumentSerializer = fromCollection.DocumentSerializer; + var localFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var localFieldLambdaParameter = localFieldLambda.Parameters.Single(); + var foreignFieldLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); + var foreignFieldLambdaParameter = foreignFieldLambda.Parameters.Single(); + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]); + var pipelineLambdaLocalParameter = pipelineLambda.Parameters[0]; + var pipelineLamdbaForeignQueryableParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(localFieldLambdaParameter, sourceExpression); + DeduceSerializer(foreignFieldLambdaParameter, foreignDocumentSerializer); + DeduceItemAndCollectionSerializers(pipelineLambdaLocalParameter, sourceExpression); + + if (IsNotKnown(pipelineLamdbaForeignQueryableParameter)) + { + var foreignQueryableSerializer = IQueryableSerializer.Create(foreignDocumentSerializer); + AddNodeSerializer(pipelineLamdbaForeignQueryableParameter, foreignQueryableSerializer); + } + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineItemSerializer)) + { + var lookupResultsSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineItemSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultsSerializer)); + } + } + else if (method.Is(MongoQueryableMethod.LookupWithFromAndPipeline)) + { + var fromCollection = arguments[1].GetConstantValue(node); + var foreignDocumentSerializer = fromCollection.DocumentSerializer; + var pipelineLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var pipelineLambdaLocalParameter = pipelineLambda.Parameters[0]; + var pipelineLamdbaForeignQueryableParameter = pipelineLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(pipelineLambdaLocalParameter, sourceExpression); + + if (IsNotKnown(pipelineLamdbaForeignQueryableParameter)) + { + var foreignQueryableSerializer = IQueryableSerializer.Create(foreignDocumentSerializer); + AddNodeSerializer(pipelineLamdbaForeignQueryableParameter, foreignQueryableSerializer); + } + + if (IsNotKnown(node) && + IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer) && + IsItemSerializerKnown(pipelineLambda.Body, out var pipelineItemSerializer)) + { + var lookupResultSerializer = LookupResultSerializer.Create(sourceItemSerializer, pipelineItemSerializer); + AddNodeSerializer(node, IQueryableSerializer.Create(lookupResultSerializer)); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceMatchingElementsMethodSerializers() + { + if (method.IsOneOf(MongoEnumerableMethod.AllElements, MongoEnumerableMethod.AllMatchingElements, MongoEnumerableMethod.FirstMatchingElement)) + { + DeduceReturnsOneSourceItemSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceMaxOrMinMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.MaxOrMinOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.MaxOrMinWithSelectorOverloads)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorItemParameter = selectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(selectorItemParameter, sourceExpression); + DeduceSerializers(node, selectorLambda.Body); + } + else + { + DeduceReturnsOneSourceItemSerializer(); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceOfTypeMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.OfType, QueryableMethod.OfType)) + { + var sourceExpression = arguments[0]; + var resultType = method.GetGenericArguments()[0]; + + if (IsNotKnown(node) && IsItemSerializerKnown(sourceExpression, out var sourceItemSerializer)) + { + var resultItemSerializer = sourceItemSerializer.GetDerivedTypeSerializer(resultType); + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceOrderByMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.OrderByOrThenByOverloads)) + { + var sourceExpression = arguments[0]; + var keySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var keySelectorParameter = keySelectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(keySelectorParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeducePickMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.PickOverloads)) + { + if (method.IsOneOf(EnumerableMethod.PickWithSortByOverloads)) + { + var sortByExpression = arguments[1]; + if (IsNotKnown(sortByExpression)) + { + var ignoreSubTreeSerializer = IgnoreSubtreeSerializer.Create(sortByExpression.Type); + AddNodeSerializer(sortByExpression, ignoreSubTreeSerializer); + } + } + + var sourceExpression = arguments[0]; + if (IsKnown(sourceExpression, out var sourceSerializer)) + { + var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer); + + var selectorExpression = arguments[method.IsOneOf(EnumerableMethod.PickWithSortByOverloads) ? 2 : 1]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, selectorExpression); + var selectorSourceItemParameter = selectorLambda.Parameters.Single(); + if (IsNotKnown(selectorSourceItemParameter)) + { + AddNodeSerializer(selectorSourceItemParameter, sourceItemSerializer); + } + } + + if (method.IsOneOf(EnumerableMethod.PickWithComputedNOverloads)) + { + var keyExpression = arguments[method.IsOneOf(EnumerableMethod.PickWithSortByOverloads) ? 3 : 2]; + if (IsKnown(keyExpression, out var keySerializer)) + { + var nExpression = arguments[method.IsOneOf(EnumerableMethod.PickWithSortByOverloads) ? 4 : 3]; + var nLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, nExpression); + var nLambdaKeyParameter = nLambda.Parameters.Single(); + + if (IsNotKnown(nLambdaKeyParameter)) + { + AddNodeSerializer(nLambdaKeyParameter, keySerializer); + } + } + } + + if (IsNotKnown(node)) + { + var selectorExpressionIndex = method switch + { + _ when method.Is(EnumerableMethod.Bottom) => 2, + _ when method.Is(EnumerableMethod.BottomN) => 2, + _ when method.Is(EnumerableMethod.BottomNWithComputedN) => 2, + _ when method.Is(EnumerableMethod.FirstN) => 1, + _ when method.Is(EnumerableMethod.FirstNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.LastN) => 1, + _ when method.Is(EnumerableMethod.LastNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.MaxN) => 1, + _ when method.Is(EnumerableMethod.MaxNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.MinN) => 1, + _ when method.Is(EnumerableMethod.MinNWithComputedN) => 1, + _ when method.Is(EnumerableMethod.Top) => 2, + _ when method.Is(EnumerableMethod.TopN) => 2, + _ when method.Is(EnumerableMethod.TopNWithComputedN) => 2, + _ => throw new ArgumentException($"Unrecognized method: {method.Name}.") + }; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[selectorExpressionIndex]); + + if (IsKnown(selectorLambda.Body, out var selectorItemSerializer)) + { + var nodeSerializer = method.IsOneOf(EnumerableMethod.Bottom, EnumerableMethod.Top) ? + selectorItemSerializer : + IEnumerableSerializer.Create(selectorItemSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceParseMethodSerializers() + { + if (IsNotKnown(node)) + { + if (IsParseMethod(method)) + { + var nodeSerializer = GetParseResultSerializer(method.DeclaringType); + AddNodeSerializer(node, nodeSerializer); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + static bool IsParseMethod(MethodInfo method) + { + var parameters = method.GetParameters(); + return + method.IsPublic && + method.IsStatic && + method.ReturnType == method.DeclaringType && + parameters.Length == 1 && + parameters[0].ParameterType == typeof(string); + } + + static IBsonSerializer GetParseResultSerializer(Type declaringType) + { + return declaringType switch + { + _ when declaringType == typeof(DateTime) => DateTimeSerializer.Instance, + _ when declaringType == typeof(decimal) => DecimalSerializer.Instance, + _ when declaringType == typeof(double) => DoubleSerializer.Instance, + _ when declaringType == typeof(int) => Int32Serializer.Instance, + _ when declaringType == typeof(short) => Int64Serializer.Instance, + _ => UnknowableSerializer.Create(declaringType) + }; + } + } + + void DeducePowMethodSerializers() + { + if (method.Is(MathMethod.Pow)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceRadiansToDegreesMethodSerializers() + { + if (method.Is(MongoDBMathMethod.RadiansToDegrees)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceReturnsBooleanSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, BooleanSerializer.Instance); + } + } + + void DeduceReturnsDateTimeSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, DateTimeSerializer.UtcInstance); + } + } + + void DeduceReturnsDecimalSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, DecimalSerializer.Instance); + } + } + + void DeduceReturnsDoubleSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, DoubleSerializer.Instance); + } + } + + void DeduceReturnsInt32Serializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, Int32Serializer.Instance); + } + } + + void DeduceReturnsInt64Serializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, Int64Serializer.Instance); + } + } + + void DeduceReturnsNullableDecimalSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableDecimalInstance); + } + } + + void DeduceReturnsNullableDoubleSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableDoubleInstance); + } + } + + void DeduceReturnsNullableInt32Serializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableInt32Instance); + } + } + + void DeduceReturnsNullableInt64Serializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableInt64Instance); + } + } + + void DeduceReturnsNullableSingleSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, NullableSerializer.NullableSingleInstance); + } + } + + void DeduceReturnsNumericSerializer() + { + if (IsNotKnown(node) && node.Type.IsNumeric()) + { + var numericSerializer = StandardSerializers.GetSerializer(node.Type); + AddNodeSerializer(node, numericSerializer); + } + } + + void DeduceReturnsNumericOrNullableNumericSerializer() + { + if (IsNotKnown(node) && node.Type.IsNumericOrNullableNumeric()) + { + var numericSerializer = StandardSerializers.GetSerializer(node.Type); + AddNodeSerializer(node, numericSerializer); + } + } + + void DeduceReturnsOneSourceItemSerializer() + { + var sourceExpression = arguments[0]; + + if (IsNotKnown(node) && IsKnown(sourceExpression, out var sourceSerializer)) + { + var nodeSerializer = sourceSerializer is IUnknowableSerializer ? + UnknowableSerializer.Create(node.Type) : + ArraySerializerHelper.GetItemSerializer(sourceSerializer); + AddNodeSerializer(node, nodeSerializer); + } + } + + void DeduceReturnsSingleSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, SingleSerializer.Instance); + } + } + + void DeduceReturnsStringSerializer() + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, StringSerializer.Instance); + } + } + + void DeduceReturnsTimeSpanSerializer(TimeSpanUnits units) + { + if (IsNotKnown(node)) + { + var resultSerializer = new TimeSpanSerializer(BsonType.Int64, units); + AddNodeSerializer(node, resultSerializer); + } + } + + void DeduceRangeMethodSerializers() + { + if (method.Is(EnumerableMethod.Range)) + { + var elementExpression = arguments[0]; + DeduceCollectionAndItemSerializers(node, elementExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceRepeatMethodSerializers() + { + if (method.Is(EnumerableMethod.Repeat)) + { + var elementExpression = arguments[0]; + DeduceCollectionAndItemSerializers(node, elementExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceReverseMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Reverse, QueryableMethod.Reverse)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceRoundMethodSerializers() + { + if (method.IsOneOf(MathMethod.RoundWithDecimal, MathMethod.RoundWithDecimalAndDecimals, MathMethod.RoundWithDouble, MathMethod.RoundWithDoubleAndDigits)) + { + if (IsNotKnown(node)) + { + var resultSerializer = StandardSerializers.GetSerializer(node.Type); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSelectMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.Select, QueryableMethod.Select)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorParameter = selectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(selectorParameter, sourceExpression); + DeduceCollectionAndItemSerializers(node, selectorLambda.Body); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSelectManySerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyOverloads)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyWithSelector)) + { + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorSourceParameter = selectorLambda.Parameters.Single(); + + DeduceItemAndCollectionSerializers(selectorSourceParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, selectorLambda.Body); + } + + if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyWithCollectionSelectorAndResultSelector)) + { + var collectionSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + + var collectionSelectorSourceItemParameter = collectionSelectorLambda.Parameters.Single(); + var resultSelectorSourceItemParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorCollectionItemParameter = resultSelectorLambda.Parameters[1]; + + DeduceItemAndCollectionSerializers(collectionSelectorSourceItemParameter, sourceExpression); + DeduceItemAndCollectionSerializers(resultSelectorSourceItemParameter, sourceExpression); + DeduceItemAndCollectionSerializers(resultSelectorCollectionItemParameter, collectionSelectorLambda.Body); + DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSequenceEqualMethodSerializers() + { + if (method.IsOneOf(EnumerableMethod.SequenceEqual, QueryableMethod.SequenceEqual)) + { + var source1Expression = arguments[0]; + var source2Expression = arguments[1]; + + DeduceCollectionAndCollectionSerializers(source1Expression, source2Expression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSetEqualsMethodSerializers() + { + if (ISetMethod.IsSetEqualsMethod(method)) + { + var objectExpression = node.Object; + var otherExpression = arguments[0]; + + DeduceCollectionAndCollectionSerializers(objectExpression, otherExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSetWindowFieldsMethodSerializers() + { + if (method.Is(EnumerableMethod.First)) + { + var objectExpression = node.Object; + var otherExpression = arguments[0]; + + DeduceCollectionAndCollectionSerializers(objectExpression, otherExpression); + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceShiftMethodSerializers() + { + if (method.IsOneOf(WindowMethod.Shift, WindowMethod.ShiftWithDefaultValue)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorSourceItemParameter = selectorLambda.Parameters[0]; + + DeduceItemAndCollectionSerializers(selectorSourceItemParameter, sourceExpression); + + if (IsNotKnown(node) && IsKnown(selectorLambda.Body, out var resultSerializer)) + { + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSplitMethodSerializers() + { + if (method.IsOneOf(StringMethod.SplitOverloads)) + { + if (IsNotKnown(node)) + { + var nodeSerializer = ArraySerializer.Create(StringSerializer.Instance); + AddNodeSerializer(node, nodeSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSqrtMethodSerializers() + { + if (method.Is(MathMethod.Sqrt)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceStandardDeviationMethodSerializers() + { + if (method.IsOneOf(MongoEnumerableMethod.StandardDeviationOverloads)) + { + if (method.IsOneOf(MongoEnumerableMethod.StandardDeviationWithSelectorOverloads)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorItemParameter = selectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(selectorItemParameter, sourceExpression); + } + + DeduceReturnsNumericOrNullableNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceEndsWithOrStartsWithMethodSerializers() + { + if (method.IsOneOf(StringMethod.EndsWithOrStartsWithOverloads)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceStringInMethodSerializers() + { + if (method.IsOneOf(StringMethod.StringInWithEnumerable, StringMethod.StringInWithParams)) + { + DeduceReturnsBooleanSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceStrLenBytesMethodSerializers() + { + if (method.Is(StringMethod.StrLenBytes)) + { + DeduceReturnsInt32Serializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSubstringMethodSerializers() + { + if (method.IsOneOf(StringMethod.Substring, StringMethod.SubstringWithLength, StringMethod.SubstrBytes)) + { + DeduceReturnsStringSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSubtractMethodSerializers() + { + if (method.IsOneOf(DateTimeMethod.SubtractReturningDateTimeOverloads)) + { + DeduceReturnsDateTimeSerializer(); + } + else if (method.IsOneOf(DateTimeMethod.SubtractReturningInt64Overloads)) + { + DeduceReturnsInt64Serializer(); + } + else if (method.IsOneOf(DateTimeMethod.SubtractReturningTimeSpanWithMillisecondsUnitsOverloads)) + { + var units = TimeSpanUnits.Milliseconds; + DeduceReturnsTimeSpanSerializer(units); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSumMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.SumOverloads)) + { + if (method.IsOneOf(EnumerableOrQueryableMethod.SumWithSelectorOverloads)) + { + var sourceExpression = arguments[0]; + var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var selectorParameter = selectorLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(selectorParameter, sourceExpression); + } + + var returnType = node.Type; + switch (returnType) + { + case not null when returnType == typeof(decimal): DeduceReturnsDecimalSerializer(); break; + case not null when returnType == typeof(double): DeduceReturnsDoubleSerializer(); break; + case not null when returnType == typeof(int): DeduceReturnsInt32Serializer(); break; + case not null when returnType == typeof(long): DeduceReturnsInt64Serializer(); break; + case not null when returnType == typeof(float): DeduceReturnsSingleSerializer(); break; + case not null when returnType == typeof(decimal?): DeduceReturnsNullableDecimalSerializer(); break; + case not null when returnType == typeof(double?): DeduceReturnsNullableDoubleSerializer(); break; + case not null when returnType == typeof(int?): DeduceReturnsNullableInt32Serializer(); break; + case not null when returnType == typeof(long?): DeduceReturnsNullableInt64Serializer(); break; + case not null when returnType == typeof(float?): DeduceReturnsNullableSingleSerializer(); break; + + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceSkipOrTakeMethodSerializers() + { + if (method.IsOneOf(EnumerableOrQueryableMethod.SkipOrTakeOverloads)) + { + var sourceExpression = arguments[0]; + + if (method.IsOneOf(EnumerableOrQueryableMethod.SkipWhileOrTakeWhile)) + { + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + } + + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceToArrayMethodSerializers() + { + if (IsToArrayMethod(out var sourceExpression)) + { + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + + bool IsToArrayMethod(out Expression sourceExpression) + { + if (method.IsPublic && + method.Name == "ToArray" && + method.GetParameters().Length == (method.IsStatic ? 1 : 0)) + { + sourceExpression = method.IsStatic ? arguments[0] : node.Object; + return true; + } + + sourceExpression = null; + return false; + } + } + + void DeduceToListSerializers() + { + if (IsNotKnown(node)) + { + var source = method.IsStatic ? arguments[0] : node.Object; + if (IsKnown(source, out var sourceSerializer)) + { + var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer); + var resultSerializer = ListSerializer.Create(sourceItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + } + + void DeduceToLowerOrToUpperSerializers() + { + if (method.IsOneOf(StringMethod.ToLowerOrToUpperOverloads)) + { + DeduceReturnsStringSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceToStringSerializers() + { + DeduceReturnsStringSerializer(); + } + + void DeduceTrigonometricMethodSerializers() + { + if (method.IsOneOf(MathMethod.TrigonometricMethods)) + { + DeduceReturnsDoubleSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceTruncateSerializers() + { + if (method.IsOneOf(DateTimeMethod.Truncate, DateTimeMethod.TruncateWithBinSize, DateTimeMethod.TruncateWithBinSizeAndTimezone)) + { + DeduceReturnsDateTimeSerializer(); + } + else if (method.IsOneOf(MathMethod.TruncateDecimal, MathMethod.TruncateDouble)) + { + DeduceReturnsNumericSerializer(); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceUnionSerializers() + { + if (method.IsOneOf(EnumerableMethod.Union, QueryableMethod.Union)) + { + var sourceExpression = arguments[0]; + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceUnknownMethodSerializer() + { + DeduceUnknowableSerializer(node); + } + + void DeduceWeekSerializers() + { + if (method.IsOneOf(DateTimeMethod.Week, DateTimeMethod.WeekWithTimezone)) + { + if (IsNotKnown(node)) + { + AddNodeSerializer(node, Int32Serializer.Instance); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceWhereSerializers() + { + if (method.IsOneOf(__whereOverloads)) + { + var sourceExpression = arguments[0]; + var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); + var predicateParameter = predicateLambda.Parameters.Single(); + DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression); + DeduceCollectionAndCollectionSerializers(node, sourceExpression); + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + void DeduceZipSerializers() + { + if (method.IsOneOf(EnumerableMethod.Zip, QueryableMethod.Zip)) + { + var firstExpression = arguments[0]; + var secondExpression = arguments[1]; + var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); + var resultSelectorFirstParameter = resultSelectorLambda.Parameters[0]; + var resultSelectorSecondParameter = resultSelectorLambda.Parameters[1]; + + if (IsNotKnown(resultSelectorFirstParameter) && IsKnown(firstExpression, out var firstSerializer)) + { + var firstItemSerializer = ArraySerializerHelper.GetItemSerializer(firstSerializer); + AddNodeSerializer(resultSelectorFirstParameter, firstItemSerializer); + } + + if (IsNotKnown(resultSelectorSecondParameter) && IsKnown(secondExpression, out var secondSerializer)) + { + var secondItemSerializer = ArraySerializerHelper.GetItemSerializer(secondSerializer); + AddNodeSerializer(resultSelectorSecondParameter, secondItemSerializer); + } + + if (IsNotKnown(node) && IsKnown(resultSelectorLambda.Body, out var resultItemSerializer)) + { + var resultSerializer = IEnumerableOrIQueryableSerializer.Create(node.Type, resultItemSerializer); + AddNodeSerializer(node, resultSerializer); + } + } + else + { + DeduceUnknownMethodSerializer(); + } + } + + bool IsDictionaryContainsKeyExpression(out Expression keyExpression) + { + if (DictionaryMethod.IsContainsKeyMethod(method)) + { + keyExpression = arguments[0]; + return true; + } + + keyExpression = null; + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNew.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNew.cs new file mode 100644 index 00000000000..913ef139d13 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNew.cs @@ -0,0 +1,140 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitNew(NewExpression node) + { + var constructor = node.Constructor; + var arguments = node.Arguments; + IBsonSerializer nodeSerializer; + + if (IsKnown(node, out nodeSerializer) && + arguments.Any(IsNotKnown)) + { + if (!typeof(BsonValue).IsAssignableFrom(node.Type) && + nodeSerializer is IBsonDocumentSerializer) + { + var matchingMemberSerializationInfos = nodeSerializer.GetMatchingMemberSerializationInfosForConstructorParameters(node, node.Constructor); + for (var i = 0; i < matchingMemberSerializationInfos.Count; i++) + { + var argument = arguments[i]; + var matchingMemberSerializationInfo = matchingMemberSerializationInfos[i]; + + if (IsNotKnown(argument)) + { + // arg => arg: matchingMemberSerializer + AddNodeSerializer(argument, matchingMemberSerializationInfo.Serializer); + } + } + } + } + + base.VisitNew(node); + + if (IsNotKnown(node)) + { + nodeSerializer = CreateSerializer(constructor); + if (nodeSerializer != null) + { + AddNodeSerializer(node, nodeSerializer); + } + } + + return node; + + IBsonSerializer CreateSerializer(ConstructorInfo constructor) + { + if (constructor == null) + { + return CreateNewExpressionSerializer(node, node, bindings: null); + } + else if (constructor.DeclaringType == typeof(BsonDocument)) + { + return BsonDocumentSerializer.Instance; + } + else if (constructor.DeclaringType == typeof(BsonValue)) + { + return BsonValueSerializer.Instance; + } + else if (constructor.DeclaringType == typeof(DateTime)) + { + return DateTimeSerializer.Instance; + } + else if (DictionaryConstructor.IsWithIEnumerableKeyValuePairConstructor(constructor)) + { + var collectionExpression = arguments[0]; + if (IsItemSerializerKnown(collectionExpression, out var itemSerializer) && + itemSerializer.IsKeyValuePairSerializer(out _, out _, out var keySerializer, out var valueSerializer)) + { + return DictionarySerializer.Create(DictionaryRepresentation.Document, keySerializer, valueSerializer); + } + } + else if (HashSetConstructor.IsWithCollectionConstructor(constructor)) + { + var collectionExpression = arguments[0]; + if (IsItemSerializerKnown(collectionExpression, out var itemSerializer)) + { + return HashSetSerializer.Create(itemSerializer); + } + } + else if (ListConstructor.IsWithCollectionConstructor(constructor)) + { + var collectionExpression = arguments[0]; + if (IsItemSerializerKnown(collectionExpression, out var itemSerializer)) + { + return ListSerializer.Create(itemSerializer); + } + } + else if (KeyValuePairConstructor.IsWithKeyAndValueConstructor(constructor)) + { + var key = arguments[0]; + var value = arguments[1]; + if (IsKnown(key, out var keySerializer) && + IsKnown(value, out var valueSerializer)) + { + return KeyValuePairSerializer.Create(BsonType.Document, keySerializer, valueSerializer); + } + } + else if (TupleOrValueTupleConstructor.IsTupleOrValueTupleConstructor(constructor)) + { + if (AreAllKnown(arguments, out var argumentSerializers)) + { + return TupleOrValueTupleSerializer.Create(constructor.DeclaringType, argumentSerializers); + } + } + else + { + return CreateNewExpressionSerializer(node, node, bindings: null); + } + + return null; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNewArray.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNewArray.cs new file mode 100644 index 00000000000..5c9d30d8946 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitNewArray.cs @@ -0,0 +1,146 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitNewArray(NewArrayExpression node) + { + DeduceNewArraySerializers(); + base.VisitNewArray(node); + DeduceNewArraySerializers(); + + return node; + + void DeduceNewArraySerializers() + { + switch (node.NodeType) + { + case ExpressionType.NewArrayBounds: + DeduceNewArrayBoundsSerializers(); + break; + + case ExpressionType.NewArrayInit: + DeduceNewArrayInitSerializers(); + break; + } + } + + void DeduceNewArrayBoundsSerializers() + { + throw new NotImplementedException(); + } + + void DeduceNewArrayInitSerializers() + { + var itemExpressions = node.Expressions; + IBsonSerializer itemSerializer; + + if (IsAnyNotKnown(itemExpressions) && IsKnown(node, out var arraySerializer)) + { + if (arraySerializer is IPolymorphicArraySerializer polymorphicArraySerializer) + { + for (var i = 0; i < itemExpressions.Count; i++) + { + var itemExpression = itemExpressions[i]; + if (IsNotKnown(itemExpression)) + { + itemSerializer = polymorphicArraySerializer.GetItemSerializer(i); + AddNodeSerializer(itemExpression, itemSerializer); + } + } + } + else + { + itemSerializer = arraySerializer.GetItemSerializer(); + foreach (var itemExpression in itemExpressions) + { + if (IsNotKnown(itemExpression)) + { + AddNodeSerializer(itemExpression, itemSerializer); + } + } + } + } + + if (IsAnyNotKnown(itemExpressions) && IsAnyKnown(itemExpressions, out itemSerializer)) + { + var firstItemType = itemExpressions.First().Type; + if (itemExpressions.All(e => e.Type == firstItemType)) + { + foreach (var itemExpression in itemExpressions) + { + if (IsNotKnown(itemExpression)) + { + AddNodeSerializer(itemExpression, itemSerializer); + } + } + } + } + + if (IsNotKnown(node)) + { + if (AreAllKnown(itemExpressions, out var itemSerializers)) + { + if (AllItemSerializersAreEqual(itemSerializers, out itemSerializer)) + { + arraySerializer = ArraySerializer.Create(itemSerializer); + } + else + { + var itemType = node.Type.GetElementType(); + arraySerializer = PolymorphicArraySerializer.Create(itemType, itemSerializers); + } + AddNodeSerializer(node, arraySerializer); + } + } + + static bool AllItemSerializersAreEqual(IReadOnlyList itemSerializers, out IBsonSerializer itemSerializer) + { + switch (itemSerializers.Count) + { + case 0: + itemSerializer = null; + return false; + case 1: + itemSerializer = itemSerializers[0]; + return true; + default: + var firstItemSerializer = itemSerializers[0]; + if (itemSerializers.Skip(1).All(s => s.Equals(firstItemSerializer))) + { + itemSerializer = firstItemSerializer; + return true; + } + else + { + itemSerializer = null; + return false; + } + } + } + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitTypeBinary.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitTypeBinary.cs new file mode 100644 index 00000000000..40ec74177ab --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitTypeBinary.cs @@ -0,0 +1,30 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitTypeBinary(TypeBinaryExpression node) + { + base.VisitTypeBinary(node); + + DeduceBooleanSerializer(node); + + return node; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitUnary.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitUnary.cs new file mode 100644 index 00000000000..96418e73305 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitUnary.cs @@ -0,0 +1,306 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor +{ + protected override Expression VisitUnary(UnaryExpression node) + { + var unaryOperator = node.NodeType; + var operand = node.Operand; + + base.VisitUnary(node); + + switch (unaryOperator) + { + case ExpressionType.Negate: + DeduceNegateSerializers(); // TODO: fold into general case? + break; + + default: + DeduceUnaryOperatorSerializers(); + break; + } + + return node; + + void DeduceNegateSerializers() + { + DeduceSerializers(node, operand); + } + + void DeduceUnaryOperatorSerializers() + { + if (IsNotKnown(node)) + { + var resultSerializer = unaryOperator switch + { + ExpressionType.ArrayLength => Int32Serializer.Instance, + ExpressionType.Convert or ExpressionType.TypeAs => GetConvertSerializer(), + ExpressionType.Not => StandardSerializers.GetSerializer(node.Type), + ExpressionType.Quote => IgnoreNodeSerializer.Create(node.Type), + _ => null + }; + + if (resultSerializer != null) + { + AddNodeSerializer(node, resultSerializer); + } + } + } + + IBsonSerializer GetConvertSerializer() + { + var sourceType = operand.Type; + var targetType = node.Type; + + // handle double conversion (BsonValue)(object)x + if (targetType == typeof(BsonValue) && + operand is UnaryExpression unarySourceExpression && + unarySourceExpression.NodeType == ExpressionType.Convert && + unarySourceExpression.Type == typeof(object)) + { + operand = unarySourceExpression.Operand; + } + + if (IsKnown(operand, out var sourceSerializer)) + { + return GetTargetSerializer(node, sourceType, targetType, sourceSerializer); + } + + return null; + + static IBsonSerializer GetTargetSerializer(UnaryExpression node, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (targetType == sourceType) + { + return sourceSerializer; + } + + // handle conversion to BsonValue before any others + if (targetType == typeof(BsonValue)) + { + return GetConvertToBsonValueSerializer(node, sourceSerializer); + } + + // from Nullable must be handled before to Nullable + if (IsConvertFromNullableType(sourceType)) + { + return GetConvertFromNullableTypeSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertToNullableType(targetType, out var valueType)) + { + var valueSerializer = valueType == targetType ? sourceSerializer : GetTargetSerializer(node, sourceType, valueType, sourceSerializer); + return valueSerializer != null ? GetConvertToNullableTypeSerializer(node, sourceType, targetType, valueSerializer) : null; + } + + // from here on we know there are no longer any Nullable types involved + + if (sourceType == typeof(BsonValue)) + { + return GetConvertFromBsonValueSerializer(node, targetType); + } + + if (IsConvertEnumToUnderlyingType(sourceType, targetType)) + { + return GetConvertEnumToUnderlyingTypeSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertUnderlyingTypeToEnum(sourceType, targetType)) + { + return GetConvertUnderlyingTypeToEnumSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertEnumToEnum(sourceType, targetType)) + { + return GetConvertEnumToEnumSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertToBaseType(sourceType, targetType)) + { + return GetConvertToBaseTypeSerializer(node, sourceType, targetType, sourceSerializer); + } + + if (IsConvertToDerivedType(sourceType, targetType)) + { + return GetConvertToDerivedTypeSerializer(node, targetType, sourceSerializer); + } + + if (IsNumericConversion(sourceType, targetType)) + { + return GetNumericConversionSerializer(node, sourceType, targetType, sourceSerializer); + } + + return null; + } + + static IBsonSerializer GetConvertFromBsonValueSerializer(UnaryExpression expression, Type targetType) + { + return targetType switch + { + _ when targetType == typeof(string) => StringSerializer.Instance, + _ => throw new ExpressionNotSupportedException(expression, because: $"conversion from BsonValue to {targetType} is not supported") + }; + } + + static IBsonSerializer GetConvertToBaseTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + var derivedTypeSerializer = sourceSerializer; + return DowncastingSerializer.Create(targetType, sourceType, derivedTypeSerializer); + } + + static IBsonSerializer GetConvertToDerivedTypeSerializer(UnaryExpression expression, Type targetType, IBsonSerializer sourceSerializer) + { + var derivedTypeSerializer = sourceSerializer.GetDerivedTypeSerializer(targetType); + return derivedTypeSerializer; + } + + static IBsonSerializer GetConvertToBsonValueSerializer(UnaryExpression expression, IBsonSerializer sourceSerializer) + { + return BsonValueSerializer.Instance; + } + + static IBsonSerializer GetConvertEnumToEnumSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (!sourceType.IsEnum) + { + throw new ExpressionNotSupportedException(expression, because: "source type is not an enum"); + } + if (!targetType.IsEnum) + { + throw new ExpressionNotSupportedException(expression, because: "target type is not an enum"); + } + + return EnumSerializer.Create(targetType); + } + + static IBsonSerializer GetConvertEnumToUnderlyingTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + var enumSerializer = sourceSerializer; + return AsEnumUnderlyingTypeSerializer.Create(enumSerializer); + } + + static IBsonSerializer GetConvertFromNullableTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (sourceSerializer is not INullableSerializer nullableSourceSerializer) + { + throw new ExpressionNotSupportedException(expression, because: $"sourceSerializer type {sourceSerializer.GetType()} does not implement nameof(INullableSerializer)"); + } + + var sourceValueSerializer = nullableSourceSerializer.ValueSerializer; + var sourceValueType = sourceValueSerializer.ValueType; + + if (targetType.IsNullable(out var targetValueType)) + { + var targetValueSerializer = GetTargetSerializer(expression, sourceValueType, targetValueType, sourceValueSerializer); + return NullableSerializer.Create(targetValueSerializer); + } + else + { + return GetTargetSerializer(expression, sourceValueType, targetType, sourceValueSerializer); + } + } + + static IBsonSerializer GetConvertToNullableTypeSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + if (sourceType.IsNullable()) + { + throw new ExpressionNotSupportedException(expression, because: "sourceType is already nullable"); + } + + if (targetType.IsNullable()) + { + return NullableSerializer.Create(sourceSerializer); + } + + throw new ExpressionNotSupportedException(expression, because: "targetType is not nullable"); + } + + static IBsonSerializer GetConvertUnderlyingTypeToEnumSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + IBsonSerializer targetSerializer; + if (sourceSerializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + { + targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer; + } + else + { + targetSerializer = EnumSerializer.Create(targetType); + } + + return targetSerializer; + } + + static IBsonSerializer GetNumericConversionSerializer(UnaryExpression expression, Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + return NumericConversionSerializer.Create(sourceType, targetType, sourceSerializer); + } + + static bool IsConvertEnumToEnum(Type sourceType, Type targetType) + { + return sourceType.IsEnum && targetType.IsEnum; + } + + static bool IsConvertEnumToUnderlyingType(Type sourceType, Type targetType) + { + return + sourceType.IsEnum(out var underlyingType) && + targetType == underlyingType; + } + + static bool IsConvertFromNullableType(Type sourceType) + { + return sourceType.IsNullable(); + } + + static bool IsConvertToBaseType(Type sourceType, Type targetType) + { + return sourceType.IsSubclassOf(targetType) || sourceType.ImplementsInterface(targetType); + } + + static bool IsConvertToDerivedType(Type sourceType, Type targetType) + { + return sourceType.IsAssignableFrom(targetType); // targetType either derives from sourceType or implements sourceType interface + } + + static bool IsConvertToNullableType(Type targetType, out Type valueType) + { + return targetType.IsNullable(out valueType); + } + + static bool IsConvertUnderlyingTypeToEnum(Type sourceType, Type targetType) + { + return + targetType.IsEnum(out var underlyingType) && + sourceType == underlyingType; + } + + static bool IsNumericConversion(Type sourceType, Type targetType) + { + return sourceType.IsNumeric() && targetType.IsNumeric(); + } + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitor.cs new file mode 100644 index 00000000000..1b6c25e1238 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitor.cs @@ -0,0 +1,71 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; +using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal partial class SerializerFinderVisitor : ExpressionVisitor +{ + private bool _isMakingProgress = true; + private readonly SerializerMap _nodeSerializers; + private int _oldNodeSerializersCount = 0; + private readonly ExpressionTranslationOptions _translationOptions; + private bool _useDefaultSerializerForConstants = false; // make as much progress as possible before setting this to true + + public SerializerFinderVisitor(ExpressionTranslationOptions translationOptions, SerializerMap nodeSerializers) + { + _nodeSerializers = nodeSerializers; + _translationOptions = translationOptions; + } + + public bool IsMakingProgress => _isMakingProgress; + + public void EndPass() + { + var newNodeSerializersCount = _nodeSerializers.Count; + if (newNodeSerializersCount == _oldNodeSerializersCount) + { + if (_useDefaultSerializerForConstants) + { + _isMakingProgress = false; + } + else + { + _useDefaultSerializerForConstants = true; + } + } + } + + public void StartPass() + { + _oldNodeSerializersCount = _nodeSerializers.Count; + } + + public override Expression Visit(Expression node) + { + if (IsKnown(node, out var nodeSerializer)) + { + if (nodeSerializer is IIgnoreSubtreeSerializer or IUnknowableSerializer) + { + return node; // don't visit subtree + } + } + + return base.Visit(node); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerMap.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerMap.cs new file mode 100644 index 00000000000..8b35de67223 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerMap.cs @@ -0,0 +1,111 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; + +internal interface IReadOnlySerializerMap +{ + IBsonSerializer GetSerializer(Expression node); +} + +internal class SerializerMap : IReadOnlySerializerMap +{ + private readonly Dictionary _map = new(); + + public int Count => _map.Count; + + public void AddSerializer(Expression node, IBsonSerializer serializer) + { + if (serializer.ValueType != node.Type && + node.Type.IsNullable(out var nodeNonNullableType) && + serializer.ValueType.IsNullable(out var serializerNonNullableType) && + serializer is INullableSerializer nullableSerializer) + { + if (nodeNonNullableType.IsEnum(out var targetEnumUnderlyingType) && targetEnumUnderlyingType == serializerNonNullableType) + { + var enumType = nodeNonNullableType; + var underlyingTypeSerializer = nullableSerializer.ValueSerializer; + var enumSerializer = AsUnderlyingTypeEnumSerializer.Create(enumType, underlyingTypeSerializer); + serializer = NullableSerializer.Create(enumSerializer); + } + else if (serializerNonNullableType.IsEnum(out var serializerUnderlyingType) && serializerUnderlyingType == nodeNonNullableType) + { + var enumSerializer = nullableSerializer.ValueSerializer; + var underlyingTypeSerializer = AsEnumUnderlyingTypeSerializer.Create(enumSerializer); + serializer = NullableSerializer.Create(underlyingTypeSerializer); + } + } + + if (serializer.ValueType != node.Type) + { + if (node.Type.IsAssignableFrom(serializer.ValueType)) + { + serializer = DowncastingSerializer.Create(baseType: node.Type, derivedType: serializer.ValueType, derivedTypeSerializer: serializer); + } + else if (serializer.ValueType.IsAssignableFrom(node.Type)) + { + serializer = UpcastingSerializer.Create(baseType: serializer.ValueType, derivedType: node.Type, baseTypeSerializer: serializer); + } + else + { + throw new ArgumentException($"Serializer value type {serializer.ValueType} does not match expression value type {node.Type}", nameof(serializer)); + } + } + + if (_map.TryGetValue(node, out var existingSerializer)) + { + throw new ExpressionNotSupportedException( + node, + because: $"there are duplicate known serializers for expression '{node}': {serializer.GetType()} and {existingSerializer.GetType()}"); + } + + _map.Add(node, serializer); + } + + public IBsonSerializer GetSerializer(Expression node) + { + if (_map.TryGetValue(node, out var nodeSerializer)) + { + return nodeSerializer; + } + + throw new ExpressionNotSupportedException(node, because: "unable to determine which serializer to use"); + } + + public bool IsNotKnown(Expression node) + { + return !IsKnown(node); + } + + public bool IsKnown(Expression node) + { + return _map.ContainsKey(node); + } + + public bool IsKnown(Expression node, out IBsonSerializer serializer) + { + serializer = null; + return node != null && _map.TryGetValue(node, out serializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializer.cs similarity index 63% rename from src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs rename to src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializer.cs index 816e5fc237f..7e0b3d1e75c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializer.cs @@ -20,24 +20,24 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers { - internal interface IEnumUnderlyingTypeSerializer + internal interface IAsEnumUnderlyingTypeSerializer { IBsonSerializer EnumSerializer { get; } } - internal class EnumUnderlyingTypeSerializer : StructSerializerBase, IEnumUnderlyingTypeSerializer + internal class AsEnumUnderlyingTypeSerializer : StructSerializerBase, IAsEnumUnderlyingTypeSerializer where TEnum : Enum - where TEnumUnderlyingType : struct + where TUnderlyingType : struct { // private fields private readonly IBsonSerializer _enumSerializer; // constructors - public EnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) + public AsEnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) { - if (typeof(TEnumUnderlyingType) != Enum.GetUnderlyingType(typeof(TEnum))) + if (typeof(TUnderlyingType) != Enum.GetUnderlyingType(typeof(TEnum))) { - throw new ArgumentException($"{typeof(TEnumUnderlyingType).FullName} is not the underlying type of {typeof(TEnum).FullName}."); + throw new ArgumentException($"{typeof(TUnderlyingType).FullName} is not the underlying type of {typeof(TEnum).FullName}."); } _enumSerializer = Ensure.IsNotNull(enumSerializer, nameof(enumSerializer)); } @@ -46,13 +46,13 @@ public EnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) public IBsonSerializer EnumSerializer => _enumSerializer; // explicitly implemented properties - IBsonSerializer IEnumUnderlyingTypeSerializer.EnumSerializer => EnumSerializer; + IBsonSerializer IAsEnumUnderlyingTypeSerializer.EnumSerializer => EnumSerializer; // public methods - public override TEnumUnderlyingType Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + public override TUnderlyingType Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) { var enumValue = _enumSerializer.Deserialize(context); - return (TEnumUnderlyingType)(object)enumValue; + return (TUnderlyingType)(object)enumValue; } /// @@ -62,28 +62,28 @@ public override bool Equals(object obj) if (object.ReferenceEquals(this, obj)) { return true; } return base.Equals(obj) && - obj is EnumUnderlyingTypeSerializer other && + obj is AsEnumUnderlyingTypeSerializer other && object.Equals(_enumSerializer, other._enumSerializer); } /// public override int GetHashCode() => _enumSerializer.GetHashCode(); - public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TEnumUnderlyingType value) + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TUnderlyingType value) { var enumValue = (TEnum)(object)value; _enumSerializer.Serialize(context, enumValue); } } - internal static class EnumUnderlyingTypeSerializer + internal static class AsEnumUnderlyingTypeSerializer { public static IBsonSerializer Create(IBsonSerializer enumSerializer) { var enumType = enumSerializer.ValueType; var underlyingType = Enum.GetUnderlyingType(enumType); - var enumUnderlyingTypeSerializerType = typeof(EnumUnderlyingTypeSerializer<,>).MakeGenericType(enumType, underlyingType); - return (IBsonSerializer)Activator.CreateInstance(enumUnderlyingTypeSerializerType, enumSerializer); + var asEnumUnderlyingTypeSerializerType = typeof(AsEnumUnderlyingTypeSerializer<,>).MakeGenericType(enumType, underlyingType); + return (IBsonSerializer)Activator.CreateInstance(asEnumUnderlyingTypeSerializerType, enumSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsUnderlyingTypeEnumSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsUnderlyingTypeEnumSerializer.cs new file mode 100644 index 00000000000..41f673af856 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/AsUnderlyingTypeEnumSerializer.cs @@ -0,0 +1,88 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Core.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers +{ + internal interface IAsUnderlyingTypeEnumSerializer + { + IBsonSerializer UnderlyingTypeSerializer { get; } + } + + internal class AsUnderlyingTypeEnumSerializer : SerializerBase, IAsUnderlyingTypeEnumSerializer + where TEnum : Enum + where TUnderlyingType : struct + { + // private fields + private readonly IBsonSerializer _underlyingTypeSerializer; + + // constructors + public AsUnderlyingTypeEnumSerializer(IBsonSerializer underlyingTypeSerializer) + { + if (typeof(TUnderlyingType) != Enum.GetUnderlyingType(typeof(TEnum))) + { + throw new ArgumentException($"{typeof(TUnderlyingType).FullName} is not the underlying type of {typeof(TEnum).FullName}."); + } + _underlyingTypeSerializer = Ensure.IsNotNull(underlyingTypeSerializer, nameof(underlyingTypeSerializer)); + } + + // public properties + public IBsonSerializer UnderlyingTypeSerializer => _underlyingTypeSerializer; + + // explicitly implemented properties + IBsonSerializer IAsUnderlyingTypeEnumSerializer.UnderlyingTypeSerializer => UnderlyingTypeSerializer; + + // public methods + public override TEnum Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var underlyingTypeValue = _underlyingTypeSerializer.Deserialize(context); + return (TEnum)(object)underlyingTypeValue; + } + + /// + public override bool Equals(object obj) + { + if (object.ReferenceEquals(obj, null)) { return false; } + if (object.ReferenceEquals(this, obj)) { return true; } + return + base.Equals(obj) && + obj is AsUnderlyingTypeEnumSerializer other && + object.Equals(_underlyingTypeSerializer, other._underlyingTypeSerializer); + } + + /// + public override int GetHashCode() => _underlyingTypeSerializer.GetHashCode(); + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TEnum value) + { + var underlyingTypeValue = (TUnderlyingType)(object)value; + _underlyingTypeSerializer.Serialize(context, underlyingTypeValue); + } + } + + internal static class AsUnderlyingTypeEnumSerializer + { + public static IBsonSerializer Create(Type enumType, IBsonSerializer underlyingTypeSerializer) + { + var underlyingType = Enum.GetUnderlyingType(enumType); + var asUnderlyingTypeEnumSerializerType = typeof(AsUnderlyingTypeEnumSerializer<,>).MakeGenericType(enumType, underlyingType); + return (IBsonSerializer)Activator.CreateInstance(asUnderlyingTypeEnumSerializerType, underlyingTypeSerializer); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs index bfecb1ef9c7..0d56847649e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/DictionarySerializer.cs @@ -45,8 +45,7 @@ public DictionarySerializer( { } - protected override ICollection> CreateAccumulator() - { - return new Dictionary(); - } + protected override ICollection> CreateAccumulator() => new Dictionary(); + + protected override DictionaryFinalizeAccumulator(ICollection> accumulator) => (Dictionary)accumulator; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/HashSetSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/HashSetSerializer.cs new file mode 100644 index 00000000000..87a47747e5f --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/HashSetSerializer.cs @@ -0,0 +1,42 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class HashSetSerializer +{ + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var serializerType = typeof(HashSetSerializer<>).MakeGenericType(itemSerializer.ValueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializer); + } +} + +internal class HashSetSerializer : EnumerableInterfaceImplementerSerializerBase, T> +{ + public HashSetSerializer(IBsonSerializer itemSerializer) + : base(itemSerializer) + { + } + + protected override object CreateAccumulator() => new HashSet(); + + protected override HashSet FinalizeResult(object accumulator) => (HashSet)accumulator; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableOrIQueryableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableOrIQueryableSerializer.cs new file mode 100644 index 00000000000..f03bf327711 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableOrIQueryableSerializer.cs @@ -0,0 +1,30 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IEnumerableOrIQueryableSerializer +{ + public static IBsonSerializer Create(Type enumerableOrQueryableType, IBsonSerializer itemSerializer) + { + return enumerableOrQueryableType.ImplementsIQueryable(out _) ? + IQueryableSerializer.Create(itemSerializer) : + IEnumerableSerializer.Create(itemSerializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IOrderedEnumerableOrIOrderedQueryableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IOrderedEnumerableOrIOrderedQueryableSerializer.cs new file mode 100644 index 00000000000..da44f92e218 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IOrderedEnumerableOrIOrderedQueryableSerializer.cs @@ -0,0 +1,30 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IOrderedEnumerableOrIOrderedQueryableSerializer +{ + public static IBsonSerializer Create(Type enumerableOrQueryableType, IBsonSerializer itemSerializer) + { + return enumerableOrQueryableType.ImplementsIOrderedQueryable(out _) ? + IOrderedQueryableSerializer.Create(itemSerializer) : + IOrderedEnumerableSerializer.Create(itemSerializer); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs index 2be9f49a1b3..b169febe181 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ISetWindowFieldsPartitionSerializer.cs @@ -24,7 +24,7 @@ internal interface ISetWindowFieldsPartitionSerializer IBsonSerializer InputSerializer { get; } } - internal class ISetWindowFieldsPartitionSerializer : IBsonSerializer>, ISetWindowFieldsPartitionSerializer + internal class ISetWindowFieldsPartitionSerializer : IBsonSerializer>, ISetWindowFieldsPartitionSerializer, IBsonArraySerializer { private readonly IBsonSerializer _inputSerializer; @@ -61,16 +61,20 @@ public void Serialize(BsonSerializationContext context, BsonSerializationArgs ar throw new InvalidOperationException("This serializer is not intended to be used."); } - public void Serialize(BsonSerializationContext context, BsonSerializationArgs args, object value) { throw new InvalidOperationException("This serializer is not intended to be used."); } - object IBsonSerializer.Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) { throw new InvalidOperationException("This serializer is not intended to be used."); } + + public bool TryGetItemSerializationInfo(out BsonSerializationInfo itemSerializationInfo) + { + itemSerializationInfo = new BsonSerializationInfo(null, _inputSerializer, _inputSerializer.ValueType); + return true; + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreNodeSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreNodeSerializer.cs new file mode 100644 index 00000000000..23fb02f7db8 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreNodeSerializer.cs @@ -0,0 +1,33 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IgnoreNodeSerializer +{ + public static IBsonSerializer Create(Type valueType) + { + var serializerType = typeof(IgnoreNodeSerializer<>).MakeGenericType(valueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType); + } +} + +internal class IgnoreNodeSerializer : SerializerBase +{ +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreSubtreeSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreSubtreeSerializer.cs new file mode 100644 index 00000000000..5476eb1e747 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IgnoreSubtreeSerializer.cs @@ -0,0 +1,37 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class IgnoreSubtreeSerializer +{ + public static IBsonSerializer Create(Type valueType) + { + var serializerType = typeof(IgnoreSubtreeSerializer<>).MakeGenericType(valueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType); + } +} + +internal interface IIgnoreSubtreeSerializer +{ +} + +internal class IgnoreSubtreeSerializer : SerializerBase, IIgnoreSubtreeSerializer +{ +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ListSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ListSerializer.cs new file mode 100644 index 00000000000..2a7044e7116 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/ListSerializer.cs @@ -0,0 +1,42 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class ListSerializer +{ + public static IBsonSerializer Create(IBsonSerializer itemSerializer) + { + var serializerType = typeof(ListSerializer<>).MakeGenericType(itemSerializer.ValueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializer); + } +} + +internal class ListSerializer : EnumerableInterfaceImplementerSerializerBase, T> +{ + public ListSerializer(IBsonSerializer itemSerializer) + : base(itemSerializer) + { + } + + protected override object CreateAccumulator() => new List(); + + protected override List FinalizeResult(object accumulator) => (List)accumulator; +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/NumericConversionSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/NumericConversionSerializer.cs new file mode 100644 index 00000000000..c09e78a713c --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/NumericConversionSerializer.cs @@ -0,0 +1,77 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class NumericConversionSerializer +{ + public static IBsonSerializer Create(Type sourceType, Type targetType, IBsonSerializer sourceSerializer) + { + var serializerType = typeof(NumericConversionSerializer<,>).MakeGenericType(sourceType, targetType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, sourceSerializer); + } +} + +internal class NumericConversionSerializer : SerializerBase, IHasRepresentationSerializer +{ + private readonly BsonType _representation; + private readonly IBsonSerializer _sourceSerializer; + + public BsonType Representation => _representation; + + public NumericConversionSerializer(IBsonSerializer sourceSerializer) + { + if (sourceSerializer is not IHasRepresentationSerializer hasRepresentationSerializer) + { + throw new NotSupportedException($"Serializer class {sourceSerializer.GetType().Name} does not implement IHasRepresentationSerializer."); + } + + _sourceSerializer = sourceSerializer; + _representation = hasRepresentationSerializer.Representation; + } + + public override TTarget Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var sourceValue = _sourceSerializer.Deserialize(context); + return (TTarget)Convert(typeof(TSource), typeof(TTarget), sourceValue); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TTarget value) + { + var sourceValue = Convert(typeof(TTarget), typeof(TSource), value); + _sourceSerializer.Serialize(context, args, sourceValue); + } + + private object Convert(Type sourceType, Type targetType, object value) + { + return (Type.GetTypeCode(sourceType), Type.GetTypeCode(targetType)) switch + { + (TypeCode.Decimal, TypeCode.Double) => (object)(double)(decimal)value, + (TypeCode.Double, TypeCode.Decimal) => (object)(decimal)(double)value, + (TypeCode.Int16, TypeCode.Int32) => (object)(int)(short)value, + (TypeCode.Int16, TypeCode.Int64) => (object)(long)(short)value, + (TypeCode.Int32, TypeCode.Int16) => (object)(short)(int)value, + (TypeCode.Int32, TypeCode.Int64) => (object)(long)(int)value, + (TypeCode.Int64, TypeCode.Int16) => (object)(short)(long)value, + (TypeCode.Int64, TypeCode.Int32) => (object)(int)(long)value, + _ => throw new NotSupportedException($"Cannot convert {sourceType} to {targetType}."), + }; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/PolymorphicArraySerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/PolymorphicArraySerializer.cs new file mode 100644 index 00000000000..beb65eee63e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/PolymorphicArraySerializer.cs @@ -0,0 +1,98 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal interface IPolymorphicArraySerializer +{ + IBsonSerializer GetItemSerializer(int index); +} + +internal static class PolymorphicArraySerializer +{ + public static IBsonSerializer Create(Type itemType, IEnumerable itemSerializers) + { + var serializerType = typeof(PolymorphicArraySerializer<>).MakeGenericType(itemType); + return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializers); + } +} + +internal sealed class PolymorphicArraySerializer : SerializerBase, IPolymorphicArraySerializer +{ + private readonly IReadOnlyList _itemSerializers; + + public PolymorphicArraySerializer(IEnumerable itemSerializers) + { + var itemSerializersArray = itemSerializers.ToArray(); + foreach (var itemSerializer in itemSerializersArray) + { + if (!typeof(TItem).IsAssignableFrom(itemSerializer.ValueType)) + { + throw new ArgumentException($"Serializer class {itemSerializer.ValueType} value type is not assignable to item type {typeof(TItem).Name}"); + } + } + + _itemSerializers = itemSerializersArray; + } + + public override TItem[] Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + + reader.ReadStartArray(); + var i = 0; + var array = new TItem[_itemSerializers.Count]; + while (reader.ReadBsonType() != BsonType.EndOfDocument) + { + if (i < array.Length) + { + array[i] = (TItem)_itemSerializers[i].Deserialize(context); + i++; + } + } + if (i != array.Length) + { + throw new BsonSerializationException($"Expected {array.Length} array items but found {i}."); + } + reader.ReadEndArray(); + + return array; + } + + IBsonSerializer IPolymorphicArraySerializer.GetItemSerializer(int index) => _itemSerializers[index]; + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TItem[] value) + { + if (value.Length != _itemSerializers.Count) + { + throw new BsonSerializationException($"Expected array value to have {_itemSerializers.Count} items but found {value.Length}."); + } + + var writer = context.Writer; + writer.WriteStartArray(); + for (var i = 0; i < value.Length; i++) + { + _itemSerializers[i].Serialize(context, args, value[i]); + } + writer.WriteEndArray(); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/TupleOrValueTupleSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/TupleOrValueTupleSerializer.cs new file mode 100644 index 00000000000..762b2839ee8 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/TupleOrValueTupleSerializer.cs @@ -0,0 +1,35 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class TupleOrValueTupleSerializer +{ + public static IBsonSerializer Create(Type tupleType, IEnumerable itemSerializers) + { + return tupleType.Name switch + { + _ when tupleType.IsTuple() => TupleSerializer.Create(itemSerializers), + _ when tupleType.IsValueTuple() => ValueTupleSerializer.Create(itemSerializers), + _ => throw new ArgumentException($"Unexpected tuple type: {tupleType.Name}", nameof(tupleType)) + }; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UnknowableSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UnknowableSerializer.cs new file mode 100644 index 00000000000..e3e6583408b --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UnknowableSerializer.cs @@ -0,0 +1,37 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers; + +internal static class UnknowableSerializer +{ + public static IBsonSerializer Create(Type valueType) + { + var serializerType = typeof(UnknowableSerializer<>).MakeGenericType(valueType); + return (IBsonSerializer)Activator.CreateInstance(serializerType); + } +} + +internal interface IUnknowableSerializer +{ +} + +internal class UnknowableSerializer : SerializerBase, IUnknowableSerializer +{ +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs new file mode 100644 index 00000000000..e2843cb8602 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/UpcastingSerializer.cs @@ -0,0 +1,92 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers +{ + internal static class UpcastingSerializer + { + public static IBsonSerializer Create( + Type baseType, + Type derivedType, + IBsonSerializer baseTypeSerializer) + { + var upcastingSerializerType = typeof(UpcastingSerializer<,>).MakeGenericType(baseType, derivedType); + return (IBsonSerializer)Activator.CreateInstance(upcastingSerializerType, baseTypeSerializer); + } + } + + internal sealed class UpcastingSerializer : SerializerBase, IBsonArraySerializer, IBsonDocumentSerializer + where TDerived : TBase + { + private readonly IBsonSerializer _baseTypeSerializer; + + public UpcastingSerializer(IBsonSerializer baseTypeSerializer) + { + _baseTypeSerializer = baseTypeSerializer ?? throw new ArgumentNullException(nameof(baseTypeSerializer)); + } + + public Type BaseType => typeof(TBase); + + public IBsonSerializer BaseTypeSerializer => _baseTypeSerializer; + + public Type DerivedType => typeof(TDerived); + + public override TDerived Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + return (TDerived)_baseTypeSerializer.Deserialize(context); + } + + public override bool Equals(object obj) + { + if (object.ReferenceEquals(obj, null)) { return false; } + if (object.ReferenceEquals(this, obj)) { return true; } + return + base.Equals(obj) && + obj is UpcastingSerializer other && + object.Equals(_baseTypeSerializer, other._baseTypeSerializer); + } + + public override int GetHashCode() => 0; + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TDerived value) + { + _baseTypeSerializer.Serialize(context, value); + } + + public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo) + { + if (_baseTypeSerializer is not IBsonArraySerializer arraySerializer) + { + throw new NotSupportedException($"The class {_baseTypeSerializer.GetType().FullName} does not implement IBsonArraySerializer."); + } + + return arraySerializer.TryGetItemSerializationInfo(out serializationInfo); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + if (_baseTypeSerializer is not IBsonDocumentSerializer documentSerializer) + { + throw new NotSupportedException($"The class {_baseTypeSerializer.GetType().FullName} does not implement IBsonDocumentSerializer."); + } + + return documentSerializer.TryGetMemberSerializationInfo(memberName, out serializationInfo); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs index f3bb40aaf3a..c66f84b213e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/WrappedValueSerializer.cs @@ -98,6 +98,20 @@ public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationI public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) { + if (_valueSerializer is IBsonDocumentSerializer documentSerializer) + { + if (documentSerializer.TryGetMemberSerializationInfo(memberName, out serializationInfo)) + { + serializationInfo = BsonSerializationInfo.CreateWithPath( + [_fieldName, serializationInfo.ElementName], + serializationInfo.Serializer, + serializationInfo.NominalType); + return true; + } + + return false; + } + throw new InvalidOperationException(); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs index 818a92fab7a..3462d1bcf3e 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ArrayIndexExpressionToAggregationExpressionTranslator.cs @@ -14,8 +14,11 @@ */ using System.Linq.Expressions; +using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -30,7 +33,8 @@ public static TranslatedExpression Translate(TranslationContext context, BinaryE var indexExpression = expression.Right; var indexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, indexExpression); var ast = AstExpression.ArrayElemAt(arrayTranslation.Ast, indexTranslation.Ast); - var itemSerializer = ArraySerializerHelper.GetItemSerializer(arrayTranslation.Serializer); + var arraySerializer = arrayTranslation.Serializer; + var itemSerializer = arraySerializer.GetItemSerializer(indexExpression, arrayExpression); return new TranslatedExpression(expression, ast, itemSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ClientSideProjectionRewriter.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ClientSideProjectionRewriter.cs index 939b61aaea9..35bab95e682 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ClientSideProjectionRewriter.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ClientSideProjectionRewriter.cs @@ -28,22 +28,6 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg internal class ClientSideProjectionRewriter: ExpressionVisitor { #region static - private readonly static MethodInfo[] __orderByMethods = - [ - EnumerableMethod.OrderBy, - EnumerableMethod.OrderByDescending, - QueryableMethod.OrderBy, - QueryableMethod.OrderByDescending - ]; - - private readonly static MethodInfo[] __thenByMethods = - [ - EnumerableMethod.ThenBy, - EnumerableMethod.ThenByDescending, - QueryableMethod.ThenBy, - QueryableMethod.ThenByDescending - ]; - public static (TranslatedExpression[], LambdaExpression) RewriteProjection(TranslationContext context, LambdaExpression projectionLambda, IBsonSerializer sourceSerializer) { var rootParameter = projectionLambda.Parameters.Single(); @@ -108,7 +92,7 @@ public override Expression Visit(Expression node) protected override Expression VisitMethodCall(MethodCallExpression node) { // don't split OrderBy/ThenBy across the client/server boundary - if (node.Method.IsOneOf(__thenByMethods)) + if (node.Method.IsOneOf(EnumerableOrQueryableMethod.ThenByOverloads)) { return VisitThenBy(node); } @@ -126,13 +110,13 @@ private Expression VisitThenBy(MethodCallExpression node) { var sourceMethod = sourceMethodCallExpression.Method; - if (sourceMethod.IsOneOf(__thenByMethods)) + if (sourceMethod.IsOneOf(EnumerableOrQueryableMethod.ThenByOverloads)) { var rewrittenSourceExpression = VisitThenBy(sourceMethodCallExpression); return node.Update(node.Object, [rewrittenSourceExpression, keySelectorExpression]); } - if (sourceMethod.IsOneOf(__orderByMethods)) + if (sourceMethod.IsOneOf(EnumerableOrQueryableMethod.OrderByOverloads)) { var rewrittenSourceExpression = VisitOrderBy(sourceMethodCallExpression); return node.Update(node.Object, [rewrittenSourceExpression, keySelectorExpression]); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs index 7487627213d..33a0f678737 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs @@ -23,12 +23,11 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ConstantExpressionToAggregationExpressionTranslator { - public static TranslatedExpression Translate(ConstantExpression constantExpression) + public static TranslatedExpression Translate(TranslationContext context, ConstantExpression constantExpression) { - var constantType = constantExpression.Type; - var constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); + var constantSerializer = context.NodeSerializers.GetSerializer(constantExpression); return Translate(constantExpression, constantSerializer); - } + } public static TranslatedExpression Translate(ConstantExpression constantExpression, IBsonSerializer constantSerializer) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs index 532e10c1609..90cf9d8c45d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs @@ -214,7 +214,7 @@ private static TranslatedExpression TranslateConvertEnumToEnum(UnaryExpression e private static TranslatedExpression TranslateConvertEnumToUnderlyingType(UnaryExpression expression, Type sourceType, Type targetType, TranslatedExpression sourceTranslation) { var enumSerializer = sourceTranslation.Serializer; - var targetSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer); + var targetSerializer = AsEnumUnderlyingTypeSerializer.Create(enumSerializer); return new TranslatedExpression(expression, sourceTranslation.Ast, targetSerializer); } @@ -265,7 +265,7 @@ private static TranslatedExpression TranslateConvertUnderlyingTypeToEnum(UnaryEx var valueSerializer = sourceTranslation.Serializer; IBsonSerializer targetSerializer; - if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + if (valueSerializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) { targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs index 9f019682a63..077decb5f4c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs @@ -75,7 +75,7 @@ public static TranslatedExpression TranslateWithoutUnwrapping(TranslationContext case ExpressionType.Conditional: return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression); case ExpressionType.Constant: - return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression); + return ConstantExpressionToAggregationExpressionTranslator.Translate(context, (ConstantExpression)expression); case ExpressionType.Index: return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression); case ExpressionType.ListInit: diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index 20f7e81312c..758fcd0efa4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -13,16 +13,14 @@ * limitations under the License. */ -using System; using System.Collections.Generic; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -44,168 +42,63 @@ public static TranslatedExpression Translate( NewExpression newExpression, IReadOnlyList bindings) { + var nodeSerializer = context.NodeSerializers.GetSerializer(expression); var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct var constructorArguments = newExpression.Arguments; - var computedFields = new List(); - var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); - if (constructorInfo != null && creatorMap != null) + var computedFields = new List(); + if (constructorInfo != null && constructorArguments.Count > 0) { - var constructorParameters = constructorInfo.GetParameters(); - var creatorMapParameters = creatorMap.Arguments?.ToArray(); - if (constructorParameters.Length > 0) + var matchingMemberSerializationInfos = nodeSerializer.GetMatchingMemberSerializationInfosForConstructorParameters(expression, constructorInfo); + + for (var i = 0; i < constructorArguments.Count; i++) { - if (creatorMapParameters == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); - } - if (creatorMapParameters.Length != constructorParameters.Length) - { - throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); - } + var argument = constructorArguments[i]; + var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argument); + var matchingMemberSerializationInfo = matchingMemberSerializationInfos[i]; - for (var i = 0; i < creatorMapParameters.Length; i++) + if (!argumentTranslation.Serializer.CanBeAssignedTo(matchingMemberSerializationInfo.Serializer)) { - var creatorMapParameter = creatorMapParameters[i]; - var constructorArgumentExpression = constructorArguments[i]; - var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression); - var constructorArgumentType = constructorArgumentExpression.Type; - var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType); - var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); - EnsureDefaultValue(memberMap); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast)); + throw new ExpressionNotSupportedException(argument, expression, because: "constructor argument serializer is not compatible with matching member serializer"); } - } - } - - foreach (var binding in bindings) - { - var memberAssignment = (MemberAssignment)binding; - var member = memberAssignment.Member; - var memberMap = FindMemberMap(expression, classMap, member.Name); - var valueExpression = memberAssignment.Expression; - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast)); - } - - var ast = AstExpression.ComputedDocument(computedFields); - classMap.Freeze(); - var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); - var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); - - return new TranslatedExpression(expression, ast, serializer); - } - - private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) - { - BsonClassMap baseClassMap = null; - if (classType.BaseType != null) - { - baseClassMap = CreateClassMap(classType.BaseType, null, out _); - } - - var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); - var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); - var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); - if (constructorInfo != null) - { - creatorMap = classMap.MapConstructor(constructorInfo); - } - else - { - creatorMap = null; - } - - classMap.AutoMap(); - classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here - - return classMap; - } - - private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) - { - var memberType = memberMap.MemberType; - var memberSerializer = memberMap.GetSerializer(); - var sourceType = sourceSerializer.ValueType; - if (memberType != sourceType && - memberType.ImplementsIEnumerable(out var memberItemType) && - sourceType.ImplementsIEnumerable(out var sourceItemType) && - sourceItemType == memberItemType && - sourceSerializer is IBsonArraySerializer sourceArraySerializer && - sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && - memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) - { - var sourceItemSerializer = sourceItemSerializationInfo.Serializer; - return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); - } - - return sourceSerializer; - } - - private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) - { - var declaringClassMap = classMap; - while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) - { - declaringClassMap = declaringClassMap.BaseClassMap; - - if (declaringClassMap == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); + var computedField = AstExpression.ComputedField(matchingMemberSerializationInfo.ElementName, argumentTranslation.Ast); + computedFields.Add(computedField); } } - foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) + if (bindings.Count > 0) { - if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) + if (nodeSerializer is not IBsonDocumentSerializer documentSerializer) { - return memberMap; + throw new ExpressionNotSupportedException(expression, because: $"serializer type {nodeSerializer.GetType()} does not implement IBsonDocumentSerializer"); } - } - return declaringClassMap.MapMember(creatorMapParameter); + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; - static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) - { - var memberInfo = memberMap.MemberInfo; - return - memberInfo.MemberType == creatorMapParameter.MemberType && - memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); - } - } + if (!documentSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"member {member.Name} was not found"); + } - private static void EnsureDefaultValue(BsonMemberMap memberMap) - { - if (memberMap.IsDefaultValueSpecified) - { - return; - } + var valueExpression = memberAssignment.Expression; + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; - memberMap.SetDefaultValue(defaultValue); - } + if (!valueTranslation.Serializer.CanBeAssignedTo(memberSerializationInfo.Serializer)) + { + throw new ExpressionNotSupportedException(valueExpression, expression, because: $"value serializer is not compatible with serializer for member {member.Name}"); + } - private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) - { - foreach (var memberMap in classMap.DeclaredMemberMaps) - { - if (memberMap.MemberName == memberName) - { - return memberMap; + var computedField = AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast); + computedFields.Add(computedField); } } - if (classMap.BaseClassMap != null) - { - return FindMemberMap(expression, classMap.BaseClassMap, memberName); - } - - throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + var ast = AstExpression.ComputedDocument(computedFields); + return new TranslatedExpression(expression, ast, nodeSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AbsMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AbsMethodToAggregationExpressionTranslator.cs index 36d0e03a6b4..56dc9607212 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AbsMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AbsMethodToAggregationExpressionTranslator.cs @@ -23,23 +23,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AbsMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __absMethods = - { - MathMethod.AbsDecimal, - MathMethod.AbsDouble, - MathMethod.AbsInt16, - MathMethod.AbsInt32, - MathMethod.AbsInt64, - MathMethod.AbsSByte, - MathMethod.AbsSingle - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__absMethods)) + if (method.IsOneOf(MathMethod.AbsOverloads)) { var valueExpression = arguments[0]; var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs index 8abb03ac872..959df291eb7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs @@ -24,49 +24,19 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AggregateMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __aggregateMethods = - { - EnumerableMethod.AggregateWithFunc, - EnumerableMethod.AggregateWithSeedAndFunc, - EnumerableMethod.AggregateWithSeedFuncAndResultSelector, - QueryableMethod.AggregateWithFunc, - QueryableMethod.AggregateWithSeedAndFunc, - QueryableMethod.AggregateWithSeedFuncAndResultSelector - }; - - private static readonly MethodInfo[] __aggregateWithoutSeedMethods = - { - EnumerableMethod.AggregateWithFunc, - QueryableMethod.AggregateWithFunc - }; - - private static readonly MethodInfo[] __aggregateWithSeedMethods = - { - EnumerableMethod.AggregateWithSeedAndFunc, - EnumerableMethod.AggregateWithSeedFuncAndResultSelector, - QueryableMethod.AggregateWithSeedAndFunc, - QueryableMethod.AggregateWithSeedFuncAndResultSelector - }; - - private static readonly MethodInfo[] __aggregateWithSeedFuncAndResultSelectorMethods = - { - EnumerableMethod.AggregateWithSeedFuncAndResultSelector, - QueryableMethod.AggregateWithSeedFuncAndResultSelector - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__aggregateMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - if (method.IsOneOf(__aggregateWithoutSeedMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithFunc)) { var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); var funcParameters = funcLambda.Parameters; @@ -95,7 +65,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return new TranslatedExpression(expression, ast, itemSerializer); } - else if (method.IsOneOf(__aggregateWithSeedMethods)) + else if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithSeedOverloads)) { var seedExpression = arguments[1]; var seedTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, seedExpression); @@ -116,7 +86,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC @in: funcTranslation.Ast); var serializer = accumulatorSerializer; - if (method.IsOneOf(__aggregateWithSeedFuncAndResultSelectorMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AggregateWithSeedFuncAndResultSelector)) { var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); var resultSelectorAccumulatorParameter = resultSelectorLambda.Parameters[0]; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AllMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AllMethodToAggregationExpressionTranslator.cs index 290f49185a0..f10a7bcf416 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AllMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AllMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AllMethodToAggregationExpressionTranslator { - private readonly static MethodInfo[] __allMethods = - { - EnumerableMethod.All, - QueryableMethod.All - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__allMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.All)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AnyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AnyMethodToAggregationExpressionTranslator.cs index 5841d67f823..f89d84896e8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AnyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AnyMethodToAggregationExpressionTranslator.cs @@ -24,19 +24,6 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AnyMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __anyMethods = - { - EnumerableMethod.Any, - QueryableMethod.Any - }; - - private static readonly MethodInfo[] __anyWithPredicateMethods = - { - EnumerableMethod.AnyWithPredicate, - QueryableMethod.AnyWithPredicate, - ArrayMethod.Exists - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; @@ -46,13 +33,13 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); - if (method.IsOneOf(__anyMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Any)) { var ast = AstExpression.Gt(AstExpression.Size(sourceTranslation.Ast), 0); return new TranslatedExpression(expression, ast, new BooleanSerializer()); } - if (method.IsOneOf(__anyWithPredicateMethods) || ListMethod.IsExistsMethod(method)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AnyWithPredicate) || method.Is(ArrayMethod.Exists) || ListMethod.IsExistsMethod(method)) { var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, method.IsStatic ? arguments[1] : arguments[0]); var predicateParameter = predicateLambda.Parameters[0]; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs index 5a7a2942f70..ecee001f1c8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs @@ -24,26 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AppendOrPrependMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __appendOrPrependMethods = - { - EnumerableMethod.Append, - EnumerableMethod.Prepend, - QueryableMethod.Append, - QueryableMethod.Prepend - }; - - private static readonly MethodInfo[] __appendMethods = - { - EnumerableMethod.Append, - QueryableMethod.Append - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__appendOrPrependMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AppendOrPrepend)) { var sourceExpression = arguments[0]; var elementExpression = arguments[1]; @@ -68,7 +54,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } } - var ast = method.IsOneOf(__appendMethods) ? + var ast = method.IsOneOf(EnumerableOrQueryableMethod.Append) ? AstExpression.ConcatArrays(sourceTranslation.Ast, AstExpression.ComputedArray(elementTranslation.Ast)) : AstExpression.ConcatArrays(AstExpression.ComputedArray(elementTranslation.Ast), sourceTranslation.Ast); var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AverageMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AverageMethodToAggregationExpressionTranslator.cs index f2849ef812c..35b0f07e624 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AverageMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AverageMethodToAggregationExpressionTranslator.cs @@ -26,87 +26,19 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AverageMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __averageMethods = - { - EnumerableMethod.AverageDecimal, - EnumerableMethod.AverageDecimalWithSelector, - EnumerableMethod.AverageDouble, - EnumerableMethod.AverageDoubleWithSelector, - EnumerableMethod.AverageInt32, - EnumerableMethod.AverageInt32WithSelector, - EnumerableMethod.AverageInt64, - EnumerableMethod.AverageInt64WithSelector, - EnumerableMethod.AverageNullableDecimal, - EnumerableMethod.AverageNullableDecimalWithSelector, - EnumerableMethod.AverageNullableDouble, - EnumerableMethod.AverageNullableDoubleWithSelector, - EnumerableMethod.AverageNullableInt32, - EnumerableMethod.AverageNullableInt32WithSelector, - EnumerableMethod.AverageNullableInt64, - EnumerableMethod.AverageNullableInt64WithSelector, - EnumerableMethod.AverageNullableSingle, - EnumerableMethod.AverageNullableSingleWithSelector, - EnumerableMethod.AverageSingle, - EnumerableMethod.AverageSingleWithSelector, - QueryableMethod.AverageDecimal, - QueryableMethod.AverageDecimalWithSelector, - QueryableMethod.AverageDouble, - QueryableMethod.AverageDoubleWithSelector, - QueryableMethod.AverageInt32, - QueryableMethod.AverageInt32WithSelector, - QueryableMethod.AverageInt64, - QueryableMethod.AverageInt64WithSelector, - QueryableMethod.AverageNullableDecimal, - QueryableMethod.AverageNullableDecimalWithSelector, - QueryableMethod.AverageNullableDouble, - QueryableMethod.AverageNullableDoubleWithSelector, - QueryableMethod.AverageNullableInt32, - QueryableMethod.AverageNullableInt32WithSelector, - QueryableMethod.AverageNullableInt64, - QueryableMethod.AverageNullableInt64WithSelector, - QueryableMethod.AverageNullableSingle, - QueryableMethod.AverageNullableSingleWithSelector, - QueryableMethod.AverageSingle, - QueryableMethod.AverageSingleWithSelector - }; - - private static readonly MethodInfo[] __averageWithSelectorMethods = - { - EnumerableMethod.AverageDecimalWithSelector, - EnumerableMethod.AverageDoubleWithSelector, - EnumerableMethod.AverageInt32WithSelector, - EnumerableMethod.AverageInt64WithSelector, - EnumerableMethod.AverageNullableDecimalWithSelector, - EnumerableMethod.AverageNullableDoubleWithSelector, - EnumerableMethod.AverageNullableInt32WithSelector, - EnumerableMethod.AverageNullableInt64WithSelector, - EnumerableMethod.AverageNullableSingleWithSelector, - EnumerableMethod.AverageSingleWithSelector, - QueryableMethod.AverageDecimalWithSelector, - QueryableMethod.AverageDoubleWithSelector, - QueryableMethod.AverageInt32WithSelector, - QueryableMethod.AverageInt64WithSelector, - QueryableMethod.AverageNullableDecimalWithSelector, - QueryableMethod.AverageNullableDoubleWithSelector, - QueryableMethod.AverageNullableInt32WithSelector, - QueryableMethod.AverageNullableInt64WithSelector, - QueryableMethod.AverageNullableSingleWithSelector, - QueryableMethod.AverageSingleWithSelector - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__averageMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AverageOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); AstExpression ast; - if (method.IsOneOf(__averageWithSelectorMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.AverageWithSelectorOverloads)) { var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); var selectorParameter = selectorLambda.Parameters[0]; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareMethodToAggregationExpressionTranslator.cs index 79f68c311ce..91aedb65465 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class CompareMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __stringCompareMethods = - [ - StringMethod.StaticCompare, - StringMethod.StaticCompareWithIgnoreCase - ]; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsStaticCompareMethod() || method.IsInstanceCompareToMethod() || method.IsOneOf(__stringCompareMethods)) + if (method.IsStaticCompareMethod() || method.IsInstanceCompareToMethod() || method.IsOneOf(StringMethod.CompareOverloads)) { Expression value1Expression; Expression value2Expression; @@ -54,7 +48,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var value2Translation = ExpressionToAggregationExpressionTranslator.Translate(context, value2Expression); AstExpression ast; - if (method.Is(StringMethod.StaticCompareWithIgnoreCase)) + if (method.Is(StringMethod.CompareWithIgnoreCase)) { var ignoreCaseExpression = arguments[2]; var ignoreCase = ignoreCaseExpression.GetConstantValue(containingExpression: expression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs index 7a4d64a3ff0..c6611ed9044 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs @@ -31,7 +31,8 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return StartsWithContainsOrEndsWithMethodToAggregationExpressionTranslator.Translate(context, expression); } - if (IsEnumerableContainsMethod(expression, out var sourceExpression, out var valueExpression)) + if (EnumerableMethod.IsContainsMethod(expression, out var sourceExpression, out var valueExpression) && + !expression.Method.Is(StringMethod.ContainsWithChar)) { return TranslateEnumerableContains(context, expression, sourceExpression, valueExpression); } @@ -40,39 +41,6 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } // private methods - private static bool IsEnumerableContainsMethod(MethodCallExpression expression, out Expression sourceExpression, out Expression valueExpression) - { - var method = expression.Method; - var arguments = expression.Arguments; - - if (method.IsOneOf(EnumerableMethod.Contains, QueryableMethod.Contains)) - { - sourceExpression = arguments[0]; - valueExpression = arguments[1]; - return true; - } - - if (!method.IsStatic && method.ReturnType == typeof(bool) && method.Name == "Contains" && arguments.Count == 1) - { - sourceExpression = expression.Object; - valueExpression = arguments[0]; - - if (sourceExpression.Type.TryGetIEnumerableGenericInterface(out var ienumerableInterface)) - { - var itemType = ienumerableInterface.GetGenericArguments()[0]; - if (itemType == valueExpression.Type) - { - // string.Contains(char) is not translated like other Contains methods because string is not represented as an array - return sourceExpression.Type != typeof(string) && valueExpression.Type != typeof(char); - } - } - } - - sourceExpression = null; - valueExpression = null; - return false; - } - private static TranslatedExpression TranslateEnumerableContains(TranslationContext context, MethodCallExpression expression, Expression sourceExpression, Expression valueExpression) { var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs index 9f6844b3031..d4283b67f23 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConvertMethodToAggregationExpressionTranslator.cs @@ -42,8 +42,9 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var valueExpression = arguments[0]; var optionsExpression = arguments[1]; - var (toBsonType, toSerializer) = TranslateToType(expression, toType); + var toBsonType = GetResultRepresentation(expression, toType); var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var toSerializer = context.NodeSerializers.GetSerializer(expression); var (subType, byteOrder, format, onErrorAst, onNullAst) = TranslateOptions(context, expression, optionsExpression, toSerializer); var ast = AstExpression.Convert(valueTranslation.Ast, toBsonType.Render(), subType, byteOrder, format, onErrorAst, onNullAst); @@ -143,39 +144,39 @@ IBsonSerializer toSerializer return (subType, byteOrder, format, onErrorTranslation?.Ast, onNullTranslation?.Ast); } - private static (BsonType ToBsonType, IBsonSerializer ToSerializer) TranslateToType(Expression expression, Type toType) + private static BsonType GetResultRepresentation(Expression expression, Type toType) { var isNullable = toType.IsNullable(); var valueType = isNullable ? Nullable.GetUnderlyingType(toType) : toType; - var (bsonType, valueSerializer) = (ValueTuple)(Type.GetTypeCode(valueType) switch + var representation = Type.GetTypeCode(valueType) switch { - TypeCode.Boolean => (BsonType.Boolean, BooleanSerializer.Instance), - TypeCode.Byte => (BsonType.Int32, ByteSerializer.Instance), - TypeCode.Char => (BsonType.String, StringSerializer.Instance), - TypeCode.DateTime => (BsonType.DateTime, DateTimeSerializer.Instance), - TypeCode.Decimal => (BsonType.Decimal128, DecimalSerializer.Instance), - TypeCode.Double => (BsonType.Double, DoubleSerializer.Instance), - TypeCode.Int16 => (BsonType.Int32, Int16Serializer.Instance), - TypeCode.Int32 => (BsonType.Int32, Int32Serializer.Instance), - TypeCode.Int64 => (BsonType.Int64, Int64Serializer.Instance), - TypeCode.SByte => (BsonType.Int32, SByteSerializer.Instance), - TypeCode.Single => (BsonType.Double, SingleSerializer.Instance), - TypeCode.String => (BsonType.String, StringSerializer.Instance), - TypeCode.UInt16 => (BsonType.Int32, UInt16Serializer.Instance), - TypeCode.UInt32 => (BsonType.Int64, Int32Serializer.Instance), - TypeCode.UInt64 => (BsonType.Decimal128, UInt64Serializer.Instance), - - _ when valueType == typeof(byte[]) => (BsonType.Binary, ByteArraySerializer.Instance), - _ when valueType == typeof(BsonBinaryData) => (BsonType.Binary, BsonBinaryDataSerializer.Instance), - _ when valueType == typeof(Decimal128) => (BsonType.Decimal128, Decimal128Serializer.Instance), - _ when valueType == typeof(Guid) => (BsonType.Binary, GuidSerializer.StandardInstance), - _ when valueType == typeof(ObjectId) => (BsonType.ObjectId, ObjectIdSerializer.Instance), + TypeCode.Boolean => BsonType.Boolean, + TypeCode.Byte => BsonType.Int32, + TypeCode.Char => BsonType.String, + TypeCode.DateTime => BsonType.DateTime, + TypeCode.Decimal => BsonType.Decimal128, + TypeCode.Double => BsonType.Double, + TypeCode.Int16 => BsonType.Int32, + TypeCode.Int32 => BsonType.Int32, + TypeCode.Int64 => BsonType.Int64, + TypeCode.SByte => BsonType.Int32, + TypeCode.Single => BsonType.Double, + TypeCode.String => BsonType.String, + TypeCode.UInt16 => BsonType.Int32, + TypeCode.UInt32 => BsonType.Int64, + TypeCode.UInt64 => BsonType.Decimal128, + + _ when valueType == typeof(byte[]) => BsonType.Binary, + _ when valueType == typeof(BsonBinaryData) => BsonType.Binary, + _ when valueType == typeof(Decimal128) => BsonType.Decimal128, + _ when valueType == typeof(Guid) => BsonType.Binary, + _ when valueType == typeof(ObjectId) => BsonType.ObjectId, _ => throw new ExpressionNotSupportedException(expression, because: $"{toType} is not a valid TTo for Convert") - }); + }; - return (bsonType, isNullable ? NullableSerializer.Create(valueSerializer) : valueSerializer); + return representation; } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CountMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CountMethodToAggregationExpressionTranslator.cs index 73801af66f7..462a0655a44 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CountMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CountMethodToAggregationExpressionTranslator.cs @@ -27,45 +27,19 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class CountMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __countMethods; - private static readonly MethodInfo[] __countWithPredicateMethods; - - static CountMethodToAggregationExpressionTranslator() - { - __countMethods = new[] - { - EnumerableMethod.Count, - EnumerableMethod.CountWithPredicate, - EnumerableMethod.LongCount, - EnumerableMethod.LongCountWithPredicate, - QueryableMethod.Count, - QueryableMethod.CountWithPredicate, - QueryableMethod.LongCount, - QueryableMethod.LongCountWithPredicate - }; - - __countWithPredicateMethods = new[] - { - EnumerableMethod.CountWithPredicate, - EnumerableMethod.LongCountWithPredicate, - QueryableMethod.CountWithPredicate, - QueryableMethod.LongCountWithPredicate - }; - } - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__countMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.CountOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); AstExpression ast; - if (method.IsOneOf(__countWithPredicateMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.CountWithPredicateOverloads)) { if (sourceExpression.Type == typeof(string)) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CreateMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CreateMethodToAggregationExpressionTranslator.cs index 30ab3c04fbe..8931519e0d0 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CreateMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CreateMethodToAggregationExpressionTranslator.cs @@ -28,36 +28,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class CreateMethodToAggregationExpressionTranslator { - private static MethodInfo[] __tupleCreateMethods = new[] - { - TupleMethod.Create1, - TupleMethod.Create2, - TupleMethod.Create3, - TupleMethod.Create4, - TupleMethod.Create5, - TupleMethod.Create6, - TupleMethod.Create7, - TupleMethod.Create8 - }; - - private static MethodInfo[] __valueTupleCreateMethods = new[] - { - ValueTupleMethod.Create1, - ValueTupleMethod.Create2, - ValueTupleMethod.Create3, - ValueTupleMethod.Create4, - ValueTupleMethod.Create5, - ValueTupleMethod.Create6, - ValueTupleMethod.Create7, - ValueTupleMethod.Create8 - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__tupleCreateMethods) || method.IsOneOf(__valueTupleCreateMethods)) + if (method.IsOneOf(TupleOrValueTupleMethod.CreateOverloads)) { var tupleType = method.ReturnType; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateFromStringMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateFromStringMethodToAggregationExpressionTranslator.cs index 49e5c99a641..beacd741ab4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateFromStringMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateFromStringMethodToAggregationExpressionTranslator.cs @@ -25,33 +25,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class DateFromStringMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __dateFromStringMethods = - { - MqlMethod.DateFromString, - MqlMethod.DateFromStringWithFormat, - MqlMethod.DateFromStringWithFormatAndTimezone, - MqlMethod.DateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull - }; - - private static readonly MethodInfo[] __withFormatMethods = - { - MqlMethod.DateFromStringWithFormat, - MqlMethod.DateFromStringWithFormatAndTimezone, - MqlMethod.DateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull - }; - - private static readonly MethodInfo[] __withTimezoneMethods = - { - MqlMethod.DateFromStringWithFormatAndTimezone, - MqlMethod.DateFromStringWithFormatAndTimezoneAndOnErrorAndOnNull - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__dateFromStringMethods)) + if (method.IsOneOf(MqlMethod.DateFromStringOverloads)) { var dateStringExpression = arguments[0]; var dateStringTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, dateStringExpression); @@ -59,7 +38,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC IBsonSerializer resultSerializer = DateTimeSerializer.Instance; AstExpression format = null; - if (method.IsOneOf(__withFormatMethods)) + if (method.IsOneOf(MqlMethod.DateFromStringWithFormatOverloads)) { var formatExpression = arguments[1]; var formatTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, formatExpression); @@ -67,7 +46,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } AstExpression timezoneAst = null; - if (method.IsOneOf(__withTimezoneMethods)) + if (method.IsOneOf(MqlMethod.DateFromStringWithTimezoneOverloads)) { var timezoneExpression = arguments[2]; var timezoneTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, timezoneExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.cs index e3448e36864..993477d0dfb 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.cs @@ -29,81 +29,9 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class DateTimeAddOrSubtractMethodToAggregationExpressionTranslator { - private static MethodInfo[] __dateTimeAddOrSubtractMethods = new[] - { - DateTimeMethod.Add, - DateTimeMethod.AddDays, - DateTimeMethod.AddDaysWithTimezone, - DateTimeMethod.AddHours, - DateTimeMethod.AddHoursWithTimezone, - DateTimeMethod.AddMilliseconds, - DateTimeMethod.AddMillisecondsWithTimezone, - DateTimeMethod.AddMinutes, - DateTimeMethod.AddMinutesWithTimezone, - DateTimeMethod.AddMonths, - DateTimeMethod.AddMonthsWithTimezone, - DateTimeMethod.AddQuarters, - DateTimeMethod.AddQuartersWithTimezone, - DateTimeMethod.AddSeconds, - DateTimeMethod.AddSecondsWithTimezone, - DateTimeMethod.AddTicks, - DateTimeMethod.AddWeeks, - DateTimeMethod.AddWeeksWithTimezone, - DateTimeMethod.AddWithTimezone, - DateTimeMethod.AddWithUnit, - DateTimeMethod.AddWithUnitAndTimezone, - DateTimeMethod.AddYears, - DateTimeMethod.AddYearsWithTimezone, - DateTimeMethod.SubtractWithTimeSpan, - DateTimeMethod.SubtractWithTimeSpanAndTimezone, - DateTimeMethod.SubtractWithUnit, - DateTimeMethod.SubtractWithUnitAndTimezone - }; - - private static MethodInfo[] __dateTimeAddOrSubtractWithTimeSpanMethods = new[] - { - DateTimeMethod.Add, - DateTimeMethod.AddWithTimezone, - DateTimeMethod.SubtractWithTimeSpan, - DateTimeMethod.SubtractWithTimeSpanAndTimezone - }; - - private static MethodInfo[] __dateTimeAddOrSubtractWithUnitMethods = new[] - { - DateTimeMethod.AddWithUnit, - DateTimeMethod.AddWithUnitAndTimezone, - DateTimeMethod.SubtractWithUnit, - DateTimeMethod.SubtractWithUnitAndTimezone - }; - - private static MethodInfo[] __dateTimeAddOrSubtractWithTimezoneMethods = new[] - { - DateTimeMethod.AddDaysWithTimezone, - DateTimeMethod.AddHoursWithTimezone, - DateTimeMethod.AddMillisecondsWithTimezone, - DateTimeMethod.AddMinutesWithTimezone, - DateTimeMethod.AddMonthsWithTimezone, - DateTimeMethod.AddQuartersWithTimezone, - DateTimeMethod.AddSecondsWithTimezone, - DateTimeMethod.AddWeeksWithTimezone, - DateTimeMethod.AddWithTimezone, - DateTimeMethod.AddWithUnitAndTimezone, - DateTimeMethod.AddYearsWithTimezone, - DateTimeMethod.SubtractWithTimeSpanAndTimezone, - DateTimeMethod.SubtractWithUnitAndTimezone - }; - - private static MethodInfo[] __dateTimeSubtractMethods = new[] - { - DateTimeMethod.SubtractWithTimeSpan, - DateTimeMethod.SubtractWithTimeSpanAndTimezone, - DateTimeMethod.SubtractWithUnit, - DateTimeMethod.SubtractWithUnitAndTimezone - }; - public static bool CanTranslate(MethodCallExpression expression) { - return expression.Method.IsOneOf(__dateTimeAddOrSubtractMethods); + return expression.Method.IsOneOf(DateTimeMethod.AddOrSubtractOverloads); } public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) @@ -111,7 +39,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__dateTimeAddOrSubtractMethods)) + if (method.IsOneOf(DateTimeMethod.AddOrSubtractOverloads)) { Expression thisExpression, valueExpression; if (method.IsStatic) @@ -128,7 +56,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var thisTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, thisExpression); AstExpression unit, amount; - if (method.IsOneOf(__dateTimeAddOrSubtractWithTimeSpanMethods)) + if (method.IsOneOf(DateTimeMethod.AddOrSubtractWithTimeSpanOverloads)) { if (valueExpression is ConstantExpression constantValueExpression) { @@ -161,7 +89,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC }; } } - else if (method.IsOneOf(__dateTimeAddOrSubtractWithUnitMethods)) + else if (method.IsOneOf(DateTimeMethod.AddOrSubtractWithUnitOverloads)) { var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); var valueAst = ConvertHelper.RemoveWideningConvert(valueTranslation); @@ -192,14 +120,14 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } AstExpression timezone = null; - if (method.IsOneOf(__dateTimeAddOrSubtractWithTimezoneMethods)) + if (method.IsOneOf(DateTimeMethod.AddOrSubtractWithTimezoneOverloads)) { var timezoneExpression = arguments.Last(); var timezoneTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, timezoneExpression); timezone = timezoneTranslation.Ast; } - var ast = method.IsOneOf(__dateTimeSubtractMethods) ? + var ast = method.IsOneOf(DateTimeMethod.SubtractReturningDateTimeOverloads) ? AstExpression.DateSubtract(thisTranslation.Ast, unit, amount, timezone) : AstExpression.DateAdd(thisTranslation.Ast, unit, amount, timezone); var serializer = DateTimeSerializer.UtcInstance; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeSubtractWithDateTimeMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeSubtractWithDateTimeMethodToAggregationExpressionTranslator.cs index 2631cc193a7..72b90ea5115 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeSubtractWithDateTimeMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeSubtractWithDateTimeMethodToAggregationExpressionTranslator.cs @@ -29,29 +29,9 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class DateTimeSubtractWithDateTimeMethodToAggregationExpressionTranslator { - private readonly static MethodInfo[] __dateTimeSubtractWithDateTimeMethods = - { - DateTimeMethod.SubtractWithDateTime, - DateTimeMethod.SubtractWithDateTimeAndTimezone, - DateTimeMethod.SubtractWithDateTimeAndUnit, - DateTimeMethod.SubtractWithDateTimeAndUnitAndTimezone - }; - - private readonly static MethodInfo[] __dateTimeSubtractWithTimezoneMethods = - { - DateTimeMethod.SubtractWithDateTimeAndTimezone, - DateTimeMethod.SubtractWithDateTimeAndUnitAndTimezone - }; - - private readonly static MethodInfo[] __dateTimeSubtractWithUnitMethods = - { - DateTimeMethod.SubtractWithDateTimeAndUnit, - DateTimeMethod.SubtractWithDateTimeAndUnitAndTimezone - }; - public static bool CanTranslate(MethodCallExpression expression) { - return expression.Method.IsOneOf(__dateTimeSubtractWithDateTimeMethods); + return expression.Method.IsOneOf(DateTimeMethod.SubtractWithDateTimeOverloads); } public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) @@ -59,7 +39,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__dateTimeSubtractWithDateTimeMethods)) + if (method.IsOneOf(DateTimeMethod.SubtractWithDateTimeOverloads)) { Expression thisExpression, valueExpression; if (method.IsStatic) @@ -78,7 +58,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC AstExpression unit, startOfWeek; IBsonSerializer serializer; - if (method.IsOneOf(__dateTimeSubtractWithUnitMethods)) + if (method.IsOneOf(DateTimeMethod.SubtractWithUnitOverloads)) { var unitExpression = arguments[2]; var unitConstant = unitExpression.GetConstantValue(containingExpression: expression); @@ -101,7 +81,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } AstExpression timezone = null; - if (method.IsOneOf(__dateTimeSubtractWithTimezoneMethods)) + if (method.IsOneOf(DateTimeMethod.SubtractWithTimezoneOverloads)) { var timezoneExpression = arguments.Last(); var timezoneTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, timezoneExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs index ce74b191fa0..706b488a3a3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs @@ -25,26 +25,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class DefaultIfEmptyMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __defaultIfEmptyMethods = - { - EnumerableMethod.DefaultIfEmpty, - EnumerableMethod.DefaultIfEmptyWithDefaultValue, - QueryableMethod.DefaultIfEmpty, - QueryableMethod.DefaultIfEmptyWithDefaultValue, - }; - - private static readonly MethodInfo[] __defaultIfEmptyWithDefaultValueMethods = - { - EnumerableMethod.DefaultIfEmptyWithDefaultValue, - QueryableMethod.DefaultIfEmptyWithDefaultValue, - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__defaultIfEmptyMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.DefaultIfEmptyOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -53,7 +39,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var (sourceVarBinding, sourceAst) = AstExpression.UseVarIfNotSimple("source", sourceTranslation.Ast); AstExpression defaultValueAst; - if (method.IsOneOf(__defaultIfEmptyWithDefaultValueMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.DefaultIfEmptyWithDefaultValue)) { var defaultValueExpression = arguments[1]; var defaultValueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, defaultValueExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs index 0486cca5e0b..7b573c4a4b1 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class DistinctMethodToAggregationExpressionTranslator { - private readonly static MethodInfo[] __distinctMethods = - { - EnumerableMethod.Distinct, - QueryableMethod.Distinct - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__distinctMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Distinct)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs index e6e9cf24e1d..d9bc8509335 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs @@ -23,26 +23,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ElementAtMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __elementAtMethods = - { - EnumerableMethod.ElementAt, - EnumerableMethod.ElementAtOrDefault, - QueryableMethod.ElementAt, - QueryableMethod.ElementAtOrDefault - }; - - private static readonly MethodInfo[] __elementAtOrDefaultMethods = - { - EnumerableMethod.ElementAtOrDefault, - QueryableMethod.ElementAtOrDefault - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__elementAtMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.ElementAtOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -53,7 +39,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var indexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, indexExpression); AstExpression ast; - if (method.IsOneOf(__elementAtOrDefaultMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.ElementAtOrDefault)) { var defaultValue = itemSerializer.ValueType.GetDefaultValue(); var serializedDefaultValue = SerializationHelper.SerializeValue(itemSerializer, defaultValue); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs index a42526b03ef..45f22d9c09a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs @@ -24,21 +24,15 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class EnumerableConcatMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __concatMethods = - { - EnumerableMethod.Concat, - QueryableMethod.Concat - }; - public static bool CanTranslate(MethodCallExpression expression) - => expression.Method.IsOneOf(__concatMethods); + => expression.Method.IsOneOf(EnumerableOrQueryableMethod.Concat); public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__concatMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Concat)) { var firstExpression = arguments[0]; var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ExceptMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ExceptMethodToAggregationExpressionTranslator.cs index a539d5e750d..ddc528f6d5c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ExceptMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ExceptMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ExceptMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __exceptMethods = - { - EnumerableMethod.Except, - QueryableMethod.Except - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__exceptMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Except)) { var firstExpression = arguments[0]; var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs index 96199df8207..811a2ab4346 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs @@ -25,68 +25,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class FirstOrLastMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __firstOrLastMethods = - { - EnumerableMethod.First, - EnumerableMethod.FirstWithPredicate, - EnumerableMethod.FirstOrDefault, - EnumerableMethod.FirstOrDefaultWithPredicate, - EnumerableMethod.Last, - EnumerableMethod.LastWithPredicate, - EnumerableMethod.LastOrDefault, - EnumerableMethod.LastOrDefaultWithPredicate, - QueryableMethod.First, - QueryableMethod.FirstWithPredicate, - QueryableMethod.FirstOrDefault, - QueryableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.Last, - QueryableMethod.LastWithPredicate, - QueryableMethod.LastOrDefault, - QueryableMethod.LastOrDefaultWithPredicate - }; - - private static readonly MethodInfo[] __firstMethods = - { - EnumerableMethod.First, - EnumerableMethod.FirstWithPredicate, - EnumerableMethod.FirstOrDefault, - EnumerableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.First, - QueryableMethod.FirstWithPredicate, - QueryableMethod.FirstOrDefault, - QueryableMethod.FirstOrDefaultWithPredicate - }; - - private static readonly MethodInfo[] __orDefaultMethods = - { - EnumerableMethod.FirstOrDefault, - EnumerableMethod.FirstOrDefaultWithPredicate, - EnumerableMethod.LastOrDefault, - EnumerableMethod.LastOrDefaultWithPredicate, - QueryableMethod.FirstOrDefault, - QueryableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.LastOrDefault, - QueryableMethod.LastOrDefaultWithPredicate - }; - - private static readonly MethodInfo[] __withPredicateMethods = - { - EnumerableMethod.FirstWithPredicate, - EnumerableMethod.FirstOrDefaultWithPredicate, - EnumerableMethod.LastWithPredicate, - EnumerableMethod.LastOrDefaultWithPredicate, - QueryableMethod.FirstWithPredicate, - QueryableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.LastWithPredicate, - QueryableMethod.LastOrDefaultWithPredicate - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__firstOrLastMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrLastOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -95,9 +39,9 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var sourceAst = sourceTranslation.Ast; var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var isFirstMethod = method.IsOneOf(__firstMethods); + var isFirstMethod = method.IsOneOf(EnumerableOrQueryableMethod.FirstOverloads); - if (method.IsOneOf(__withPredicateMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrLastWithPredicateOverloads)) { var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); var parameterExpression = predicateLambda.Parameters.Single(); @@ -122,7 +66,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } AstExpression ast; - if (method.IsOneOf(__orDefaultMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.FirstOrDefaultOverloads, EnumerableOrQueryableMethod.LastOrDefaultOverloads)) { var defaultValue = itemSerializer.ValueType.GetDefaultValue(); var serializedDefaultValue = SerializationHelper.SerializeValue(itemSerializer, defaultValue); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfAnyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfAnyMethodToAggregationExpressionTranslator.cs index e21ebb9d28b..d8303404772 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfAnyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfAnyMethodToAggregationExpressionTranslator.cs @@ -28,19 +28,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class IndexOfAnyMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __indexOfAnyMethods = - { - StringMethod.IndexOfAny, - StringMethod.IndexOfAnyWithStartIndex, - StringMethod.IndexOfAnyWithStartIndexAndCount, - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__indexOfAnyMethods)) + if (method.IsOneOf(StringMethod.IndexOfAnyOverloads)) { var (stringVar, stringAst) = TranslateObject(expression.Object); var anyOf = TranslateAnyOf(arguments); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfMethodToAggregationExpressionTranslator.cs index 8fe5f57f89a..e6392e3c8af 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfMethodToAggregationExpressionTranslator.cs @@ -27,56 +27,6 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class IndexOfMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __indexOfMethods = - { - StringMethod.IndexOfWithChar, - StringMethod.IndexOfBytesWithValue, - StringMethod.IndexOfBytesWithValueAndStartIndex, - StringMethod.IndexOfBytesWithValueAndStartIndexAndCount, - StringMethod.IndexOfWithCharAndStartIndex, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - StringMethod.IndexOfWithString, - StringMethod.IndexOfWithStringAndStartIndex, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - private static readonly MethodInfo[] __indexOfWithStartIndexMethods = - { - StringMethod.IndexOfBytesWithValueAndStartIndex, - StringMethod.IndexOfBytesWithValueAndStartIndexAndCount, - StringMethod.IndexOfWithCharAndStartIndex, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndex, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - private static readonly MethodInfo[] __indexOfWithCountMethods = - { - StringMethod.IndexOfBytesWithValueAndStartIndexAndCount, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - private static readonly MethodInfo[] __indexOfWithStringComparisonMethods = - { - StringMethod.IndexOfWithStringAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - private static readonly MethodInfo[] __indexOfBytesMethods = - { - StringMethod.IndexOfBytesWithValue, - StringMethod.IndexOfBytesWithValueAndStartIndex, - StringMethod.IndexOfBytesWithValueAndStartIndexAndCount - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { if (IsStringIndexOfMethod(expression, out var objectExpression, out var valueExpression, out var startIndexExpression, out var countExpression, out var comparisonTypeExpression)) @@ -100,7 +50,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var endAst = CreateEndAst(startIndexTranslation?.Ast, countTranslation?.Ast); AstExpression ast; - if (expression.Method.IsOneOf(__indexOfBytesMethods) || ordinal) + if (expression.Method.IsOneOf(StringMethod.IndexOfBytesOverloads) || ordinal) { ast = AstExpression.IndexOfBytes(objectTranslation.Ast, valueTranslation.Ast, startIndexTranslation?.Ast, endAst); } @@ -167,14 +117,14 @@ private static bool IsStringIndexOfMethod( var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__indexOfMethods)) + if (method.IsOneOf(StringMethod.IndexOfOverloads)) { - if (method.IsOneOf(__indexOfBytesMethods)) + if (method.IsOneOf(StringMethod.IndexOfBytesOverloads)) { instanceExpression = arguments[0]; valueExpression = arguments[1]; - startIndexExpression = method.IsOneOf(__indexOfWithStartIndexMethods) ? arguments[2] : null; - countExpression = method.IsOneOf(__indexOfWithCountMethods) ? arguments[3] : null; + startIndexExpression = method.IsOneOf(StringMethod.IndexOfWithStartIndexOverloads) ? arguments[2] : null; + countExpression = method.IsOneOf(StringMethod.IndexOfWithCountOverloads) ? arguments[3] : null; comparisonTypeExpression = null; return true; } @@ -182,9 +132,9 @@ private static bool IsStringIndexOfMethod( { instanceExpression = expression.Object; valueExpression = arguments[0]; - startIndexExpression = method.IsOneOf(__indexOfWithStartIndexMethods) ? arguments[1] : null; - countExpression = method.IsOneOf(__indexOfWithCountMethods) ? arguments[2] : null; - comparisonTypeExpression = method.IsOneOf(__indexOfWithStringComparisonMethods) ? arguments.Last() : null; + startIndexExpression = method.IsOneOf(StringMethod.IndexOfWithStartIndexOverloads) ? arguments[1] : null; + countExpression = method.IsOneOf(StringMethod.IndexOfWithCountOverloads) ? arguments[2] : null; + comparisonTypeExpression = method.IsOneOf(StringMethod.IndexOfWithStringComparisonOverloads) ? arguments.Last() : null; return true; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs index c5519f5547d..b17146ef30c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IntersectMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class IntersectMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __intersectMethods = - { - EnumerableMethod.Intersect, - QueryableMethod.Interset - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__intersectMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Intersect)) { var firstExpression = arguments[0]; var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IsMissingMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IsMissingMethodToAggregationExpressionTranslator.cs index 6c71494dbe9..579711462a4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IsMissingMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IsMissingMethodToAggregationExpressionTranslator.cs @@ -27,19 +27,19 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class IsMissingMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __isMissingMethods = - { + private static readonly IReadOnlyMethodInfoSet __translatableOverloads = MethodInfoSet.Create( + [ MqlMethod.Exists, MqlMethod.IsMissing, MqlMethod.IsNullOrMissing, - }; + ]); public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__isMissingMethods)) + if (method.IsOneOf(__translatableOverloads)) { var fieldExpression = arguments[0]; var fieldTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, fieldExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MaxOrMinMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MaxOrMinMethodToAggregationExpressionTranslator.cs index 1019fb10228..40d3f58704b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MaxOrMinMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MaxOrMinMethodToAggregationExpressionTranslator.cs @@ -24,64 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class MaxOrMinMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __maxOrMinMethods = - { - EnumerableMethod.Max, - EnumerableMethod.MaxDecimal, - EnumerableMethod.MaxDecimalWithSelector, - EnumerableMethod.MaxDouble, - EnumerableMethod.MaxDoubleWithSelector, - EnumerableMethod.MaxInt32, - EnumerableMethod.MaxInt32WithSelector, - EnumerableMethod.MaxInt64, - EnumerableMethod.MaxInt64WithSelector, - EnumerableMethod.MaxNullableDecimal, - EnumerableMethod.MaxNullableDecimalWithSelector, - EnumerableMethod.MaxNullableDouble, - EnumerableMethod.MaxNullableDoubleWithSelector, - EnumerableMethod.MaxNullableInt32, - EnumerableMethod.MaxNullableInt32WithSelector, - EnumerableMethod.MaxNullableInt64, - EnumerableMethod.MaxNullableInt64WithSelector, - EnumerableMethod.MaxNullableSingle, - EnumerableMethod.MaxNullableSingleWithSelector, - EnumerableMethod.MaxSingle, - EnumerableMethod.MaxSingleWithSelector, - EnumerableMethod.MaxWithSelector, - EnumerableMethod.Min, - EnumerableMethod.MinDecimal, - EnumerableMethod.MinDecimalWithSelector, - EnumerableMethod.MinDouble, - EnumerableMethod.MinDoubleWithSelector, - EnumerableMethod.MinInt32, - EnumerableMethod.MinInt32WithSelector, - EnumerableMethod.MinInt64, - EnumerableMethod.MinInt64WithSelector, - EnumerableMethod.MinNullableDecimal, - EnumerableMethod.MinNullableDecimalWithSelector, - EnumerableMethod.MinNullableDouble, - EnumerableMethod.MinNullableDoubleWithSelector, - EnumerableMethod.MinNullableInt32, - EnumerableMethod.MinNullableInt32WithSelector, - EnumerableMethod.MinNullableInt64, - EnumerableMethod.MinNullableInt64WithSelector, - EnumerableMethod.MinNullableSingle, - EnumerableMethod.MinNullableSingleWithSelector, - EnumerableMethod.MinSingle, - EnumerableMethod.MinSingleWithSelector, - EnumerableMethod.MinWithSelector, - QueryableMethod.Max, - QueryableMethod.MaxWithSelector, - QueryableMethod.Min, - QueryableMethod.MinWithSelector - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__maxOrMinMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.MaxOrMinOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs index 0baa8709c1d..7572c2dd83f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/MedianMethodToAggregationExpressionTranslator.cs @@ -24,50 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal class MedianMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __medianMethods = - [ - MongoEnumerableMethod.MedianDecimal, - MongoEnumerableMethod.MedianDecimalWithSelector, - MongoEnumerableMethod.MedianDouble, - MongoEnumerableMethod.MedianDoubleWithSelector, - MongoEnumerableMethod.MedianInt32, - MongoEnumerableMethod.MedianInt32WithSelector, - MongoEnumerableMethod.MedianInt64, - MongoEnumerableMethod.MedianInt64WithSelector, - MongoEnumerableMethod.MedianNullableDecimal, - MongoEnumerableMethod.MedianNullableDecimalWithSelector, - MongoEnumerableMethod.MedianNullableDouble, - MongoEnumerableMethod.MedianNullableDoubleWithSelector, - MongoEnumerableMethod.MedianNullableInt32, - MongoEnumerableMethod.MedianNullableInt32WithSelector, - MongoEnumerableMethod.MedianNullableInt64, - MongoEnumerableMethod.MedianNullableInt64WithSelector, - MongoEnumerableMethod.MedianNullableSingle, - MongoEnumerableMethod.MedianNullableSingleWithSelector, - MongoEnumerableMethod.MedianSingle, - MongoEnumerableMethod.MedianSingleWithSelector - ]; - - private static readonly MethodInfo[] __medianWithSelectorMethods = - [ - MongoEnumerableMethod.MedianDecimalWithSelector, - MongoEnumerableMethod.MedianDoubleWithSelector, - MongoEnumerableMethod.MedianInt32WithSelector, - MongoEnumerableMethod.MedianInt64WithSelector, - MongoEnumerableMethod.MedianNullableDecimalWithSelector, - MongoEnumerableMethod.MedianNullableDoubleWithSelector, - MongoEnumerableMethod.MedianNullableInt32WithSelector, - MongoEnumerableMethod.MedianNullableInt64WithSelector, - MongoEnumerableMethod.MedianNullableSingleWithSelector, - MongoEnumerableMethod.MedianSingleWithSelector - ]; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__medianMethods)) + if (method.IsOneOf(MongoEnumerableMethod.MedianOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -75,7 +37,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var inputAst = sourceTranslation.Ast; - if (method.IsOneOf(__medianWithSelectorMethods)) + if (method.IsOneOf(MongoEnumerableMethod.MedianWithSelectorOverloads)) { var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); @@ -104,4 +66,4 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC throw new ExpressionNotSupportedException(expression); } } -} \ No newline at end of file +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OfTypeMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OfTypeMethodToAggregationExpressionTranslator.cs index 5ce6bbe01f1..4b5dc9a57f9 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OfTypeMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OfTypeMethodToAggregationExpressionTranslator.cs @@ -27,18 +27,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class OfTypeMethodToAggregationExpressionTranslator { - private static MethodInfo[] __ofTypeMethods = - { - EnumerableMethod.OfType, - QueryableMethod.OfType - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__ofTypeMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.OfType)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OrderByMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OrderByMethodToAggregationExpressionTranslator.cs index a91bd16d189..2bbc5a2eff4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OrderByMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OrderByMethodToAggregationExpressionTranslator.cs @@ -27,44 +27,16 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class OrderByMethodToAggregationExpressionTranslator { - private static MethodInfo[] __translatableMethods = - { - EnumerableMethod.OrderBy, - EnumerableMethod.OrderByDescending, - EnumerableMethod.ThenBy, - EnumerableMethod.ThenByDescending, - QueryableMethod.OrderBy, - QueryableMethod.OrderByDescending, - QueryableMethod.ThenBy, - QueryableMethod.ThenByDescending - }; - - private static MethodInfo[] __orderByMethods = - { - EnumerableMethod.OrderBy, - EnumerableMethod.OrderByDescending, - QueryableMethod.OrderBy, - QueryableMethod.OrderByDescending - }; - - private static MethodInfo[] __thenByMethods = - { - EnumerableMethod.ThenBy, - EnumerableMethod.ThenByDescending, - QueryableMethod.ThenBy, - QueryableMethod.ThenByDescending - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__translatableMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.OrderByOrThenByOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); - if (method.IsOneOf(__thenByMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.ThenByOverloads)) { NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsOrderedQueryableSource(expression, sourceTranslation); } @@ -80,7 +52,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC if (IsIdentityLambda(keySelectorLambda)) { - if (method.IsOneOf(__orderByMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.OrderByOverloads)) { var ast = AstExpression.SortArray(sourceTranslation.Ast, order); return new TranslatedExpression(expression, ast, orderedEnumerableSerializer); @@ -92,13 +64,13 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var sortFieldPath = keySelectorLambda.TranslateToDottedFieldName(context, itemSerializer); var sortField = AstSort.Field(sortFieldPath, order); - if (method.IsOneOf(__orderByMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.OrderByOverloads)) { var ast = AstExpression.SortArray(sourceTranslation.Ast, sortField); return new TranslatedExpression(expression, ast, orderedEnumerableSerializer); } - if (method.IsOneOf(__thenByMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.ThenByOverloads)) { if (sourceTranslation.Ast is AstSortArrayExpression originalAst) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs index 216d89f1c49..005b5e7b80d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PercentileMethodToAggregationExpressionTranslator.cs @@ -24,50 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal class PercentileMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __percentileMethods = - [ - MongoEnumerableMethod.PercentileDecimal, - MongoEnumerableMethod.PercentileDecimalWithSelector, - MongoEnumerableMethod.PercentileDouble, - MongoEnumerableMethod.PercentileDoubleWithSelector, - MongoEnumerableMethod.PercentileInt32, - MongoEnumerableMethod.PercentileInt32WithSelector, - MongoEnumerableMethod.PercentileInt64, - MongoEnumerableMethod.PercentileInt64WithSelector, - MongoEnumerableMethod.PercentileNullableDecimal, - MongoEnumerableMethod.PercentileNullableDecimalWithSelector, - MongoEnumerableMethod.PercentileNullableDouble, - MongoEnumerableMethod.PercentileNullableDoubleWithSelector, - MongoEnumerableMethod.PercentileNullableInt32, - MongoEnumerableMethod.PercentileNullableInt32WithSelector, - MongoEnumerableMethod.PercentileNullableInt64, - MongoEnumerableMethod.PercentileNullableInt64WithSelector, - MongoEnumerableMethod.PercentileNullableSingle, - MongoEnumerableMethod.PercentileNullableSingleWithSelector, - MongoEnumerableMethod.PercentileSingle, - MongoEnumerableMethod.PercentileSingleWithSelector - ]; - - private static readonly MethodInfo[] __percentileWithSelectorMethods = - [ - MongoEnumerableMethod.PercentileDecimalWithSelector, - MongoEnumerableMethod.PercentileDoubleWithSelector, - MongoEnumerableMethod.PercentileInt32WithSelector, - MongoEnumerableMethod.PercentileInt64WithSelector, - MongoEnumerableMethod.PercentileNullableDecimalWithSelector, - MongoEnumerableMethod.PercentileNullableDoubleWithSelector, - MongoEnumerableMethod.PercentileNullableInt32WithSelector, - MongoEnumerableMethod.PercentileNullableInt64WithSelector, - MongoEnumerableMethod.PercentileNullableSingleWithSelector, - MongoEnumerableMethod.PercentileSingleWithSelector - ]; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__percentileMethods)) + if (method.IsOneOf(MongoEnumerableMethod.PercentileOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -75,7 +37,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var inputAst = sourceTranslation.Ast; - if (method.IsOneOf(__percentileWithSelectorMethods)) + if (method.IsOneOf(MongoEnumerableMethod.PercentileWithSelectorOverloads)) { var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); @@ -107,4 +69,4 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC throw new ExpressionNotSupportedException(expression); } } -} \ No newline at end of file +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PickMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PickMethodToAggregationExpressionTranslator.cs index ae21565dc68..46c4e3495f7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PickMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PickMethodToAggregationExpressionTranslator.cs @@ -31,87 +31,25 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class PickMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __pickMethods = new[] - { - EnumerableMethod.Bottom, - EnumerableMethod.BottomN, - EnumerableMethod.BottomNWithComputedN, - EnumerableMethod.FirstN, - EnumerableMethod.FirstNWithComputedN, - EnumerableMethod.LastN, - EnumerableMethod.LastNWithComputedN, - EnumerableMethod.MaxN, - EnumerableMethod.MaxNWithComputedN, - EnumerableMethod.MinN, - EnumerableMethod.MinNWithComputedN, - EnumerableMethod.Top, - EnumerableMethod.TopN, - EnumerableMethod.TopNWithComputedN - }; - - private static readonly MethodInfo[] __withNMethods = new[] - { - EnumerableMethod.BottomN, - EnumerableMethod.FirstN, - EnumerableMethod.LastN, - EnumerableMethod.MaxN, - EnumerableMethod.MinN, - EnumerableMethod.TopN - }; - - private static readonly MethodInfo[] __withComputedNMethods = new[] - { - EnumerableMethod.BottomNWithComputedN, - EnumerableMethod.FirstNWithComputedN, - EnumerableMethod.LastNWithComputedN, - EnumerableMethod.MaxNWithComputedN, - EnumerableMethod.MinNWithComputedN, - EnumerableMethod.TopNWithComputedN - }; - - private static readonly MethodInfo[] __withSortByMethods = new[] - { - EnumerableMethod.Bottom, - EnumerableMethod.BottomN, - EnumerableMethod.BottomNWithComputedN, - EnumerableMethod.Top, - EnumerableMethod.TopN, - EnumerableMethod.TopNWithComputedN - }; - - private static readonly MethodInfo[] __accumulatorOnlyMethods = new[] - { - EnumerableMethod.Bottom, - EnumerableMethod.BottomN, - EnumerableMethod.BottomNWithComputedN, - EnumerableMethod.FirstNWithComputedN, - EnumerableMethod.LastNWithComputedN, - EnumerableMethod.MaxNWithComputedN, - EnumerableMethod.MinNWithComputedN, - EnumerableMethod.Top, - EnumerableMethod.TopN, - EnumerableMethod.TopNWithComputedN - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments.ToArray(); - if (method.IsOneOf(__pickMethods)) + if (method.IsOneOf(EnumerableMethod.PickOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - if (method.IsOneOf(__accumulatorOnlyMethods) && !IsGroupingSource(sourceTranslation.Ast)) + if (method.IsOneOf(EnumerableMethod.PickOverloadsThatCanOnlyBeUsedAsGroupByAccumulators) && !IsGroupingSource(sourceTranslation.Ast)) { throw new ExpressionNotSupportedException(expression, because: $"{method.Name} can only be used as an accumulator with GroupBy"); } AstSortFields sortBy = null; - if (method.IsOneOf(__withSortByMethods)) + if (method.IsOneOf(EnumerableMethod.PickWithSortByOverloads)) { var sortByExpression = arguments[1]; var sortByDefinition = GetSortByDefinition(sortByExpression, expression); @@ -126,10 +64,10 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC TranslatedExpression nTranslation = null; IBsonSerializer resultSerializer; - if (method.IsOneOf(__withNMethods, __withComputedNMethods)) + if (method.IsOneOf(EnumerableMethod.PickWithNOverloads, EnumerableMethod.PickWithComputedNOverloads)) { var nExpression = arguments.Last(); - if (method.IsOneOf(__withNMethods)) + if (method.IsOneOf(EnumerableMethod.PickWithNOverloads)) { nTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, nExpression); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PowMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PowMethodToAggregationExpressionTranslator.cs index 321d0917404..2d32bcd5011 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PowMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PowMethodToAggregationExpressionTranslator.cs @@ -28,7 +28,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(MathMethod.Pow)) + if (method.Is(MathMethod.Pow)) { var xExpression = arguments[0]; var yExpression = arguments[1]; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ReverseMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ReverseMethodToAggregationExpressionTranslator.cs index c82f5f7ec2f..846831884a9 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ReverseMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ReverseMethodToAggregationExpressionTranslator.cs @@ -14,7 +14,6 @@ */ using System.Linq.Expressions; -using System.Reflection; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; @@ -24,18 +23,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ReverseMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __reverseMethods = - { - EnumerableMethod.Reverse, - QueryableMethod.Reverse - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__reverseMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.ReverseOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/RoundMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/RoundMethodToAggregationExpressionTranslator.cs index e06fdd30e7f..bba61470227 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/RoundMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/RoundMethodToAggregationExpressionTranslator.cs @@ -23,26 +23,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class RoundMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __roundMethods = - { - MathMethod.RoundWithDecimal, - MathMethod.RoundWithDecimalAndDecimals, - MathMethod.RoundWithDouble, - MathMethod.RoundWithDoubleAndDigits - }; - - private static readonly MethodInfo[] __roundWithPlaceMethods = - { - MathMethod.RoundWithDecimalAndDecimals, - MathMethod.RoundWithDoubleAndDigits - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__roundMethods)) + if (method.IsOneOf(MathMethod.RoundOverloads)) { var argumentExpression = arguments[0]; var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression); @@ -50,7 +36,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var argumentAst = ConvertHelper.RemoveWideningConvert(argumentTranslation); AstExpression ast; - if (method.IsOneOf(__roundWithPlaceMethods)) + if (method.IsOneOf(MathMethod.RoundWithPlaceOverloads)) { var placeExpression = arguments[1]; var placeTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, placeExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs index 89b67968c24..1a020ceeb74 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs @@ -25,18 +25,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class SelectManyMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __selectManyMethods = - { - EnumerableMethod.SelectMany, - QueryableMethod.SelectMany - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__selectManyMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyWithSelector)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectMethodToAggregationExpressionTranslator.cs index 27210bea1c1..eb62efc0395 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class SelectMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __selectMethods = - { - EnumerableMethod.Select, - QueryableMethod.Select - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__selectMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Select)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SetEqualsMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SetEqualsMethodToAggregationExpressionTranslator.cs index 3cacd4c462c..51679bddf22 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SetEqualsMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SetEqualsMethodToAggregationExpressionTranslator.cs @@ -17,6 +17,7 @@ using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { @@ -24,8 +25,11 @@ internal static class SetEqualsMethodToAggregationExpressionTranslator { public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { - if (IsSetEqualsMethod(expression, out var objectExpression, out var otherExpression)) + if (ISetMethod.IsSetEqualsMethod(expression.Method)) { + var objectExpression = expression.Object; + var otherExpression = expression.Arguments[0]; + var objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression); var otherTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, otherExpression); var ast = AstExpression.SetEquals(objectTranslation.Ast, otherTranslation.Ast); @@ -34,34 +38,5 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } throw new ExpressionNotSupportedException(expression); } - - private static bool IsSetEqualsMethod(MethodCallExpression expression, out Expression objectExpression, out Expression otherExpression) - { - var method = expression.Method; - var arguments = expression.Arguments; - - if (!method.IsStatic && - method.ReturnType == typeof(bool) && - method.Name == "SetEquals" && - arguments.Count == 1) - { - objectExpression = expression.Object; - otherExpression = arguments[0]; - if (objectExpression.Type.TryGetIEnumerableGenericInterface(out var objectEnumerableInterface) && - otherExpression.Type.TryGetIEnumerableGenericInterface(out var otherEnumerableInterface)) - { - var objectItemType = objectEnumerableInterface.GetGenericArguments()[0]; - var otherItemType = otherEnumerableInterface.GetGenericArguments()[0]; - if (objectItemType == otherItemType) - { - return true; - } - } - } - - objectExpression = null; - otherExpression = null; - return false; - } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SkipOrTakeMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SkipOrTakeMethodToAggregationExpressionTranslator.cs index 6004bbf8bd4..42742327cd2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SkipOrTakeMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SkipOrTakeMethodToAggregationExpressionTranslator.cs @@ -24,32 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class SkipOrTakeMethodToAggregationExpressionTranslator { - private static MethodInfo[] __skipOrTakeMethods = - { - EnumerableMethod.Skip, - EnumerableMethod.Take, - QueryableMethod.Skip, - QueryableMethod.Take, - }; - - private static MethodInfo[] __skipMethods = - { - EnumerableMethod.Skip, - QueryableMethod.Skip - }; - - private static MethodInfo[] __takeMethods = - { - EnumerableMethod.Take, - QueryableMethod.Take - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__skipOrTakeMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.SkipOrTakeOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -63,8 +43,8 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var ast = method switch { - _ when method.IsOneOf(__skipMethods) => AstExpression.Slice(sourceTranslation.Ast, countAst, int.MaxValue), - _ when method.IsOneOf(__takeMethods) => AstExpression.Slice(sourceTranslation.Ast, countAst), + _ when method.IsOneOf(EnumerableOrQueryableMethod.SkipOverloads) => AstExpression.Slice(sourceTranslation.Ast, countAst, int.MaxValue), + _ when method.IsOneOf(EnumerableOrQueryableMethod.TakeOverloads) => AstExpression.Slice(sourceTranslation.Ast, countAst), _ => throw new ExpressionNotSupportedException(expression) }; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SkipWhileOrTakeWhileMethodToAggreggationExpressoinTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SkipWhileOrTakeWhileMethodToAggreggationExpressoinTranslator.cs index d9648b7dba5..221b5262c6c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SkipWhileOrTakeWhileMethodToAggreggationExpressoinTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SkipWhileOrTakeWhileMethodToAggreggationExpressoinTranslator.cs @@ -27,32 +27,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class SkipWhileOrTakeWhileMethodToAggregationExpressionTranslator { - private static MethodInfo[] __skipWhileOrTakeWhileMethods = - { - EnumerableMethod.SkipWhile, - EnumerableMethod.TakeWhile, - QueryableMethod.SkipWhile, - QueryableMethod.TakeWhile - }; - - private static MethodInfo[] __skipWhileMethods = - { - EnumerableMethod.SkipWhile, - QueryableMethod.SkipWhile - }; - - private static MethodInfo[] __takeWhileMethods = - { - EnumerableMethod.TakeWhile, - QueryableMethod.TakeWhile - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__skipWhileOrTakeWhileMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.SkipWhileOrTakeWhile)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -88,8 +68,8 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var sliceAst = method switch { - _ when method.IsOneOf(__skipWhileMethods) => AstExpression.Slice(sourceAst, whileCountField, int.MaxValue), - _ when method.IsOneOf(__takeWhileMethods) => AstExpression.Slice(sourceAst, whileCountField), + _ when method.IsOneOf(EnumerableOrQueryableMethod.SkipWhile) => AstExpression.Slice(sourceAst, whileCountField, int.MaxValue), + _ when method.IsOneOf(EnumerableOrQueryableMethod.TakeWhile) => AstExpression.Slice(sourceAst, whileCountField), _ => throw new ExpressionNotSupportedException(expression) }; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SplitMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SplitMethodToAggregationExpressionTranslator.cs index 4ae9b69abb2..326518f20b4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SplitMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SplitMethodToAggregationExpressionTranslator.cs @@ -26,57 +26,18 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class SplitMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __splitMethods = new[] - { - StringMethod.SplitWithChars, - StringMethod.SplitWithCharsAndCount, - StringMethod.SplitWithCharsAndCountAndOptions, - StringMethod.SplitWithCharsAndOptions, - StringMethod.SplitWithStringsAndCountAndOptions, - StringMethod.SplitWithStringsAndOptions - }; - - private static readonly MethodInfo[] __splitWithCharsMethods = new[] - { - StringMethod.SplitWithChars, - StringMethod.SplitWithCharsAndCount, - StringMethod.SplitWithCharsAndCountAndOptions, - StringMethod.SplitWithCharsAndOptions - }; - - private static readonly MethodInfo[] __splitWithCountMethods = new[] - { - StringMethod.SplitWithCharsAndCount, - StringMethod.SplitWithCharsAndCountAndOptions, - StringMethod.SplitWithStringsAndCountAndOptions, - }; - - private static readonly MethodInfo[] __splitWithOptionsMethods = new[] - { - StringMethod.SplitWithCharsAndCountAndOptions, - StringMethod.SplitWithCharsAndOptions, - StringMethod.SplitWithStringsAndCountAndOptions, - StringMethod.SplitWithStringsAndOptions - }; - - private static readonly MethodInfo[] __splitWithStringsMethods = new[] - { - StringMethod.SplitWithStringsAndCountAndOptions, - StringMethod.SplitWithStringsAndOptions - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__splitMethods)) + if (method.IsOneOf(StringMethod.SplitOverloads)) { var stringExpression = expression.Object; var stringTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, stringExpression); string delimiter; - if (method.IsOneOf(__splitWithCharsMethods)) + if (method.IsOneOf(StringMethod.SplitWithCharsOverloads)) { var separatorsExpression = arguments[0]; var separatorChars = separatorsExpression.GetConstantValue(containingExpression: expression); @@ -86,7 +47,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC } delimiter = new string(separatorChars[0], 1); } - else if (method.IsOneOf(__splitWithStringsMethods)) + else if (method.IsOneOf(StringMethod.SplitWithStringsOverloads)) { var separatorsExpression = arguments[0]; var separatorStrings = separatorsExpression.GetConstantValue(containingExpression: expression); @@ -104,7 +65,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var ast = AstExpression.Split(stringTranslation.Ast, delimiter); var options = StringSplitOptions.None; - if (method.IsOneOf(__splitWithOptionsMethods)) + if (method.IsOneOf(StringMethod.SplitWithOptionsOverloads)) { var optionsExpression = arguments.Last(); options = optionsExpression.GetConstantValue(containingExpression: expression); @@ -117,7 +78,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC @as: "item"); } - if (method.IsOneOf(__splitWithCountMethods)) + if (method.IsOneOf(StringMethod.SplitWithCountOverloads)) { var countExpression = arguments[1]; var countTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, countExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/StartsWithContainsOrEndsWithMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/StartsWithContainsOrEndsWithMethodToAggregationExpressionTranslator.cs index c423feb37ff..0a3237c9b69 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/StartsWithContainsOrEndsWithMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/StartsWithContainsOrEndsWithMethodToAggregationExpressionTranslator.cs @@ -31,14 +31,14 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class StartsWithContainsOrEndsWithMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __startsWithContainsOrEndsWithMethods; - private static readonly MethodInfo[] __withComparisonTypeMethods; - private static readonly MethodInfo[] __withIgnoreCaseAndCultureMethods; + private static readonly IReadOnlyMethodInfoSet __translatableOverloads; + private static readonly IReadOnlyMethodInfoSet __withComparisonTypeOverloads; + private static readonly IReadOnlyMethodInfoSet __withIgnoreCaseAndCultureOverloads; static StartsWithContainsOrEndsWithMethodToAggregationExpressionTranslator() { - __startsWithContainsOrEndsWithMethods = new[] - { + __translatableOverloads = MethodInfoSet.Create( + [ StringMethod.StartsWithWithChar, StringMethod.StartsWithWithString, StringMethod.StartsWithWithStringAndComparisonType, @@ -51,28 +51,28 @@ static StartsWithContainsOrEndsWithMethodToAggregationExpressionTranslator() StringMethod.EndsWithWithString, StringMethod.EndsWithWithStringAndComparisonType, StringMethod.EndsWithWithStringAndIgnoreCaseAndCulture - }; + ]); - __withComparisonTypeMethods = new[] - { + __withComparisonTypeOverloads = MethodInfoSet.Create( + [ StringMethod.StartsWithWithStringAndComparisonType, StringMethod.ContainsWithCharAndComparisonType, StringMethod.ContainsWithStringAndComparisonType, StringMethod.EndsWithWithStringAndComparisonType - }; + ]); - __withIgnoreCaseAndCultureMethods = new[] - { + __withIgnoreCaseAndCultureOverloads = MethodInfoSet.Create( + [ StringMethod.StartsWithWithStringAndIgnoreCaseAndCulture, StringMethod.EndsWithWithStringAndIgnoreCaseAndCulture - }; + ]); } public static bool CanTranslate(MethodCallExpression expression) { var method = expression.Method; - if (method.IsOneOf(__startsWithContainsOrEndsWithMethods)) + if (method.IsOneOf(__translatableOverloads)) { return true; } @@ -216,7 +216,7 @@ bool GetIgnoreCaseFromIgnoreCaseAndCulture(Expression ignoreCaseExpression, Expr bool IsWithComparisonTypeMethod(MethodInfo method) { - if (method.IsOneOf(__withComparisonTypeMethods)) + if (method.IsOneOf(__withComparisonTypeOverloads)) { return true; } @@ -226,7 +226,7 @@ bool IsWithComparisonTypeMethod(MethodInfo method) bool IsWithIgnoreCaseAndCultureMethod(MethodInfo method) { - if (method.IsOneOf(__withIgnoreCaseAndCultureMethods)) + if (method.IsOneOf(__withIgnoreCaseAndCultureOverloads)) { return true; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/StringConcatMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/StringConcatMethodToAggregationExpressionTranslator.cs index e6912d421ad..5796e16dba3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/StringConcatMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/StringConcatMethodToAggregationExpressionTranslator.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq; @@ -28,18 +29,6 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class StringConcatMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __stringConcatMethods = new[] - { - StringMethod.ConcatWith1Object, - StringMethod.ConcatWith2Objects, - StringMethod.ConcatWith3Objects, - StringMethod.ConcatWithObjectArray, - StringMethod.ConcatWith2Strings, - StringMethod.ConcatWith3Strings, - StringMethod.ConcatWith4Strings, - StringMethod.ConcatWithStringArray - }; - public static bool CanTranslate(BinaryExpression expression, out MethodInfo method, out ReadOnlyCollection arguments) { if (expression.NodeType == ExpressionType.Add && @@ -58,7 +47,7 @@ public static bool CanTranslate(BinaryExpression expression, out MethodInfo meth public static bool CanTranslate(MethodCallExpression expression, out MethodInfo method, out ReadOnlyCollection arguments) { - if (expression.Method.IsOneOf(__stringConcatMethods)) + if (expression.Method.IsOneOf(StringMethod.ConcatOverloads)) { method = expression.Method; arguments = expression.Arguments; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SumMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SumMethodToAggregationExpressionTranslator.cs index 2f1fe72a073..5d47337de13 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SumMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SumMethodToAggregationExpressionTranslator.cs @@ -24,56 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class SumMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __sumMethods = - { - EnumerableMethod.SumDecimal, - EnumerableMethod.SumDecimalWithSelector, - EnumerableMethod.SumDouble, - EnumerableMethod.SumDoubleWithSelector, - EnumerableMethod.SumInt32, - EnumerableMethod.SumInt32WithSelector, - EnumerableMethod.SumInt64, - EnumerableMethod.SumInt64WithSelector, - EnumerableMethod.SumNullableDecimal, - EnumerableMethod.SumNullableDecimalWithSelector, - EnumerableMethod.SumNullableDouble, - EnumerableMethod.SumNullableDoubleWithSelector, - EnumerableMethod.SumNullableInt32, - EnumerableMethod.SumNullableInt32WithSelector, - EnumerableMethod.SumNullableInt64, - EnumerableMethod.SumNullableInt64WithSelector, - EnumerableMethod.SumNullableSingle, - EnumerableMethod.SumNullableSingleWithSelector, - EnumerableMethod.SumSingle, - EnumerableMethod.SumSingleWithSelector, - QueryableMethod.SumDecimal, - QueryableMethod.SumDecimalWithSelector, - QueryableMethod.SumDouble, - QueryableMethod.SumDoubleWithSelector, - QueryableMethod.SumInt32, - QueryableMethod.SumInt32WithSelector, - QueryableMethod.SumInt64, - QueryableMethod.SumInt64WithSelector, - QueryableMethod.SumNullableDecimal, - QueryableMethod.SumNullableDecimalWithSelector, - QueryableMethod.SumNullableDouble, - QueryableMethod.SumNullableDoubleWithSelector, - QueryableMethod.SumNullableInt32, - QueryableMethod.SumNullableInt32WithSelector, - QueryableMethod.SumNullableInt64, - QueryableMethod.SumNullableInt64WithSelector, - QueryableMethod.SumNullableSingle, - QueryableMethod.SumNullableSingleWithSelector, - QueryableMethod.SumSingle, - QueryableMethod.SumSingleWithSelector - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__sumMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.SumOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs index f95a2361fdc..263bd9ac6a8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToListMethodToAggregationExpressionTranslator.cs @@ -20,6 +20,7 @@ using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { @@ -37,10 +38,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var listItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var listItemType = listItemSerializer.ValueType; - var listType = typeof(List<>).MakeGenericType(listItemType); - var listSerializerType = typeof(EnumerableInterfaceImplementerSerializer<,>).MakeGenericType(listType, listItemType); - var listSerializer = (IBsonSerializer)Activator.CreateInstance(listSerializerType, listItemSerializer); + var listSerializer = ListSerializer.Create(listItemSerializer); return new TranslatedExpression(expression, sourceTranslation.Ast, listSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToStringMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToStringMethodToAggregationExpressionTranslator.cs index 79a62e7f055..6516c86016f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToStringMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ToStringMethodToAggregationExpressionTranslator.cs @@ -25,23 +25,23 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ToStringMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __dateTimeToStringMethods = new[] - { + private static readonly IReadOnlyMethodInfoSet __dateTimeToStringOverloads = MethodInfoSet.Create( + [ DateTimeMethod.ToStringWithFormat, DateTimeMethod.ToStringWithFormatAndTimezone, NullableDateTimeMethod.ToStringWithFormatAndTimezoneAndOnNull, - }; + ]); - private static readonly MethodInfo[] __dateTimeToStringMethodsWithTimezone = new[] - { + private static readonly IReadOnlyMethodInfoSet __dateTimeToStringWithTimezoneOverloads = MethodInfoSet.Create( + [ DateTimeMethod.ToStringWithFormatAndTimezone, NullableDateTimeMethod.ToStringWithFormatAndTimezoneAndOnNull, - }; + ]); - private static readonly MethodInfo[] __dateTimeToStringMethodsWithOnNull = new[] - { + private static readonly IReadOnlyMethodInfoSet __dateTimeToStringWithOnNullOverloads = MethodInfoSet.Create( + [ NullableDateTimeMethod.ToStringWithFormatAndTimezoneAndOnNull, - }; + ]); public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { @@ -53,7 +53,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return TranslateInstanceToStringMethodWithNoArguments(context, expression); } - if (method.IsOneOf(__dateTimeToStringMethods)) + if (method.IsOneOf(__dateTimeToStringOverloads)) { return TranslateDateTimeToStringMethod(context, expression, method, arguments); } @@ -86,7 +86,7 @@ private static TranslatedExpression TranslateDateTimeToStringMethod(TranslationC } AstExpression timezoneAst = null; - if (method.IsOneOf(__dateTimeToStringMethodsWithTimezone)) + if (method.IsOneOf(__dateTimeToStringWithTimezoneOverloads)) { var timezoneExpression = arguments[2]; if (!(timezoneExpression is ConstantExpression constantExpression) || constantExpression.Value != null) @@ -97,7 +97,7 @@ private static TranslatedExpression TranslateDateTimeToStringMethod(TranslationC } AstExpression onNullAst = null; - if (method.IsOneOf(__dateTimeToStringMethodsWithOnNull)) + if (method.IsOneOf(__dateTimeToStringWithOnNullOverloads)) { var onNullExpression = arguments[3]; var onNullTranslataion = ExpressionToAggregationExpressionTranslator.Translate(context, onNullExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/TrigMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/TrigMethodToAggregationExpressionTranslator.cs index 9c6b85dc21a..0df7cb2872d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/TrigMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/TrigMethodToAggregationExpressionTranslator.cs @@ -26,18 +26,18 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class TrigMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __binaryTrigMethods; - private static readonly MethodInfo[] __unaryTrigMethods; + private static readonly IReadOnlyMethodInfoSet __binaryTrigMethods; + private static readonly IReadOnlyMethodInfoSet __unaryTrigMethods; static TrigMethodToAggregationExpressionTranslator() { - __binaryTrigMethods = new[] - { + __binaryTrigMethods = MethodInfoSet.Create( + [ MathMethod.Atan2 - }; + ]); - __unaryTrigMethods = new[] - { + __unaryTrigMethods = MethodInfoSet.Create( + [ MathMethod.Acos, MathMethod.Acosh, MathMethod.Asin, @@ -52,7 +52,7 @@ static TrigMethodToAggregationExpressionTranslator() MathMethod.Tanh, MongoDBMathMethod.DegreesToRadians, MongoDBMathMethod.RadiansToDegrees - }; + ]); } public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/TrimMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/TrimMethodToAggregationExpressionTranslator.cs index 351dd98b5ca..9faee98d1bb 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/TrimMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/TrimMethodToAggregationExpressionTranslator.cs @@ -25,25 +25,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class TrimMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __trimMethods; - - static TrimMethodToAggregationExpressionTranslator() - { - __trimMethods = new[] - { - StringMethod.Trim, - StringMethod.TrimEnd, - StringMethod.TrimStart, - StringMethod.TrimWithChars - }; - } - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__trimMethods)) + if (method.IsOneOf(StringMethod.TrimOverloads)) { var objectExpression = expression.Object; var objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/UnionMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/UnionMethodToAggregationExpressionTranslator.cs index 9773217232c..5d2093ea3a9 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/UnionMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/UnionMethodToAggregationExpressionTranslator.cs @@ -24,18 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class UnionMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __unionMethods = - { - EnumerableMethod.Union, - QueryableMethod.Union - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__unionMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Union)) { var firstExpression = arguments[0]; var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs index 250c8658210..e7329d75e68 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs @@ -24,19 +24,19 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class WhereMethodToAggregationExpressionTranslator { - private static MethodInfo[] __whereMethods = - { + private static readonly IReadOnlyMethodInfoSet __whereOverloads = MethodInfoSet.Create( + [ EnumerableMethod.Where, MongoEnumerableMethod.WhereWithLimit, QueryableMethod.Where - }; + ]); public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__whereMethods)) + if (method.IsOneOf(__whereOverloads)) { var sourceExpression = arguments[0]; var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); @@ -70,7 +70,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC @as: predicateSymbol.Var.Name, limitTranslation?.Ast); - var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); + var resultSerializer = context.NodeSerializers.GetSerializer(expression); return new TranslatedExpression(expression, ast, resultSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs index f45cffc3e49..3fafbb25169 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WindowMethodToAggregationExpressionTranslator.cs @@ -29,138 +29,16 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class WindowMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __windowMethods = - { - WindowMethod.AddToSet, - WindowMethod.AverageWithDecimal, - WindowMethod.AverageWithDouble, - WindowMethod.AverageWithInt32, - WindowMethod.AverageWithInt64, - WindowMethod.AverageWithNullableDecimal, - WindowMethod.AverageWithNullableDouble, - WindowMethod.AverageWithNullableInt32, - WindowMethod.AverageWithNullableInt64, - WindowMethod.AverageWithNullableSingle, - WindowMethod.AverageWithSingle, - WindowMethod.Count, - WindowMethod.CovariancePopulationWithDecimals, - WindowMethod.CovariancePopulationWithDoubles, - WindowMethod.CovariancePopulationWithInt32s, - WindowMethod.CovariancePopulationWithInt64s, - WindowMethod.CovariancePopulationWithNullableDecimals, - WindowMethod.CovariancePopulationWithNullableDoubles, - WindowMethod.CovariancePopulationWithNullableInt32s, - WindowMethod.CovariancePopulationWithNullableInt64s, - WindowMethod.CovariancePopulationWithNullableSingles, - WindowMethod.CovariancePopulationWithSingles, - WindowMethod.CovarianceSampleWithDecimals, - WindowMethod.CovarianceSampleWithDoubles, - WindowMethod.CovarianceSampleWithInt32s, - WindowMethod.CovarianceSampleWithInt64s, - WindowMethod.CovarianceSampleWithNullableDecimals, - WindowMethod.CovarianceSampleWithNullableDoubles, - WindowMethod.CovarianceSampleWithNullableInt32s, - WindowMethod.CovarianceSampleWithNullableInt64s, - WindowMethod.CovarianceSampleWithNullableSingles, - WindowMethod.CovarianceSampleWithSingles, - WindowMethod.DenseRank, - WindowMethod.DerivativeWithDecimal, - WindowMethod.DerivativeWithDecimalAndUnit, - WindowMethod.DerivativeWithDouble, - WindowMethod.DerivativeWithDoubleAndUnit, - WindowMethod.DerivativeWithInt32, - WindowMethod.DerivativeWithInt32AndUnit, - WindowMethod.DerivativeWithInt64, - WindowMethod.DerivativeWithInt64AndUnit, - WindowMethod.DerivativeWithSingle, - WindowMethod.DerivativeWithSingleAndUnit, - WindowMethod.DocumentNumber, - WindowMethod.ExponentialMovingAverageWithDecimal, - WindowMethod.ExponentialMovingAverageWithDouble, - WindowMethod.ExponentialMovingAverageWithInt32, - WindowMethod.ExponentialMovingAverageWithInt64, - WindowMethod.ExponentialMovingAverageWithSingle, - WindowMethod.First, - WindowMethod.IntegralWithDecimal, - WindowMethod.IntegralWithDecimalAndUnit, - WindowMethod.IntegralWithDouble, - WindowMethod.IntegralWithDoubleAndUnit, - WindowMethod.IntegralWithInt32, - WindowMethod.IntegralWithInt32AndUnit, - WindowMethod.IntegralWithInt64, - WindowMethod.IntegralWithInt64AndUnit, - WindowMethod.IntegralWithSingle, - WindowMethod.IntegralWithSingleAndUnit, - WindowMethod.Last, - WindowMethod.Locf, - WindowMethod.Max, - WindowMethod.MedianWithDecimal, - WindowMethod.MedianWithDouble, - WindowMethod.MedianWithInt32, - WindowMethod.MedianWithInt64, - WindowMethod.MedianWithNullableDecimal, - WindowMethod.MedianWithNullableDouble, - WindowMethod.MedianWithNullableInt32, - WindowMethod.MedianWithNullableInt64, - WindowMethod.MedianWithNullableSingle, - WindowMethod.MedianWithSingle, - WindowMethod.Min, - WindowMethod.PercentileWithDecimal, - WindowMethod.PercentileWithDouble, - WindowMethod.PercentileWithInt32, - WindowMethod.PercentileWithInt64, - WindowMethod.PercentileWithNullableDecimal, - WindowMethod.PercentileWithNullableDouble, - WindowMethod.PercentileWithNullableInt32, - WindowMethod.PercentileWithNullableInt64, - WindowMethod.PercentileWithNullableSingle, - WindowMethod.PercentileWithSingle, - WindowMethod.Push, - WindowMethod.Rank, - WindowMethod.Shift, - WindowMethod.ShiftWithDefaultValue, - WindowMethod.StandardDeviationPopulationWithDecimal, - WindowMethod.StandardDeviationPopulationWithDouble, - WindowMethod.StandardDeviationPopulationWithInt32, - WindowMethod.StandardDeviationPopulationWithInt64, - WindowMethod.StandardDeviationPopulationWithNullableDecimal, - WindowMethod.StandardDeviationPopulationWithNullableDouble, - WindowMethod.StandardDeviationPopulationWithNullableInt32, - WindowMethod.StandardDeviationPopulationWithNullableInt64, - WindowMethod.StandardDeviationPopulationWithNullableSingle, - WindowMethod.StandardDeviationPopulationWithSingle, - WindowMethod.StandardDeviationSampleWithDecimal, - WindowMethod.StandardDeviationSampleWithDouble, - WindowMethod.StandardDeviationSampleWithInt32, - WindowMethod.StandardDeviationSampleWithInt64, - WindowMethod.StandardDeviationSampleWithNullableDecimal, - WindowMethod.StandardDeviationSampleWithNullableDouble, - WindowMethod.StandardDeviationSampleWithNullableInt32, - WindowMethod.StandardDeviationSampleWithNullableInt64, - WindowMethod.StandardDeviationSampleWithNullableSingle, - WindowMethod.StandardDeviationSampleWithSingle, - WindowMethod.SumWithDecimal, - WindowMethod.SumWithDouble, - WindowMethod.SumWithInt32, - WindowMethod.SumWithInt64, - WindowMethod.SumWithNullableDecimal, - WindowMethod.SumWithNullableDouble, - WindowMethod.SumWithNullableInt32, - WindowMethod.SumWithNullableInt64, - WindowMethod.SumWithNullableSingle, - WindowMethod.SumWithSingle - }; - - private static readonly MethodInfo[] __nullaryMethods = - { + private static readonly IReadOnlyMethodInfoSet __nullaryOverloads = MethodInfoSet.Create( + [ WindowMethod.Count, WindowMethod.DenseRank, WindowMethod.DocumentNumber, WindowMethod.Rank - }; + ]); - private static readonly MethodInfo[] __unaryMethods = - { + private static readonly IReadOnlyMethodInfoSet __unaryOverloads = MethodInfoSet.Create( + [ WindowMethod.AddToSet, WindowMethod.AverageWithDecimal, WindowMethod.AverageWithDouble, @@ -208,10 +86,10 @@ internal static class WindowMethodToAggregationExpressionTranslator WindowMethod.SumWithNullableInt64, WindowMethod.SumWithNullableSingle, WindowMethod.SumWithSingle - }; + ]); - private static readonly MethodInfo[] __binaryMethods = - { + private static readonly IReadOnlyMethodInfoSet __binaryOverloads = MethodInfoSet.Create( + [ WindowMethod.CovariancePopulationWithDecimals, WindowMethod.CovariancePopulationWithDoubles, WindowMethod.CovariancePopulationWithInt32s, @@ -232,10 +110,10 @@ internal static class WindowMethodToAggregationExpressionTranslator WindowMethod.CovarianceSampleWithNullableInt64s, WindowMethod.CovarianceSampleWithNullableSingles, WindowMethod.CovarianceSampleWithSingles - }; + ]); - private static readonly MethodInfo[] __derivativeOrIntegralMethods = - { + private static readonly IReadOnlyMethodInfoSet __derivativeOrIntegralOverloads = MethodInfoSet.Create( + [ WindowMethod.DerivativeWithDecimal, WindowMethod.DerivativeWithDecimalAndUnit, WindowMethod.DerivativeWithDouble, @@ -256,24 +134,24 @@ internal static class WindowMethodToAggregationExpressionTranslator WindowMethod.IntegralWithInt64AndUnit, WindowMethod.IntegralWithSingle, WindowMethod.IntegralWithSingleAndUnit - }; + ]); - private static readonly MethodInfo[] __exponentialMovingAverageMethods = - { + private static readonly IReadOnlyMethodInfoSet __exponentialMovingAverageOverloads = MethodInfoSet.Create( + [ WindowMethod.ExponentialMovingAverageWithDecimal, WindowMethod.ExponentialMovingAverageWithDouble, WindowMethod.ExponentialMovingAverageWithInt32, WindowMethod.ExponentialMovingAverageWithInt64, WindowMethod.ExponentialMovingAverageWithSingle - }; + ]); - private static readonly MethodInfo[] __shiftMethods = - { + private static readonly IReadOnlyMethodInfoSet __shiftOverloads = MethodInfoSet.Create( + [ WindowMethod.Shift, WindowMethod.ShiftWithDefaultValue - }; + ]); - private static readonly MethodInfo[] __quantileMethods = + private static readonly IReadOnlyMethodInfoSet __quantileOverloads = MethodInfoSet.Create( [ WindowMethod.MedianWithDecimal, WindowMethod.MedianWithDouble, @@ -295,11 +173,11 @@ internal static class WindowMethodToAggregationExpressionTranslator WindowMethod.PercentileWithNullableInt64, WindowMethod.PercentileWithNullableSingle, WindowMethod.PercentileWithSingle - ]; + ]); public static bool CanTranslate(MethodCallExpression expression) { - return expression.Method.IsOneOf(__windowMethods); + return IsWindowMethod(expression.Method); } public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) @@ -308,7 +186,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC var parameters = method.GetParameters(); var arguments = expression.Arguments.ToArray(); - if (method.IsOneOf(__windowMethods)) + if (IsWindowMethod(method)) { var partitionExpression = arguments[0]; var partitionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, partitionExpression); @@ -321,7 +199,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC window = TranslateWindow(context, expression, windowExpression, inputSerializer); } - if (method.IsOneOf(__nullaryMethods)) + if (method.IsOneOf(__nullaryOverloads)) { var @operator = GetNullaryWindowOperator(method); var ast = AstExpression.NullaryWindowExpression(@operator, window); @@ -335,7 +213,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC selectorTranslation = TranslateSelector(context, selectorLambda, inputSerializer); } - if (method.IsOneOf(__unaryMethods)) + if (method.IsOneOf(__unaryOverloads)) { ThrowIfSelectorTranslationIsNull(selectorTranslation); var @operator = GetUnaryWindowOperator(method); @@ -344,7 +222,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return new TranslatedExpression(expression, ast, serializer); } - if (method.IsOneOf(__binaryMethods)) + if (method.IsOneOf(__binaryOverloads)) { var selector1Lambda = GetArgument(parameters, "selector1", arguments); var selector2Lambda = GetArgument(parameters, "selector2", arguments); @@ -357,7 +235,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return new TranslatedExpression(expression, ast, serializer); } - if (method.IsOneOf(__derivativeOrIntegralMethods)) + if (method.IsOneOf(__derivativeOrIntegralOverloads)) { ThrowIfSelectorTranslationIsNull(selectorTranslation); WindowTimeUnit? unit = default; @@ -372,7 +250,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return new TranslatedExpression(expression, ast, serializer); } - if (method.IsOneOf(__exponentialMovingAverageMethods)) + if (method.IsOneOf(__exponentialMovingAverageOverloads)) { ThrowIfSelectorTranslationIsNull(selectorTranslation); var weightingExpression = arguments[2]; @@ -383,7 +261,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return new TranslatedExpression(expression, ast, serializer); } - if (method.IsOneOf(__quantileMethods)) + if (method.IsOneOf(__quantileOverloads)) { ThrowIfSelectorTranslationIsNull(selectorTranslation); AstExpression ast; @@ -404,7 +282,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC return new TranslatedExpression(expression, ast, serializer); } - if (method.IsOneOf(__shiftMethods)) + if (method.IsOneOf(__shiftOverloads)) { ThrowIfSelectorTranslationIsNull(selectorTranslation); var byExpression = arguments[2]; @@ -623,5 +501,10 @@ private static IBsonSerializer GetSortBySerializerGeneric( return renderedField.FieldSerializer; } + + private static bool IsWindowMethod(MethodInfo method) + { + return method.DeclaringType == typeof(ISetWindowFieldsPartitionExtensions); + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ZipMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ZipMethodToAggregationExpressionTranslator.cs index 259af4e59c6..a275e40e0c4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ZipMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ZipMethodToAggregationExpressionTranslator.cs @@ -25,18 +25,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ZipMethodToAggregationExpressionTranslator { - private static readonly MethodInfo[] __zipMethods = - { - EnumerableMethod.Zip, - QueryableMethod.Zip - }; - public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__zipMethods)) + if (method.IsOneOf(EnumerableOrQueryableMethod.Zip)) { var firstExpression = arguments[0]; var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs index c5eba340536..fedb16e0893 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs @@ -19,6 +19,8 @@ using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -27,28 +29,14 @@ internal static class NewArrayInitExpressionToAggregationExpressionTranslator public static TranslatedExpression Translate(TranslationContext context, NewArrayExpression expression) { var items = new List(); - IBsonSerializer itemSerializer = null; foreach (var itemExpression in expression.Expressions) { var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression); items.Add(itemTranslation.Ast); - itemSerializer ??= itemTranslation.Serializer; - - // make sure all items are serialized using the same serializer - if (!itemTranslation.Serializer.Equals(itemSerializer)) - { - throw new ExpressionNotSupportedException(expression, because: "all items in the array must be serialized using the same serializer"); - } } - var ast = AstExpression.ComputedArray(items); - var arrayType = expression.Type; - var itemType = arrayType.GetElementType(); - itemSerializer ??= BsonSerializer.LookupSerializer(itemType); // if the array is empty itemSerializer will be null - var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); - var arraySerializer = (IBsonSerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); - + var arraySerializer = context.NodeSerializers.GetSerializer(expression); return new TranslatedExpression(expression, ast, arraySerializer); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs index aee174ac38d..af7b324c2f3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs @@ -39,34 +39,21 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr var collectionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, collectionExpression); var itemSerializer = ArraySerializerHelper.GetItemSerializer(collectionTranslation.Serializer); - IBsonSerializer keySerializer; - IBsonSerializer valueSerializer; AstExpression collectionTranslationAst; - if (itemSerializer is IBsonDocumentSerializer itemDocumentSerializer) + if (itemSerializer.IsKeyValuePairSerializer(out var keyElementName, out var valueElementName, out var keySerializer, out var valueSerializer)) { - if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo)) - { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Key member"); - } - keySerializer = keyMemberSerializationInfo.Serializer; - - if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo)) - { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Value member"); - } - valueSerializer = valueMemberSerializationInfo.Serializer; - - if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v") + if (keyElementName == "k" && valueElementName == "v") { collectionTranslationAst = collectionTranslation.Ast; } else { + // map keyElementName and valueElementName to "k" and "v" var pairVar = AstExpression.Var("pair"); var computedDocumentAst = AstExpression.ComputedDocument([ - AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)), - AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName)) + AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyElementName)), + AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueElementName)) ]); collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs index cfe4f67f6a8..fbc84096074 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewKeyValuePairExpressionToAggregationExpressionTranslator.cs @@ -13,11 +13,9 @@ * limitations under the License. */ -using System; using System.Collections.Generic; using System.Linq.Expressions; using MongoDB.Bson; -using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewListExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewListExpressionToAggregationExpressionTranslator.cs index 3063460c00b..68acaa92950 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewListExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewListExpressionToAggregationExpressionTranslator.cs @@ -34,7 +34,7 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr { var argument = arguments[0]; var argumentType = argument.Type; - if (argumentType.IsConstructedGenericType && argumentType.GetGenericTypeDefinition().Implements(typeof(IEnumerable<>))) + if (argumentType.IsConstructedGenericType && argumentType.GetGenericTypeDefinition().ImplementsInterface(typeof(IEnumerable<>))) { var enumerableInterface = argumentType.GetIEnumerableGenericInterface(); var argumentItemType = enumerableInterface.GetGenericArguments()[0]; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewTupleExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewTupleExpressionToAggregationExpressionTranslator.cs index 66235aa71ca..a94d7fc3817 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewTupleExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewTupleExpressionToAggregationExpressionTranslator.cs @@ -13,13 +13,11 @@ * limitations under the License. */ -using System; -using System.Collections.Generic; using System.Linq.Expressions; using MongoDB.Bson.Serialization; -using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -49,16 +47,11 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr } var ast = AstExpression.ComputedArray(items); - var tupleSerializer = CreateTupleSerializer(tupleType, itemSerializers); + var tupleSerializer = TupleOrValueTupleSerializer.Create(tupleType, itemSerializers); return new TranslatedExpression(expression, ast, tupleSerializer); } throw new ExpressionNotSupportedException(expression); } - - private static IBsonSerializer CreateTupleSerializer(Type tupleType, IEnumerable itemSerializers) - { - return tupleType.IsTuple() ? TupleSerializer.Create(itemSerializers) : ValueTupleSerializer.Create(itemSerializers); - } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs index 692b3600ddd..486ae382721 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NotExpressionToAggregationExpressionTranslator.cs @@ -24,6 +24,7 @@ public static TranslatedExpression Translate(TranslationContext context, UnaryEx { if (expression.NodeType == ExpressionType.Not) { + // TODO: check operand representation var operandTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, expression.Operand); var ast = expression.Type == typeof(bool) ? AstExpression.Not(operandTranslation.Ast) : AstExpression.BitNot(operandTranslation.Ast); return new TranslatedExpression(expression, ast, operandTranslation.Serializer); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AnyMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AnyMethodToExecutableQueryTranslator.cs index 09262bfc20f..3c751b2210d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AnyMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AnyMethodToExecutableQueryTranslator.cs @@ -34,27 +34,27 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToExecut internal static class AnyMethodToExecutableQueryTranslator { // private static fields - private static readonly MethodInfo[] __anyMethods; - private static readonly MethodInfo[] __anyWithPredicateMethods; + private static readonly IReadOnlyMethodInfoSet __anyOverloads; + private static readonly IReadOnlyMethodInfoSet __anyWithPredicateOverloads; private static readonly IExecutableQueryFinalizer __finalizer = new AnyFinalizer(); private static readonly IBsonSerializer __outputSerializer = new WrappedValueSerializer("_v", BsonNullSerializer.Instance); // static constructors static AnyMethodToExecutableQueryTranslator() { - __anyMethods = new[] - { + __anyOverloads = MethodInfoSet.Create( + [ QueryableMethod.Any, QueryableMethod.AnyWithPredicate, MongoQueryableMethod.AnyAsync, MongoQueryableMethod.AnyWithPredicateAsync - }; + ]); - __anyWithPredicateMethods = new[] - { + __anyWithPredicateOverloads = MethodInfoSet.Create( + [ QueryableMethod.AnyWithPredicate, MongoQueryableMethod.AnyWithPredicateAsync - }; + ]); } // public static methods @@ -63,12 +63,12 @@ public static ExecutableQuery Translate(MongoQueryPr var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__anyMethods)) + if (method.IsOneOf(__anyOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); - if (method.IsOneOf(__anyWithPredicateMethods)) + if (method.IsOneOf(__anyWithPredicateOverloads)) { ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method, "with a predicate"); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AverageMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AverageMethodToExecutableQueryTranslator.cs index fbcf6542d88..32b582edab2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AverageMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AverageMethodToExecutableQueryTranslator.cs @@ -34,81 +34,25 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToExecut internal static class AverageMethodToExecutableQueryTranslator { // private static fields - private static readonly MethodInfo[] __averageMethods; - private static readonly MethodInfo[] __averageWithSelectorMethods; + private static readonly IReadOnlyMethodInfoSet __averageOverloads; + private static readonly IReadOnlyMethodInfoSet __averageWithSelectorOverloads; private static readonly IExecutableQueryFinalizer __singleFinalizer = new SingleFinalizer(); private static readonly IExecutableQueryFinalizer __singleOrDefaultFinalizer = new SingleOrDefaultFinalizer(); // static constructor static AverageMethodToExecutableQueryTranslator() { - __averageMethods = new[] - { - QueryableMethod.AverageDecimal, - QueryableMethod.AverageDecimalWithSelector, - QueryableMethod.AverageDouble, - QueryableMethod.AverageDoubleWithSelector, - QueryableMethod.AverageInt32, - QueryableMethod.AverageInt32WithSelector, - QueryableMethod.AverageInt64, - QueryableMethod.AverageInt64WithSelector, - QueryableMethod.AverageNullableDecimal, - QueryableMethod.AverageNullableDecimalWithSelector, - QueryableMethod.AverageNullableDouble, - QueryableMethod.AverageNullableDoubleWithSelector, - QueryableMethod.AverageNullableInt32, - QueryableMethod.AverageNullableInt32WithSelector, - QueryableMethod.AverageNullableInt64, - QueryableMethod.AverageNullableInt64WithSelector, - QueryableMethod.AverageNullableSingle, - QueryableMethod.AverageNullableSingleWithSelector, - QueryableMethod.AverageSingle, - QueryableMethod.AverageSingleWithSelector, - MongoQueryableMethod.AverageDecimalAsync, - MongoQueryableMethod.AverageDecimalWithSelectorAsync, - MongoQueryableMethod.AverageDoubleAsync, - MongoQueryableMethod.AverageDoubleWithSelectorAsync, - MongoQueryableMethod.AverageInt32Async, - MongoQueryableMethod.AverageInt32WithSelectorAsync, - MongoQueryableMethod.AverageInt64Async, - MongoQueryableMethod.AverageInt64WithSelectorAsync, - MongoQueryableMethod.AverageNullableDecimalAsync, - MongoQueryableMethod.AverageNullableDecimalWithSelectorAsync, - MongoQueryableMethod.AverageNullableDoubleAsync, - MongoQueryableMethod.AverageNullableDoubleWithSelectorAsync, - MongoQueryableMethod.AverageNullableInt32Async, - MongoQueryableMethod.AverageNullableInt32WithSelectorAsync, - MongoQueryableMethod.AverageNullableInt64Async, - MongoQueryableMethod.AverageNullableInt64WithSelectorAsync, - MongoQueryableMethod.AverageNullableSingleAsync, - MongoQueryableMethod.AverageNullableSingleWithSelectorAsync, - MongoQueryableMethod.AverageSingleAsync, - MongoQueryableMethod.AverageSingleWithSelectorAsync - }; + __averageOverloads = MethodInfoSet.Create( + [ + QueryableMethod.AverageOverloads, + MongoQueryableMethod.AverageOverloads + ]); - __averageWithSelectorMethods = new[] - { - QueryableMethod.AverageDecimalWithSelector, - QueryableMethod.AverageDoubleWithSelector, - QueryableMethod.AverageInt32WithSelector, - QueryableMethod.AverageInt64WithSelector, - QueryableMethod.AverageNullableDecimalWithSelector, - QueryableMethod.AverageNullableDoubleWithSelector, - QueryableMethod.AverageNullableInt32WithSelector, - QueryableMethod.AverageNullableInt64WithSelector, - QueryableMethod.AverageNullableSingleWithSelector, - QueryableMethod.AverageSingleWithSelector, - MongoQueryableMethod.AverageDecimalWithSelectorAsync, - MongoQueryableMethod.AverageDoubleWithSelectorAsync, - MongoQueryableMethod.AverageInt32WithSelectorAsync, - MongoQueryableMethod.AverageInt64WithSelectorAsync, - MongoQueryableMethod.AverageNullableDecimalWithSelectorAsync, - MongoQueryableMethod.AverageNullableDoubleWithSelectorAsync, - MongoQueryableMethod.AverageNullableInt32WithSelectorAsync, - MongoQueryableMethod.AverageNullableInt64WithSelectorAsync, - MongoQueryableMethod.AverageNullableSingleWithSelectorAsync, - MongoQueryableMethod.AverageSingleWithSelectorAsync - }; + __averageWithSelectorOverloads = MethodInfoSet.Create( + [ + QueryableMethod.AverageWithSelectorOverloads, + MongoQueryableMethod.AverageWithSelectorOverloads + ]); } // public static methods @@ -117,7 +61,7 @@ public static ExecutableQuery Translate(MongoQuer var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__averageMethods)) + if (method.IsOneOf(__averageOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); @@ -125,7 +69,7 @@ public static ExecutableQuery Translate(MongoQuer var sourceSerializer = pipeline.OutputSerializer; AstExpression valueExpression; - if (method.IsOneOf(__averageWithSelectorMethods)) + if (method.IsOneOf(__averageWithSelectorOverloads)) { var selectorLambda = ExpressionHelper.UnquoteLambda(arguments[1]); var selectorTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, selectorLambda, sourceSerializer, asRoot: true); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/CountMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/CountMethodToExecutableQueryTranslator.cs index 5fb83697387..803579f5a06 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/CountMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/CountMethodToExecutableQueryTranslator.cs @@ -30,27 +30,25 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToExecut internal static class CountMethodToExecutableQueryTranslator { // private static fields - private static readonly MethodInfo[] __countMethods; - private static readonly MethodInfo[] __countWithPredicateMethods; + private static readonly IReadOnlyMethodInfoSet __countOverloads; + private static readonly IReadOnlyMethodInfoSet __countWithPredicateOverloads; private static readonly IExecutableQueryFinalizer _finalizer = new SingleOrDefaultFinalizer(); private static readonly IBsonSerializer __wrappedInt32Serializer = new WrappedValueSerializer("_v", new Int32Serializer()); // static constructor static CountMethodToExecutableQueryTranslator() { - __countMethods = new[] - { - QueryableMethod.Count, - QueryableMethod.CountWithPredicate, - MongoQueryableMethod.CountAsync, - MongoQueryableMethod.CountWithPredicateAsync - }; + __countOverloads = MethodInfoSet.Create( + [ + QueryableMethod.CountOverloads, + MongoQueryableMethod.CountOverloads + ]); - __countWithPredicateMethods = new[] - { + __countWithPredicateOverloads = MethodInfoSet.Create( + [ QueryableMethod.CountWithPredicate, MongoQueryableMethod.CountWithPredicateAsync - }; + ]); } // public static methods @@ -59,12 +57,12 @@ public static ExecutableQuery Translate(MongoQueryPro var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__countMethods)) + if (method.IsOneOf(__countOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); - if (method.IsOneOf(__countWithPredicateMethods)) + if (method.IsOneOf(__countWithPredicateOverloads)) { ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method, "with a predicate"); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index a6a89b7639f..00a7780eb66 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -13,6 +13,8 @@ * limitations under the License. */ +using System; +using System.Linq; using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; @@ -31,7 +33,7 @@ public static ExecutableQuery> Translate TranslateScalar { // private static fields private static readonly IExecutableQueryFinalizer __firstFinalizer = new FirstFinalizer(); - private static readonly MethodInfo[] __firstMethods; - private static readonly MethodInfo[] __firstWithPredicateMethods; + private static readonly IReadOnlyMethodInfoSet __firstOverloads; + private static readonly IReadOnlyMethodInfoSet __firstWithPredicateOverloads; private static readonly IExecutableQueryFinalizer __firstOrDefaultFinalizer = new FirstOrDefaultFinalizer(); // static constructor static FirstMethodToExecutableQueryTranslator() { - __firstMethods = new[] - { - QueryableMethod.First, - QueryableMethod.FirstOrDefault, - QueryableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.FirstWithPredicate, - MongoQueryableMethod.FirstAsync, - MongoQueryableMethod.FirstOrDefaultAsync, - MongoQueryableMethod.FirstOrDefaultWithPredicateAsync, - MongoQueryableMethod.FirstWithPredicateAsync - }; + __firstOverloads = MethodInfoSet.Create( + [ + QueryableMethod.FirstOverloads, + MongoQueryableMethod.FirstOverloads + ]); - __firstWithPredicateMethods = new[] - { - QueryableMethod.FirstOrDefaultWithPredicate, - QueryableMethod.FirstWithPredicate, - MongoQueryableMethod.FirstOrDefaultWithPredicateAsync, - MongoQueryableMethod.FirstWithPredicateAsync - }; + __firstWithPredicateOverloads = MethodInfoSet.Create( + [ + QueryableMethod.FirstWithPredicateOverloads, + MongoQueryableMethod.FirstWithPredicateOverloads + ]); } // public static methods @@ -62,12 +54,12 @@ public static ExecutableQuery Translate(MongoQuer var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__firstMethods)) + if (method.IsOneOf(__firstOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); - if (method.IsOneOf(__firstWithPredicateMethods)) + if (method.IsOneOf(__firstWithPredicateOverloads)) { ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method, "with a predicate"); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/LastMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/LastMethodToExecutableQueryTranslator.cs index 70964e72ecf..16a06a4dd1a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/LastMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/LastMethodToExecutableQueryTranslator.cs @@ -29,41 +29,21 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToExecut internal static class LastMethodToExecutableQueryTranslator { // private static fields - private static readonly MethodInfo[] __lastMethods; - private static readonly MethodInfo[] __lastWithPredicateMethods; private static readonly IExecutableQueryFinalizer __singleFinalizer = new SingleFinalizer(); private static readonly IExecutableQueryFinalizer __singleOrDefaultFinalizer = new SingleOrDefaultFinalizer(); - // static constructor - static LastMethodToExecutableQueryTranslator() - { - __lastMethods = new[] - { - QueryableMethod.Last, - QueryableMethod.LastWithPredicate, - QueryableMethod.LastOrDefault, - QueryableMethod.LastOrDefaultWithPredicate - }; - - __lastWithPredicateMethods = new[] - { - QueryableMethod.LastWithPredicate, - QueryableMethod.LastOrDefaultWithPredicate - }; - } - // public methods public static ExecutableQuery Translate(MongoQueryProvider provider, TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__lastMethods)) + if (method.IsOneOf(QueryableMethod.LastOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); - if (method.IsOneOf(__lastWithPredicateMethods)) + if (method.IsOneOf(QueryableMethod.LastWithPredicateOverloads)) { ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method, "with a predicate"); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/LongCountMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/LongCountMethodToExecutableQueryTranslator.cs index ded473d720f..9cac7141f6c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/LongCountMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/LongCountMethodToExecutableQueryTranslator.cs @@ -30,27 +30,25 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToExecut internal static class LongCountMethodToExecutableQueryTranslator { // private static fields - private static readonly MethodInfo[] __longCountMethods; - private static readonly MethodInfo[] __longCountWithPredicateMethods; + private static readonly IReadOnlyMethodInfoSet __longCountOverloads; + private static readonly IReadOnlyMethodInfoSet __longCountWithPredicateOverloads; private static readonly IExecutableQueryFinalizer _finalizer = new SingleOrDefaultFinalizer(); private static readonly IBsonSerializer __wrappedInt64Serializer = new WrappedValueSerializer("_v", new Int64Serializer()); // static constructor static LongCountMethodToExecutableQueryTranslator() { - __longCountMethods = new[] - { - QueryableMethod.LongCount, - QueryableMethod.LongCountWithPredicate, - MongoQueryableMethod.LongCountAsync, - MongoQueryableMethod.LongCountWithPredicateAsync - }; + __longCountOverloads = MethodInfoSet.Create( + [ + QueryableMethod.LongCountOverloads, + MongoQueryableMethod.LongCountOverloads + ]); - __longCountWithPredicateMethods = new[] - { + __longCountWithPredicateOverloads = MethodInfoSet.Create( + [ QueryableMethod.LongCountWithPredicate, MongoQueryableMethod.LongCountWithPredicateAsync - }; + ]); } // public static methods @@ -59,12 +57,12 @@ public static ExecutableQuery Translate(MongoQueryPr var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__longCountMethods)) + if (method.IsOneOf(__longCountOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); - if (method.IsOneOf(__longCountWithPredicateMethods)) + if (method.IsOneOf(__longCountWithPredicateOverloads)) { var predicateLambda = ExpressionHelper.UnquoteLambda(arguments[1]); var predicateFilter = ExpressionToFilterTranslator.TranslateLambda(context, predicateLambda, parameterSerializer: pipeline.OutputSerializer, asRoot: true); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MaxMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MaxMethodToExecutableQueryTranslator.cs index a41162f7543..0210c4109a2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MaxMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MaxMethodToExecutableQueryTranslator.cs @@ -14,7 +14,6 @@ */ using System.Linq.Expressions; -using System.Reflection; using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; @@ -32,25 +31,23 @@ internal static class MaxMethodToExecutableQueryTranslator { // private static fields private static readonly IExecutableQueryFinalizer __finalizer = new SingleFinalizer(); - private static readonly MethodInfo[] __maxMethods; - private static readonly MethodInfo[] __maxWithSelectorMethods; + private static readonly IReadOnlyMethodInfoSet __maxOverloads; + private static readonly IReadOnlyMethodInfoSet __maxWithSelectorOverloads; // static constructor static MaxMethodToExecutableQueryTranslator() { - __maxMethods = new[] - { - QueryableMethod.Max, - QueryableMethod.MaxWithSelector, - MongoQueryableMethod.MaxAsync, - MongoQueryableMethod.MaxWithSelectorAsync, - }; + __maxOverloads = MethodInfoSet.Create( + [ + QueryableMethod.MaxOverloads, + MongoQueryableMethod.MaxOverloads, + ]); - __maxWithSelectorMethods = new[] - { + __maxWithSelectorOverloads = MethodInfoSet.Create( + [ QueryableMethod.MaxWithSelector, MongoQueryableMethod.MaxWithSelectorAsync, - }; + ]); } // public static methods @@ -59,7 +56,7 @@ public static ExecutableQuery Translate(MongoQuer var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__maxMethods)) + if (method.IsOneOf(__maxOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); @@ -68,7 +65,7 @@ public static ExecutableQuery Translate(MongoQuer var sourceSerializer = pipeline.OutputSerializer; AstExpression valueAst; IBsonSerializer valueSerializer; - if (method.IsOneOf(__maxWithSelectorMethods)) + if (method.IsOneOf(__maxWithSelectorOverloads)) { var selectorLambda = ExpressionHelper.UnquoteLambda(arguments[1]); var selectorTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, selectorLambda, sourceSerializer, asRoot: true); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MinMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MinMethodToExecutableQueryTranslator.cs index b220be68af8..bfd66cdd573 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MinMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/MinMethodToExecutableQueryTranslator.cs @@ -14,7 +14,6 @@ */ using System.Linq.Expressions; -using System.Reflection; using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; @@ -32,25 +31,23 @@ internal static class MinMethodToExecutableQueryTranslator { // private static fields private static readonly IExecutableQueryFinalizer __finalizer = new SingleFinalizer(); - private static readonly MethodInfo[] __minMethods; - private static readonly MethodInfo[] __minWithSelectorMethods; + private static readonly IReadOnlyMethodInfoSet __minOverloads; + private static readonly IReadOnlyMethodInfoSet __minWithSelectorOverloads; // static constructor static MinMethodToExecutableQueryTranslator() { - __minMethods = new[] - { - QueryableMethod.Min, - QueryableMethod.MinWithSelector, - MongoQueryableMethod.MinAsync, - MongoQueryableMethod.MinWithSelectorAsync, - }; + __minOverloads = MethodInfoSet.Create( + [ + QueryableMethod.MinOverloads, + MongoQueryableMethod.MinOverloads, + ]); - __minWithSelectorMethods = new[] - { + __minWithSelectorOverloads = MethodInfoSet.Create( + [ QueryableMethod.MinWithSelector, MongoQueryableMethod.MinWithSelectorAsync, - }; + ]); } // public static methods @@ -59,7 +56,7 @@ public static ExecutableQuery Translate(MongoQuer var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__minMethods)) + if (method.IsOneOf(__minOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); @@ -68,7 +65,7 @@ public static ExecutableQuery Translate(MongoQuer var sourceSerializer = pipeline.OutputSerializer; AstExpression valueAst; IBsonSerializer valueSerializer; - if (method.IsOneOf(__minWithSelectorMethods)) + if (method.IsOneOf(__minWithSelectorOverloads)) { var selectorLambda = ExpressionHelper.UnquoteLambda(arguments[1]); var selectorTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, selectorLambda, sourceSerializer, asRoot: true); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/SingleMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/SingleMethodToExecutableQueryTranslator.cs index eba7bfe0b07..2f081d1a1b2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/SingleMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/SingleMethodToExecutableQueryTranslator.cs @@ -28,41 +28,31 @@ internal static class SingleMethodToExecutableQueryTranslator { // private static fields private static readonly IExecutableQueryFinalizer __singleFinalizer = new SingleFinalizer(); - private static readonly MethodInfo[] __singleMethods; - private static readonly MethodInfo[] __singleOrDefaultMethods; - private static readonly MethodInfo[] __singleWithPredicateMethods; + private static readonly IReadOnlyMethodInfoSet __singleOverloads; + private static readonly IReadOnlyMethodInfoSet __singleOrDefaultOverloads; + private static readonly IReadOnlyMethodInfoSet __singleWithPredicateOverloads; private static readonly IExecutableQueryFinalizer __singleOrDefaultFinalizer = new SingleOrDefaultFinalizer(); // static constructor static SingleMethodToExecutableQueryTranslator() { - __singleMethods = new[] - { - QueryableMethod.Single, - QueryableMethod.SingleOrDefault, - QueryableMethod.SingleOrDefaultWithPredicate, - QueryableMethod.SingleWithPredicate, - MongoQueryableMethod.SingleAsync, - MongoQueryableMethod.SingleOrDefaultAsync, - MongoQueryableMethod.SingleOrDefaultWithPredicateAsync, - MongoQueryableMethod.SingleWithPredicateAsync - }; + __singleOverloads = MethodInfoSet.Create( + [ + QueryableMethod.SingleOverloads, + MongoQueryableMethod.SingleOverloads + ]); - __singleWithPredicateMethods = new[] - { - QueryableMethod.SingleOrDefaultWithPredicate, - QueryableMethod.SingleWithPredicate, - MongoQueryableMethod.SingleOrDefaultWithPredicateAsync, - MongoQueryableMethod.SingleWithPredicateAsync - }; + __singleWithPredicateOverloads = MethodInfoSet.Create( + [ + QueryableMethod.SingleWithPredicateOverloads, + MongoQueryableMethod.SingleWithPredicateOverloads + ]); - __singleOrDefaultMethods = new[] - { - QueryableMethod.SingleOrDefault, - QueryableMethod.SingleOrDefaultWithPredicate, - MongoQueryableMethod.SingleOrDefaultAsync, - MongoQueryableMethod.SingleOrDefaultWithPredicateAsync - }; + __singleOrDefaultOverloads = MethodInfoSet.Create( + [ + QueryableMethod.SingleOrDefaultOverloads, + MongoQueryableMethod.SingleOrDefaultOverloads + ]); } // public static methods @@ -71,12 +61,12 @@ public static ExecutableQuery Translate(MongoQuer var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__singleMethods)) + if (method.IsOneOf(__singleOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); - if (method.IsOneOf(__singleWithPredicateMethods)) + if (method.IsOneOf(__singleWithPredicateOverloads)) { ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method, "with a predicate"); @@ -92,7 +82,7 @@ public static ExecutableQuery Translate(MongoQuer AstStage.Limit(2), pipeline.OutputSerializer); - var finalizer = method.IsOneOf(__singleOrDefaultMethods) ? __singleOrDefaultFinalizer : __singleFinalizer; + var finalizer = method.IsOneOf(__singleOrDefaultOverloads) ? __singleOrDefaultFinalizer : __singleFinalizer; return ExecutableQuery.Create( provider, diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/StandardDeviationMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/StandardDeviationMethodToExecutableQueryTranslator.cs index 94a82bbd750..c0ed7ceb116 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/StandardDeviationMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/StandardDeviationMethodToExecutableQueryTranslator.cs @@ -14,7 +14,6 @@ */ using System.Linq.Expressions; -using System.Reflection; using MongoDB.Bson; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; @@ -33,230 +32,6 @@ internal static class StandardDeviationMethodToExecutableQueryTranslator __singleFinalizer = new SingleFinalizer(); private static IExecutableQueryFinalizer __singleOrDefaultFinalizer = new SingleOrDefaultFinalizer(); - private static readonly MethodInfo[] __standardDeviationMethods; - private static readonly MethodInfo[] __standardDeviationNullableMethods; - private static readonly MethodInfo[] __standardDeviationPopulationMethods; - private static readonly MethodInfo[] __standardDeviationWithSelectorMethods; - - // static constructor - static StandardDeviationMethodToExecutableQueryTranslator() - { - __standardDeviationMethods = new[] - { - MongoQueryableMethod.StandardDeviationPopulationDecimal, - MongoQueryableMethod.StandardDeviationPopulationDecimalAsync, - MongoQueryableMethod.StandardDeviationPopulationDecimalWithSelector, - MongoQueryableMethod.StandardDeviationPopulationDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationDouble, - MongoQueryableMethod.StandardDeviationPopulationDoubleAsync, - MongoQueryableMethod.StandardDeviationPopulationDoubleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationInt32, - MongoQueryableMethod.StandardDeviationPopulationInt32Async, - MongoQueryableMethod.StandardDeviationPopulationInt32WithSelector, - MongoQueryableMethod.StandardDeviationPopulationInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationInt64, - MongoQueryableMethod.StandardDeviationPopulationInt64Async, - MongoQueryableMethod.StandardDeviationPopulationInt64WithSelector, - MongoQueryableMethod.StandardDeviationPopulationInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimal, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDouble, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32Async, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32WithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64Async, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64WithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableSingle, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationSingle, - MongoQueryableMethod.StandardDeviationPopulationSingleAsync, - MongoQueryableMethod.StandardDeviationPopulationSingleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationSingleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleDecimal, - MongoQueryableMethod.StandardDeviationSampleDecimalAsync, - MongoQueryableMethod.StandardDeviationSampleDecimalWithSelector, - MongoQueryableMethod.StandardDeviationSampleDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleDouble, - MongoQueryableMethod.StandardDeviationSampleDoubleAsync, - MongoQueryableMethod.StandardDeviationSampleDoubleWithSelector, - MongoQueryableMethod.StandardDeviationSampleDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleInt32, - MongoQueryableMethod.StandardDeviationSampleInt32Async, - MongoQueryableMethod.StandardDeviationSampleInt32WithSelector, - MongoQueryableMethod.StandardDeviationSampleInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleInt64, - MongoQueryableMethod.StandardDeviationSampleInt64Async, - MongoQueryableMethod.StandardDeviationSampleInt64WithSelector, - MongoQueryableMethod.StandardDeviationSampleInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDecimal, - MongoQueryableMethod.StandardDeviationSampleNullableDecimalAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDecimalWithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDouble, - MongoQueryableMethod.StandardDeviationSampleNullableDoubleAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDoubleWithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableInt32, - MongoQueryableMethod.StandardDeviationSampleNullableInt32Async, - MongoQueryableMethod.StandardDeviationSampleNullableInt32WithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableInt64, - MongoQueryableMethod.StandardDeviationSampleNullableInt64Async, - MongoQueryableMethod.StandardDeviationSampleNullableInt64WithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableSingle, - MongoQueryableMethod.StandardDeviationSampleNullableSingleAsync, - MongoQueryableMethod.StandardDeviationSampleNullableSingleWithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableSingleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleSingle, - MongoQueryableMethod.StandardDeviationSampleSingleAsync, - MongoQueryableMethod.StandardDeviationSampleSingleWithSelector, - MongoQueryableMethod.StandardDeviationSampleSingleWithSelectorAsync - }; - - __standardDeviationNullableMethods = new[] - { - MongoQueryableMethod.StandardDeviationPopulationNullableDecimal, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDouble, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32Async, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32WithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64Async, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64WithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableSingle, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDecimal, - MongoQueryableMethod.StandardDeviationSampleNullableDecimalAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDecimalWithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDouble, - MongoQueryableMethod.StandardDeviationSampleNullableDoubleAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDoubleWithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableInt32, - MongoQueryableMethod.StandardDeviationSampleNullableInt32Async, - MongoQueryableMethod.StandardDeviationSampleNullableInt32WithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableInt64, - MongoQueryableMethod.StandardDeviationSampleNullableInt64Async, - MongoQueryableMethod.StandardDeviationSampleNullableInt64WithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableSingle, - MongoQueryableMethod.StandardDeviationSampleNullableSingleAsync, - MongoQueryableMethod.StandardDeviationSampleNullableSingleWithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableSingleWithSelectorAsync - }; - - __standardDeviationPopulationMethods = new[] - { - MongoQueryableMethod.StandardDeviationPopulationDecimal, - MongoQueryableMethod.StandardDeviationPopulationDecimalAsync, - MongoQueryableMethod.StandardDeviationPopulationDecimalWithSelector, - MongoQueryableMethod.StandardDeviationPopulationDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationDouble, - MongoQueryableMethod.StandardDeviationPopulationDoubleAsync, - MongoQueryableMethod.StandardDeviationPopulationDoubleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationInt32, - MongoQueryableMethod.StandardDeviationPopulationInt32Async, - MongoQueryableMethod.StandardDeviationPopulationInt32WithSelector, - MongoQueryableMethod.StandardDeviationPopulationInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationInt64, - MongoQueryableMethod.StandardDeviationPopulationInt64Async, - MongoQueryableMethod.StandardDeviationPopulationInt64WithSelector, - MongoQueryableMethod.StandardDeviationPopulationInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimal, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDouble, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32Async, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32WithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64Async, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64WithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableSingle, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationSingle, - MongoQueryableMethod.StandardDeviationPopulationSingleAsync, - MongoQueryableMethod.StandardDeviationPopulationSingleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationSingleWithSelectorAsync - }; - - __standardDeviationWithSelectorMethods = new[] - { - MongoQueryableMethod.StandardDeviationPopulationDecimalWithSelector, - MongoQueryableMethod.StandardDeviationPopulationDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationDoubleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationInt32WithSelector, - MongoQueryableMethod.StandardDeviationPopulationInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationInt64WithSelector, - MongoQueryableMethod.StandardDeviationPopulationInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32WithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64WithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationNullableSingleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationPopulationSingleWithSelector, - MongoQueryableMethod.StandardDeviationPopulationSingleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleDecimalWithSelector, - MongoQueryableMethod.StandardDeviationSampleDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleDoubleWithSelector, - MongoQueryableMethod.StandardDeviationSampleDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleInt32WithSelector, - MongoQueryableMethod.StandardDeviationSampleInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleInt64WithSelector, - MongoQueryableMethod.StandardDeviationSampleInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDecimalWithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableDecimalWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableDoubleWithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableDoubleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableInt32WithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableInt32WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableInt64WithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableInt64WithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleNullableSingleWithSelector, - MongoQueryableMethod.StandardDeviationSampleNullableSingleWithSelectorAsync, - MongoQueryableMethod.StandardDeviationSampleSingleWithSelector, - MongoQueryableMethod.StandardDeviationSampleSingleWithSelectorAsync - }; - } // public static methods public static ExecutableQuery Translate(MongoQueryProvider provider, TranslationContext context, MethodCallExpression expression) @@ -264,16 +39,16 @@ public static ExecutableQuery Translate(MongoQuer var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__standardDeviationMethods)) + if (method.IsOneOf(MongoQueryableMethod.StandardDeviationOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method); var sourceSerializer = pipeline.OutputSerializer; - var stdDevOperator = method.IsOneOf(__standardDeviationPopulationMethods) ? AstUnaryAccumulatorOperator.StdDevPop : AstUnaryAccumulatorOperator.StdDevSamp; + var stdDevOperator = method.IsOneOf(MongoQueryableMethod.StandardDeviationPopulationOverloads) ? AstUnaryAccumulatorOperator.StdDevPop : AstUnaryAccumulatorOperator.StdDevSamp; AstExpression valueAst; - if (method.IsOneOf(__standardDeviationWithSelectorMethods)) + if (method.IsOneOf(MongoQueryableMethod.StandardDeviationWithSelectorOverloads)) { var selectorLambda = ExpressionHelper.UnquoteLambda(arguments[1]); var selectorTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, selectorLambda, sourceSerializer, asRoot: true); @@ -294,7 +69,7 @@ public static ExecutableQuery Translate(MongoQuer AstStage.Project(AstProject.ExcludeId()), outputWrappedValueSerializer); - var finalizer = method.IsOneOf(__standardDeviationNullableMethods) ? __singleOrDefaultFinalizer : __singleFinalizer; + var finalizer = method.IsOneOf(MongoQueryableMethod.StandardDeviationNullableOverloads) ? __singleOrDefaultFinalizer : __singleFinalizer; return ExecutableQuery.Create( provider, diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/SumMethodToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/SumMethodToExecutableQueryTranslator.cs index 0e44e945fd7..ac19fede0c6 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/SumMethodToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/SumMethodToExecutableQueryTranslator.cs @@ -33,79 +33,23 @@ internal static class SumMethodToExecutableQueryTranslator { // private static fields private static readonly IExecutableQueryFinalizer __finalizer = new SingleOrDefaultFinalizer(); - private static readonly MethodInfo[] __sumMethods; - private static readonly MethodInfo[] __sumWithSelectorMethods; + private static readonly IReadOnlyMethodInfoSet __sumOverloads; + private static readonly IReadOnlyMethodInfoSet __sumWithSelectorOverloads; // static constructor static SumMethodToExecutableQueryTranslator() { - __sumMethods = new[] - { - QueryableMethod.SumDecimal, - QueryableMethod.SumDecimalWithSelector, - QueryableMethod.SumDouble, - QueryableMethod.SumDoubleWithSelector, - QueryableMethod.SumInt32, - QueryableMethod.SumInt32WithSelector, - QueryableMethod.SumInt64, - QueryableMethod.SumInt64WithSelector, - QueryableMethod.SumNullableDecimal, - QueryableMethod.SumNullableDecimalWithSelector, - QueryableMethod.SumNullableDouble, - QueryableMethod.SumNullableDoubleWithSelector, - QueryableMethod.SumNullableInt32, - QueryableMethod.SumNullableInt32WithSelector, - QueryableMethod.SumNullableInt64, - QueryableMethod.SumNullableInt64WithSelector, - QueryableMethod.SumNullableSingle, - QueryableMethod.SumNullableSingleWithSelector, - QueryableMethod.SumSingle, - QueryableMethod.SumSingleWithSelector, - MongoQueryableMethod.SumDecimalAsync, - MongoQueryableMethod.SumDecimalWithSelectorAsync, - MongoQueryableMethod.SumDoubleAsync, - MongoQueryableMethod.SumDoubleWithSelectorAsync, - MongoQueryableMethod.SumInt32Async, - MongoQueryableMethod.SumInt32WithSelectorAsync, - MongoQueryableMethod.SumInt64Async, - MongoQueryableMethod.SumInt64WithSelectorAsync, - MongoQueryableMethod.SumNullableDecimalAsync, - MongoQueryableMethod.SumNullableDecimalWithSelectorAsync, - MongoQueryableMethod.SumNullableDoubleAsync, - MongoQueryableMethod.SumNullableDoubleWithSelectorAsync, - MongoQueryableMethod.SumNullableInt32Async, - MongoQueryableMethod.SumNullableInt32WithSelectorAsync, - MongoQueryableMethod.SumNullableInt64Async, - MongoQueryableMethod.SumNullableInt64WithSelectorAsync, - MongoQueryableMethod.SumNullableSingleAsync, - MongoQueryableMethod.SumNullableSingleWithSelectorAsync, - MongoQueryableMethod.SumSingleAsync, - MongoQueryableMethod.SumSingleWithSelectorAsync - }; + __sumOverloads = MethodInfoSet.Create( + [ + QueryableMethod.SumOverloads, + MongoQueryableMethod.SumOverloads + ]); - __sumWithSelectorMethods = new[] - { - QueryableMethod.SumDecimalWithSelector, - QueryableMethod.SumDoubleWithSelector, - QueryableMethod.SumInt32WithSelector, - QueryableMethod.SumInt64WithSelector, - QueryableMethod.SumNullableDecimalWithSelector, - QueryableMethod.SumNullableDoubleWithSelector, - QueryableMethod.SumNullableInt32WithSelector, - QueryableMethod.SumNullableInt64WithSelector, - QueryableMethod.SumNullableSingleWithSelector, - QueryableMethod.SumSingleWithSelector, - MongoQueryableMethod.SumDecimalWithSelectorAsync, - MongoQueryableMethod.SumDoubleWithSelectorAsync, - MongoQueryableMethod.SumInt32WithSelectorAsync, - MongoQueryableMethod.SumInt64WithSelectorAsync, - MongoQueryableMethod.SumNullableDecimalWithSelectorAsync, - MongoQueryableMethod.SumNullableDoubleWithSelectorAsync, - MongoQueryableMethod.SumNullableInt32WithSelectorAsync, - MongoQueryableMethod.SumNullableInt64WithSelectorAsync, - MongoQueryableMethod.SumNullableSingleWithSelectorAsync, - MongoQueryableMethod.SumSingleWithSelectorAsync - }; + __sumWithSelectorOverloads = MethodInfoSet.Create( + [ + QueryableMethod.SumWithSelectorOverloads, + MongoQueryableMethod.SumWithSelectorOverloads + ]); } // public static methods @@ -114,7 +58,7 @@ public static ExecutableQuery Translate(MongoQuer var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__sumMethods)) + if (method.IsOneOf(__sumOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); @@ -122,7 +66,7 @@ public static ExecutableQuery Translate(MongoQuer var sourceSerializer = pipeline.OutputSerializer; AstExpression valueAst; - if (method.IsOneOf(__sumWithSelectorMethods)) + if (method.IsOneOf(__sumWithSelectorOverloads)) { var selectorLambda = ExpressionHelper.UnquoteLambda(arguments[1]); var selectorTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, selectorLambda, sourceSerializer, asRoot: true); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/CompareComparisonExpressionToFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/CompareComparisonExpressionToFilterTranslator.cs index 2c67478acc9..aa516a67978 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/CompareComparisonExpressionToFilterTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/CompareComparisonExpressionToFilterTranslator.cs @@ -28,7 +28,7 @@ public static bool CanTranslate(Expression leftExpression) return leftExpression is MethodCallExpression leftMethodCallExpression && leftMethodCallExpression.Method is var method && - (method.IsStaticCompareMethod() || method.IsInstanceCompareToMethod() || method.Is(StringMethod.StaticCompareWithIgnoreCase)); + (method.IsStaticCompareMethod() || method.IsInstanceCompareToMethod() || method.Is(StringMethod.CompareWithIgnoreCase)); } // caller is responsible for ensuring constant is on the right @@ -59,7 +59,7 @@ public static AstFilter Translate( innerValueExpression = compareArguments[0]; } - if (compareMethod.Is(StringMethod.StaticCompareWithIgnoreCase)) + if (compareMethod.Is(StringMethod.CompareWithIgnoreCase)) { var ignoreCaseExpression = compareArguments[2]; var ignoreCase = ignoreCaseExpression.GetConstantValue(containingExpression: compareMethodCallExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/StringExpressionToRegexFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/StringExpressionToRegexFilterTranslator.cs index ec24bd2db43..8500250db84 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/StringExpressionToRegexFilterTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/StringExpressionToRegexFilterTranslator.cs @@ -34,129 +34,41 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilter internal static class StringExpressionToRegexFilterTranslator { // private static fields - private static readonly MethodInfo[] __indexOfAnyMethods; - private static readonly MethodInfo[] __indexOfMethods; - private static readonly MethodInfo[] __indexOfWithCharMethods; - private static readonly MethodInfo[] __indexOfWithComparisonTypeMethods; - private static readonly MethodInfo[] __indexOfWithCountMethods; - private static readonly MethodInfo[] __indexOfWithStartIndexMethods; - private static readonly MethodInfo[] __indexOfWithStringMethods; - private static readonly MethodInfo[] __modifierMethods; - private static readonly MethodInfo[] __translatableMethods; - private static readonly MethodInfo[] __withComparisonTypeMethods; - private static readonly MethodInfo[] __withIgnoreCaseAndCultureMethods; + private static readonly IReadOnlyMethodInfoSet __modifierOverloads; + private static readonly IReadOnlyMethodInfoSet __translatableOverloads; + private static readonly IReadOnlyMethodInfoSet __withComparisonTypeOverloads; + private static readonly IReadOnlyMethodInfoSet __withIgnoreCaseAndCultureOverloads; // static constructor static StringExpressionToRegexFilterTranslator() { - __indexOfAnyMethods = new[] - { - StringMethod.IndexOfAny, - StringMethod.IndexOfAnyWithStartIndex, - StringMethod.IndexOfAnyWithStartIndexAndCount, - }; - - __indexOfMethods = new[] - { - StringMethod.IndexOfAny, - StringMethod.IndexOfAnyWithStartIndex, - StringMethod.IndexOfAnyWithStartIndexAndCount, - StringMethod.IndexOfWithChar, - StringMethod.IndexOfWithCharAndStartIndex, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - StringMethod.IndexOfWithString, - StringMethod.IndexOfWithStringAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndex, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - __indexOfWithCharMethods = new[] - { - StringMethod.IndexOfWithChar, - StringMethod.IndexOfWithCharAndStartIndex, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - }; - - __indexOfWithComparisonTypeMethods = new[] - { - StringMethod.IndexOfWithStringAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - __indexOfWithCountMethods = new[] - { - StringMethod.IndexOfAnyWithStartIndexAndCount, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - __indexOfWithStartIndexMethods = new[] - { - StringMethod.IndexOfAnyWithStartIndex, - StringMethod.IndexOfAnyWithStartIndexAndCount, - StringMethod.IndexOfWithCharAndStartIndex, - StringMethod.IndexOfWithCharAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndex, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - __indexOfWithStringMethods = new[] - { - StringMethod.IndexOfWithString, - StringMethod.IndexOfWithStringAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndex, - StringMethod.IndexOfWithStringAndStartIndexAndComparisonType, - StringMethod.IndexOfWithStringAndStartIndexAndCount, - StringMethod.IndexOfWithStringAndStartIndexAndCountAndComparisonType - }; - - __modifierMethods = new[] - { - StringMethod.ToLower, - StringMethod.ToLowerInvariant, - StringMethod.ToUpper, - StringMethod.ToUpperInvariant, - StringMethod.Trim, - StringMethod.TrimEnd, - StringMethod.TrimStart, - StringMethod.TrimWithChars - }; - - __translatableMethods = new[] - { - StringMethod.ContainsWithChar, - StringMethod.ContainsWithCharAndComparisonType, - StringMethod.ContainsWithString, - StringMethod.ContainsWithStringAndComparisonType, - StringMethod.EndsWithWithChar, - StringMethod.EndsWithWithString, - StringMethod.EndsWithWithStringAndComparisonType, - StringMethod.EndsWithWithStringAndIgnoreCaseAndCulture, - StringMethod.StartsWithWithChar, - StringMethod.StartsWithWithString, - StringMethod.StartsWithWithStringAndComparisonType, - StringMethod.StartsWithWithStringAndIgnoreCaseAndCulture - }; - - __withComparisonTypeMethods = new[] - { + __modifierOverloads = MethodInfoSet.Create( + [ + StringMethod.ToLowerOverloads, + StringMethod.ToUpperOverloads, + StringMethod.TrimOverloads + ]); + + __translatableOverloads = MethodInfoSet.Create( + [ + StringMethod.ContainsOverloads, + StringMethod.EndsWithOverloads, + StringMethod.StartsWithOverloads + ]); + + __withComparisonTypeOverloads = MethodInfoSet.Create( + [ StringMethod.ContainsWithCharAndComparisonType, StringMethod.ContainsWithStringAndComparisonType, StringMethod.EndsWithWithStringAndComparisonType, StringMethod.StartsWithWithStringAndComparisonType, - }; + ]); - __withIgnoreCaseAndCultureMethods = new[] - { + __withIgnoreCaseAndCultureOverloads = MethodInfoSet.Create( + [ StringMethod.EndsWithWithStringAndIgnoreCaseAndCulture, StringMethod.StartsWithWithStringAndIgnoreCaseAndCulture - }; + ]); } // public static methods @@ -166,7 +78,7 @@ public static bool CanTranslate(Expression expression) { var method = methodCallExpression.Method; - if (method.IsOneOf(__translatableMethods)) + if (method.IsOneOf(__translatableOverloads)) { return true; } @@ -368,7 +280,7 @@ private static bool IsStringIndexOfComparison(Expression leftExpression) { return leftExpression is MethodCallExpression leftMethodCallExpression && - leftMethodCallExpression.Method.IsOneOf(__indexOfMethods); + leftMethodCallExpression.Method.IsOneOf(StringMethod.IndexOfOverloads); } private static bool IsStringLengthComparison(Expression leftExpression) @@ -381,7 +293,7 @@ leftMemberExpression.Member is PropertyInfo propertyInfo && private static bool IsWithComparisonTypeMethod(MethodInfo method) { - if (method.IsOneOf(__withComparisonTypeMethods)) + if (method.IsOneOf(__withComparisonTypeOverloads)) { return true; } @@ -445,7 +357,7 @@ private static AstFilter TranslateGetCharsComparison(TranslationContext context, private static (AstFilterField, Modifiers) TranslateField(TranslationContext context, Expression expression, Expression fieldExpression) { if (fieldExpression is MethodCallExpression fieldMethodCallExpression && - fieldMethodCallExpression.Method.IsOneOf(__modifierMethods)) + fieldMethodCallExpression.Method.IsOneOf(__modifierOverloads)) { var (field, modifiers) = TranslateField(context, expression, fieldMethodCallExpression.Object); modifiers = TranslateModifier(modifiers, fieldMethodCallExpression); @@ -527,7 +439,7 @@ private static AstFilter TranslateStartsWithOrContainsOrEndsWith(TranslationCont { modifiers = TranslateComparisonType(modifiers, expression, arguments[1]); } - if (method.IsOneOf(__withIgnoreCaseAndCultureMethods)) + if (method.IsOneOf(__withIgnoreCaseAndCultureOverloads)) { modifiers = TranslateIgnoreCase(modifiers, expression, arguments[1]); modifiers = TranslateCulture(modifiers, expression, arguments[2]); @@ -613,7 +525,7 @@ private static AstFilter TranslateStringIndexOfComparison(TranslationContext con var (field, modifiers) = TranslateField(context, expression, fieldExpression); var startIndex = 0; - if (method.IsOneOf(__indexOfWithStartIndexMethods)) + if (method.IsOneOf(StringMethod.IndexOfWithStartIndexOverloads)) { var startIndexExpression = arguments[1]; startIndex = startIndexExpression.GetConstantValue(containingExpression: expression); @@ -624,7 +536,7 @@ private static AstFilter TranslateStringIndexOfComparison(TranslationContext con } var count = (int?)null; - if (method.IsOneOf(__indexOfWithCountMethods)) + if (method.IsOneOf(StringMethod.IndexOfWithCountOverloads)) { var countExpression = arguments[2]; count = countExpression.GetConstantValue(containingExpression: expression); @@ -636,16 +548,16 @@ private static AstFilter TranslateStringIndexOfComparison(TranslationContext con var comparand = rightExpression.GetConstantValue(containingExpression: expression); - if (method.IsOneOf(__indexOfAnyMethods, __indexOfWithCharMethods)) + if (method.IsOneOf(StringMethod.IndexOfAnyOverloads, StringMethod.IndexOfWithCharOverloads)) { char[] anyOf; - if (method.IsOneOf(__indexOfAnyMethods)) + if (method.IsOneOf(StringMethod.IndexOfAnyOverloads)) { var anyOfExpression = arguments[0]; anyOf = anyOfExpression.GetConstantValue(containingExpression: expression); } else - if (method.IsOneOf(__indexOfWithCharMethods)) + if (method.IsOneOf(StringMethod.IndexOfWithCharOverloads)) { var valueExpression = arguments[0]; var value = valueExpression.GetConstantValue(containingExpression: expression); @@ -684,13 +596,13 @@ private static AstFilter TranslateStringIndexOfComparison(TranslationContext con return CreateFilter(expression, field, modifiers, comparisonOperator, pattern); } - if (method.IsOneOf(__indexOfWithStringMethods)) + if (method.IsOneOf(StringMethod.IndexOfWithStringOverloads)) { var valueExpression = arguments[0]; var value = valueExpression.GetConstantValue(containingExpression: expression); var escapedValue = Regex.Escape(value); - if (method.IsOneOf(__indexOfWithComparisonTypeMethods)) + if (method.IsOneOf(StringMethod.IndexOfWithComparisonTypeOverloads)) { var comparisonTypeExpression = arguments.Last(); modifiers = TranslateComparisonType(modifiers, expression, comparisonTypeExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/AllWithContainsInPredicateMethodToFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/AllWithContainsInPredicateMethodToFilterTranslator.cs index cabda2421f0..3ae6b57d249 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/AllWithContainsInPredicateMethodToFilterTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/AllWithContainsInPredicateMethodToFilterTranslator.cs @@ -64,7 +64,7 @@ public static AstFilter Translate(TranslationContext context, Expression arrayFi private static bool IsContainsParameterExpression(Expression predicateBody, ParameterExpression predicateParameter, out Expression innerSourceExpression) { if (predicateBody is MethodCallExpression methodCallExpression && - IsContainsMethodCall(methodCallExpression, out var sourceExpression, out var valueExpression) && + EnumerableMethod.IsContainsMethod(methodCallExpression, out var sourceExpression, out var valueExpression) && valueExpression == predicateParameter) { innerSourceExpression = sourceExpression; @@ -73,49 +73,6 @@ private static bool IsContainsParameterExpression(Expression predicateBody, Para innerSourceExpression = null; return false; - - static bool IsContainsMethodCall(MethodCallExpression methodCallExpression, out Expression sourceExpression, out Expression valueExpression) - { - var method = methodCallExpression.Method; - var arguments = methodCallExpression.Arguments; - - if (method.Name == "Contains" && method.ReturnType == typeof(bool)) - { - if (method.IsStatic && arguments.Count == 2) - { - sourceExpression = arguments[0]; - valueExpression = arguments[1]; - if (ValueTypeIsElementTypeOfSourceType(valueExpression, sourceExpression)) - { - return true; - } - } - else if (!method.IsStatic && arguments.Count == 1) - { - sourceExpression = methodCallExpression.Object; - valueExpression = arguments[0]; - if (ValueTypeIsElementTypeOfSourceType(valueExpression, sourceExpression)) - { - return true; - } - } - } - - sourceExpression = null; - valueExpression = null; - return false; - } - - static bool ValueTypeIsElementTypeOfSourceType(Expression valueExpression, Expression sourceExpression) - { - if (sourceExpression.Type.TryGetIEnumerableGenericInterface(out var ienumerableInterface)) - { - var elementType = ienumerableInterface.GetGenericArguments()[0]; - return elementType.IsAssignableFrom(valueExpression.Type); - } - - return false; - } } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsMethodToFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsMethodToFilterTranslator.cs index 574a93809c6..3dd10aec451 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsMethodToFilterTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsMethodToFilterTranslator.cs @@ -20,6 +20,7 @@ using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters; using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ToFilterFieldTranslators; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.MethodTranslators @@ -36,30 +37,11 @@ public static AstFilter Translate(TranslationContext context, MethodCallExpressi var method = expression.Method; var arguments = expression.Arguments; - if (method.IsStatic && - method.Name == "Contains" && - method.ReturnType == typeof(bool) && - arguments.Count == 2) + if (EnumerableMethod.IsContainsMethod(expression, out var fieldExpression, out var itemExpression)) { - var fieldExpression = arguments[0]; var fieldType = fieldExpression.Type; - var itemExpression = arguments[1]; var itemType = itemExpression.Type; - if (TypeImplementsIEnumerable(fieldType, itemType)) - { - return Translate(context, expression, fieldExpression, itemExpression); - } - } - if (!method.IsStatic && - method.Name == "Contains" && - method.ReturnType == typeof(bool) && - arguments.Count == 1) - { - var fieldExpression = expression.Object; - var fieldType = fieldExpression.Type; - var itemExpression = arguments[0]; - var itemType = itemExpression.Type; if (TypeImplementsIEnumerable(fieldType, itemType)) { return Translate(context, expression, fieldExpression, itemExpression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/StringInOrNinMethodToFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/StringInOrNinMethodToFilterTranslator.cs index 845b016ca9e..be7e6d42fc2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/StringInOrNinMethodToFilterTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/StringInOrNinMethodToFilterTranslator.cs @@ -21,31 +21,20 @@ using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ToFilterFieldTranslators; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.MethodTranslators { internal static class StringInOrNinMethodToFilterTranslator { - private static readonly MethodInfo[] __stringInOrNinMethods = - { - StringMethod.AnyStringInWithEnumerable, - StringMethod.AnyStringInWithParams, - StringMethod.AnyStringNinWithEnumerable, - StringMethod.AnyStringNinWithParams, - StringMethod.StringInWithEnumerable, - StringMethod.StringInWithParams, - StringMethod.StringNinWithEnumerable, - StringMethod.StringNinWithParams, - }; - - private static readonly MethodInfo[] __stringInMethods = - { - StringMethod.AnyStringInWithEnumerable, - StringMethod.AnyStringInWithParams, - StringMethod.StringInWithEnumerable, - StringMethod.StringInWithParams, - }; + private static readonly IReadOnlyMethodInfoSet __translatableOverloads = MethodInfoSet.Create( + [ + StringMethod.AnyStringInOverloads, + StringMethod.AnyStringNinOverloads, + StringMethod.StringInOverloads, + StringMethod.StringNinOverloads + ]); // public static methods public static AstFilter Translate(TranslationContext context, MethodCallExpression expression) @@ -53,7 +42,7 @@ public static AstFilter Translate(TranslationContext context, MethodCallExpressi var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__stringInOrNinMethods)) + if (method.IsOneOf(__translatableOverloads)) { var fieldExpression = arguments[0]; var fieldTranslation = ExpressionToFilterFieldTranslator.Translate(context, fieldExpression); @@ -82,7 +71,9 @@ public static AstFilter Translate(TranslationContext context, MethodCallExpressi serializedValues.Add(serializedValue); } - return method.IsOneOf(__stringInMethods) ? AstFilter.In(fieldTranslation.Ast, serializedValues) : AstFilter.Nin(fieldTranslation.Ast, serializedValues); + return method.IsOneOf(StringMethod.AnyStringInOverloads, StringMethod.StringInOverloads) ? + AstFilter.In(fieldTranslation.Ast, serializedValues) : + AstFilter.Nin(fieldTranslation.Ast, serializedValues); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs index 7a734d7a075..86f0ce6372f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs @@ -148,7 +148,7 @@ private static TranslatedFilterField TranslateConvertEnumToUnderlyingType(Transl enumSerializer = fieldSerializer; } - var targetSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer); + var targetSerializer = AsEnumUnderlyingTypeSerializer.Create(enumSerializer); if (targetType.IsNullable()) { targetSerializer = NullableSerializer.Create(targetSerializer); @@ -186,7 +186,7 @@ private static TranslatedFilterField TranslateConvertUnderlyingTypeToEnum(Transl } IBsonSerializer targetSerializer; - if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) + if (valueSerializer is IAsEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer) { targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ConcatMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ConcatMethodToPipelineTranslator.cs index 03fb1ecb1b7..5aa93e38363 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ConcatMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ConcatMethodToPipelineTranslator.cs @@ -44,7 +44,7 @@ secondProvider.CollectionNamespace is var secondCollectionNamespace && secondCollectionNamespace != null) { var secondCollectionName = secondCollectionNamespace.CollectionName; - var secondContext = TranslationContext.Create(context.TranslationOptions); + var secondContext = TranslationContext.Create(secondQueryable, context.TranslationOptions); var secondPipeline = ExpressionToPipelineTranslator.Translate(secondContext, secondQueryable.Expression); if (secondPipeline.Ast.Stages.Count == 0) { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupByMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupByMethodToPipelineTranslator.cs index ee2b6bf7976..cc6da6c9cf5 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupByMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/GroupByMethodToPipelineTranslator.cs @@ -30,42 +30,13 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToPipeli { internal static class GroupByMethodToPipelineTranslator { - // private static fields - private static readonly MethodInfo[] __groupByMethods; - private static readonly MethodInfo[] __groupByWithElementSelectorMethods; - private static readonly MethodInfo[] __groupByWithResultSelectorMethods; - - // static constructor - static GroupByMethodToPipelineTranslator() - { - __groupByMethods = new[] - { - QueryableMethod.GroupByWithKeySelector, - QueryableMethod.GroupByWithKeySelectorAndElementSelector, - QueryableMethod.GroupByWithKeySelectorAndResultSelector, - QueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector - }; - - __groupByWithElementSelectorMethods = new[] - { - QueryableMethod.GroupByWithKeySelectorAndElementSelector, - QueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector - }; - - __groupByWithResultSelectorMethods = new[] - { - QueryableMethod.GroupByWithKeySelectorAndResultSelector, - QueryableMethod.GroupByWithKeySelectorElementSelectorAndResultSelector - }; - } - // public static methods public static TranslatedPipeline Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__groupByMethods)) + if (method.IsOneOf(QueryableMethod.GroupByOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); @@ -85,7 +56,7 @@ public static TranslatedPipeline Translate(TranslationContext context, MethodCal fields: AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, elementAst)), groupingSerializer); - if (method.IsOneOf(__groupByWithResultSelectorMethods)) + if (method.IsOneOf(QueryableMethod.GroupByWithResultSelectorOverloads)) { pipeline = TranslateResultSelector(context, pipeline, arguments, keySerializer, elementSerializer); } @@ -104,7 +75,7 @@ private static (AstExpression, IBsonSerializer) TranslateElement( { AstExpression elementAst; IBsonSerializer elementSerializer; - if (method.IsOneOf(__groupByWithElementSelectorMethods)) + if (method.IsOneOf(QueryableMethod.GroupByWithElementSelectorOverloads)) { var elementLambda = ExpressionHelper.UnquoteLambda(arguments[2]); var elementTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, elementLambda, sourceSerializer, asRoot: true); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs index 9bc144c8875..ebffbd19957 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs @@ -31,58 +31,13 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToPipeli { internal static class LookupMethodToPipelineTranslator { - // private static fields - private static readonly MethodInfo[] __lookupMethods = - { - MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignField, - MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, - MongoQueryableMethod.LookupWithDocumentsAndPipeline, - MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignField, - MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignFieldAndPipeline, - MongoQueryableMethod.LookupWithFromAndPipeline - }; - - private static readonly MethodInfo[] __lookupMethodsWithDocuments = - { - MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignField, - MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, - MongoQueryableMethod.LookupWithDocumentsAndPipeline - }; - - private static readonly MethodInfo[] __lookupMethodsWithDocumentsAndPipeline = - { - MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, - MongoQueryableMethod.LookupWithDocumentsAndPipeline - }; - - private static readonly MethodInfo[] __lookupMethodsWithFrom = - { - MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignField, - MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignFieldAndPipeline, - MongoQueryableMethod.LookupWithFromAndPipeline - }; - - private static readonly MethodInfo[] __lookupMethodsWithFromAndPipeline = - { - MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignFieldAndPipeline, - MongoQueryableMethod.LookupWithFromAndPipeline - }; - - private static readonly MethodInfo[] __lookupMethodsWithLocalFieldAndForeignField = - { - MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignField, - MongoQueryableMethod.LookupWithDocumentsAndLocalFieldAndForeignFieldAndPipeline, - MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignField, - MongoQueryableMethod.LookupWithFromAndLocalFieldAndForeignFieldAndPipeline - }; - // public static methods public static TranslatedPipeline Translate(TranslationContext context, MethodCallExpression expression) { var method = expression.Method; var arguments = expression.Arguments; - if (method.IsOneOf(__lookupMethods)) + if (method.IsOneOf(MongoQueryableMethod.LookupOverloads)) { var sourceExpression = arguments[0]; var pipeline = ExpressionToPipelineTranslator.Translate(context, sourceExpression); @@ -97,7 +52,7 @@ public static TranslatedPipeline Translate(TranslationContext context, MethodCal IMongoCollection foreignCollection = null; string foreignCollectionName = null; IBsonSerializer foreignSerializer = null; - if (method.IsOneOf(__lookupMethodsWithFrom)) + if (method.IsOneOf(MongoQueryableMethod.LookupWithFromOverloads)) { var fromExpression = arguments[1]; foreignCollection = fromExpression.GetConstantValue(expression); @@ -107,13 +62,13 @@ public static TranslatedPipeline Translate(TranslationContext context, MethodCal TranslatedPipeline lookupPipeline = null; var isCorrelatedSubquery = false; - if (method.IsOneOf(__lookupMethodsWithDocuments)) + if (method.IsOneOf(MongoQueryableMethod.LookupWithDocumentsOverloads)) { var documentsLambda = ExpressionHelper.UnquoteLambda(arguments[1]); var documentsPipeline = TranslateDocuments(context, documentsLambda, localSerializer); var documentSerializer = documentsPipeline.OutputSerializer; - if (method.IsOneOf(__lookupMethodsWithDocumentsAndPipeline)) + if (method.IsOneOf(MongoQueryableMethod.LookupWithDocumentsAndPipelineOverloads)) { var pipelineLambda = ExpressionHelper.UnquoteLambda(arguments.Last()); var localParameter = pipelineLambda.Parameters.First(); @@ -134,7 +89,7 @@ public static TranslatedPipeline Translate(TranslationContext context, MethodCal string localField = null; string foreignField = null; - if (method.IsOneOf(__lookupMethodsWithLocalFieldAndForeignField)) + if (method.IsOneOf(MongoQueryableMethod.LookupWithLocalFieldAndForeignFieldOverloads)) { var localFieldExpression = ExpressionHelper.UnquoteLambda(arguments[2]); var foreignFieldExpression = ExpressionHelper.UnquoteLambda(arguments[3]); @@ -142,7 +97,7 @@ public static TranslatedPipeline Translate(TranslationContext context, MethodCal foreignField = foreignFieldExpression.TranslateToDottedFieldName(context, foreignSerializer); } - if (method.IsOneOf(__lookupMethodsWithFromAndPipeline)) + if (method.IsOneOf(MongoQueryableMethod.LookupWithFromAndPipelineOverloads)) { var pipelineLamda = ExpressionHelper.UnquoteLambda(arguments.Last()); var localParameter = pipelineLamda.Parameters[0]; @@ -281,7 +236,13 @@ private static TranslatedPipeline TranslateDocumentsPipelineGeneric= 1) + { + var sourceParameter = parameters[0]; + var sourceParameterType = sourceParameter.ParameterType; + if (sourceParameterType.IsConstructedGenericType) + { + sourceParameterType = sourceParameterType.GetGenericTypeDefinition(); + } + + if (sourceParameterType == typeof(IQueryable) || + sourceParameterType == typeof(IQueryable<>) || + sourceParameterType == typeof(IOrderedQueryable) || + sourceParameterType == typeof(IOrderedQueryable<>)) + { + return GetUltimateSource(methodCallExpression.Arguments[0]); + } + } + + throw new ArgumentException($"No ultimate source found: {expression}."); } #endregion // private fields private readonly TranslationContextData _data; + private readonly IReadOnlySerializerMap _nodeSerializers; private readonly NameGenerator _nameGenerator; private readonly SymbolTable _symbolTable; private readonly ExpressionTranslationOptions _translationOptions; private TranslationContext( ExpressionTranslationOptions translationOptions, + IReadOnlySerializerMap nodeSerializers, TranslationContextData data, SymbolTable symbolTable, NameGenerator nameGenerator) { _translationOptions = translationOptions ?? new ExpressionTranslationOptions(); + _nodeSerializers = Ensure.IsNotNull(nodeSerializers, nameof(nodeSerializers)); _data = data; // can be null _symbolTable = Ensure.IsNotNull(symbolTable, nameof(symbolTable)); _nameGenerator = Ensure.IsNotNull(nameGenerator, nameof(nameGenerator)); @@ -54,6 +146,7 @@ private TranslationContext( // public properties public TranslationContextData Data => _data; + public IReadOnlySerializerMap NodeSerializers => _nodeSerializers; public NameGenerator NameGenerator => _nameGenerator; public SymbolTable SymbolTable => _symbolTable; public ExpressionTranslationOptions TranslationOptions => _translationOptions; @@ -99,6 +192,11 @@ public Symbol CreateSymbolWithVarName(ParameterExpression parameter, string varN return CreateSymbol(parameter, name: parameterName, varName, serializer, isCurrent); } + public IBsonSerializer GetSerializer(Expression parameter) + { + return _nodeSerializers.GetSerializer(parameter); + } + public override string ToString() { return $"{{ SymbolTable : {_symbolTable} }}"; @@ -124,7 +222,7 @@ public TranslationContext WithSymbols(params Symbol[] newSymbols) public TranslationContext WithSymbolTable(SymbolTable symbolTable) { - return new TranslationContext(_translationOptions, _data, symbolTable, _nameGenerator); + return new TranslationContext(_translationOptions, _nodeSerializers, _data, symbolTable, _nameGenerator); } } } diff --git a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs index 4da85f6bd6d..57ce448f981 100644 --- a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs +++ b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs @@ -14,17 +14,16 @@ */ using System; -using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using MongoDB.Bson; using MongoDB.Bson.Serialization; -using MongoDB.Driver; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Linq.Linq3Implementation; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages; using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.SerializerFinders; using MongoDB.Driver.Linq.Linq3Implementation.Translators; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators; @@ -61,7 +60,8 @@ internal static BsonValue TranslateExpressionToAggregateExpression>)LinqExpressionPreprocessor.Preprocess(expression); - var context = TranslationContext.Create(translationOptions, contextData); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: sourceSerializer, translationOptions: translationOptions, data: contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -76,7 +76,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( { expression = (LambdaExpression)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var body = RemovePossibleConvertToObject(expression.Body); @@ -106,7 +106,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var fieldTranslation = ExpressionToFilterFieldTranslator.Translate(context, expression.Body); @@ -125,8 +125,8 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); - var context = TranslationContext.Create(translationOptions); var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: elementSerializer, translationOptions: translationOptions); var symbol = context.CreateSymbol(parameter, "@", elementSerializer); // @ represents the implied element context = context.WithSingleSymbol(symbol); // @ is the only symbol visible inside an $elemMatch var filter = ExpressionToFilterTranslator.Translate(context, expression.Body, exprOk: false); @@ -142,7 +142,8 @@ internal static BsonDocument TranslateExpressionToFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); - var context = TranslationContext.Create(translationOptions); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: documentSerializer, translationOptions: translationOptions); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -176,7 +177,8 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec } expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); - var context = TranslationContext.Create(translationOptions); + var parameter = expression.Parameters.Single(); + var context = TranslationContext.Create(expression, initialNode: parameter, initialSerializer: inputSerializer, translationOptions: translationOptions); var simplifier = forFind ? new AstFindProjectionSimplifier() : new AstSimplifier(); try @@ -215,8 +217,18 @@ internal static BsonDocument TranslateExpressionToSetStage( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - var context = TranslationContext.Create(translationOptions); // do not partially evaluate expression var parameter = expression.Parameters.Single(); + var body = expression.Body; + + var nodeSerializers = new SerializerMap(); + nodeSerializers.AddSerializer(parameter, documentSerializer); + if (body.Type == typeof(TDocument)) + { + nodeSerializers.AddSerializer(body, documentSerializer); + } + SerializerFinder.FindSerializers(expression, translationOptions, nodeSerializers); + + var context = TranslationContext.Create(translationOptions, nodeSerializers); // do not partially evaluate expression var symbol = context.CreateRootSymbol(parameter, documentSerializer); context = context.WithSymbol(symbol); var setStage = ExpressionToSetStageTranslator.Translate(context, documentSerializer, expression); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs index 035bba42f7e..fe4117ce430 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2472Tests.cs @@ -17,6 +17,8 @@ using System.Collections.Generic; using System.Linq; using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.TestHelpers; using Xunit; @@ -79,7 +81,7 @@ public class C private class MyDTO { public DateTime timestamp { get; set; } - public decimal sqrt_calc { get; set; } + [BsonRepresentation(BsonType.Decimal128)] public decimal sqrt_calc { get; set; } } public sealed class ClassFixture : MongoCollectionFixture diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs index 524b72ff602..5177538f54c 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4054Tests.cs @@ -43,6 +43,17 @@ from movieId in person.MovieIds join movie in movies.AsQueryable() on movieId equals movie.Id select new { person, movie }; + // equivalement method call syntax + // var queryable = people.AsQueryable() + // .SelectMany( + // person => person.MovieIds, + // (person, movieId) => new { person = person, movieId = movieId }) + // .Join( + // movies.AsQueryable(), + // transparentIdentifier => transparentIdentifier.movieId, + // movie => movie.Id, + // (transparentIdentifier, movie) => new { person = transparentIdentifier.person, movie = movie }); + var stages = Translate(people, queryable); AssertStages( stages, diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs new file mode 100644 index 00000000000..0224126f84d --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4593Tests.cs @@ -0,0 +1,144 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using FluentAssertions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4593Tests : LinqIntegrationTest +{ + public CSharp4593Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void First_example_should_work() + { + var collection = Fixture.Orders; + + var find = collection + .Find(o => o.RateBasisHistoryId == "abc") + .Project(r => r.Id); + + var translatedFilter = TranslateFindFilter(collection, find); + translatedFilter.Should().Be("{ RateBasisHistoryId : 'abc' }"); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ _id : 1 }"); + + var result = find.Single(); + result.Should().Be("a"); + } + + [Fact] + public void First_example_workaround_should_work() + { + var collection = Fixture.Orders; + + var find = collection + .Find(o => o.RateBasisHistoryId == "abc") + .Project(Builders.Projection.Include(o => o.Id)); + + var translatedFilter = TranslateFindFilter(collection, find); + translatedFilter.Should().Be("{ RateBasisHistoryId : 'abc' }"); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ _id : 1 }"); + + var result = find.Single(); + result["_id"].AsString.Should().Be("a"); + } + + [Fact] + public void Second_example_should_work() + { + var collection = Fixture.Entities; + var idsFilter = Builders.Filter.Eq(x => x.Id, 1); + + var aggregate = collection.Aggregate() + .Match(idsFilter) + .Project(e => new + { + _id = e.Id, + CampaignId = e.CampaignId, + Accepted = e.Status.Key == "Accepted" ? 1 : 0, + Rejected = e.Status.Key == "Rejected" ? 1 : 0, + }); + + var stages = Translate(collection, aggregate); + AssertStages( + stages, + "{ $match : { _id : 1 } }", + """ + { $project : + { + _id : "$_id", + CampaignId : "$CampaignId", + Accepted : { $cond : { if : { $eq : ["$Status.Key", "Accepted"] }, then : 1, else : 0 } }, + Rejected : { $cond : { if : { $eq : ["$Status.Key", "Rejected"] }, then : 1, else : 0 } } + } + } + """); + + var results = aggregate.ToList(); + results.Count.Should().Be(1); + results[0]._id.Should().Be(1); + results[0].CampaignId.Should().Be(11); + results[0].Accepted.Should().Be(1); + results[0].Rejected.Should().Be(0); + } + + public class Order + { + public string Id { get; set; } + public string RateBasisHistoryId { get; set; } + } + + public class Entity + { + public int Id { get; set; } + public int CampaignId { get; set; } + public Status Status { get; set; } + } + + public class Status + { + public string Key { get; set; } + } + + public sealed class ClassFixture : MongoDatabaseFixture + { + public IMongoCollection Orders { get; private set; } + public IMongoCollection Entities { get; private set; } + + protected override void InitializeFixture() + { + Orders = CreateCollection("orders"); + Orders.InsertMany( + [ + new Order { Id = "a", RateBasisHistoryId = "abc" } + ]); + + Entities = CreateCollection("entities"); + Entities.InsertMany( + [ + new Entity { Id = 1, CampaignId = 11, Status = new Status { Key = "Accepted" } }, + new Entity { Id = 2, CampaignId = 22, Status = new Status { Key = "Rejected" } } + ]); + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs index 2164f38e6a0..4225bef829f 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4708Tests.cs @@ -355,7 +355,7 @@ public void Where_Document_item_with_int_using_call_to_get_item_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(int) }).GetGetMethod(), Expression.Constant(0)), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); @@ -379,7 +379,7 @@ public void Where_Document_item_with_int_using_MakeIndex_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(int) }), new Expression[] { Expression.Constant(0) }), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); @@ -418,7 +418,7 @@ public void Where_Document_item_with_string_using_call_to_get_item_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(string) }).GetGetMethod(), Expression.Constant("a")), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); @@ -442,7 +442,7 @@ public void Where_Document_item_with_string_using_MakeIndex_should_work() Expression.Property(x, typeof(C).GetProperty("Document")), typeof(BsonDocument).GetProperty("Item", new[] { typeof(string) }), new Expression[] { Expression.Constant("a") }), - Expression.Constant(BsonValue.Create(1))); + Expression.Constant(BsonValue.Create(1), typeof(BsonValue))); var parameters = new ParameterExpression[] { x }; var predicate = Expression.Lambda>(body, parameters); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs new file mode 100644 index 00000000000..9f8f49eff4e --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4819Tests.cs @@ -0,0 +1,68 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4819Tests : LinqIntegrationTest +{ + public CSharp4819Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void ReplaceWith_should_use_configured_element_name() + { + var collection = Fixture.Collection; + var stage = PipelineStageDefinitionBuilder + .ReplaceWith((User u) => new User { UserId = u.UserId }); + + var aggregate = collection.Aggregate() + .AppendStage(stage); + + var stages = Translate(collection, aggregate); + AssertStages( + stages, + "{ $replaceWith : { uuid : '$uuid' } }"); + + var result = aggregate.Single(); + result.Id.Should().Be(0); + result.UserId.Should().Be(Guid.Parse("00112233-4455-6677-8899-aabbccddeeff")); + } + + public class User + { + public int Id { get; set; } + [BsonElement("uuid")] + [BsonGuidRepresentation(GuidRepresentation.Standard)] + public Guid UserId { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new User { Id = 1, UserId = Guid.Parse("00112233-4455-6677-8899-aabbccddeeff") } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs new file mode 100644 index 00000000000..18be97f693c --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4820Tests.cs @@ -0,0 +1,114 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4820Tests : LinqIntegrationTest +{ + public CSharp4820Tests(ClassFixture fixture) + : base(fixture) + { + } + + static CSharp4820Tests() + { + BsonClassMap.RegisterClassMap(cm => + { + cm.AutoMap(); + var readonlyCollectionMemberMap = cm.GetMemberMap(x => x.ReadOnlyCollection); + var readOnlyCollectionSerializer = readonlyCollectionMemberMap.GetSerializer(); + var bracketingCollectionSerializer = ((IChildSerializerConfigurable)readOnlyCollectionSerializer).WithChildSerializer(new StringBracketingSerializer()); + readonlyCollectionMemberMap.SetSerializer(bracketingCollectionSerializer); + }); + } + + [Fact] + public void Update_Set_with_List_should_work() + { + var values = new List() { "abc", "def" }; + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + [Fact] + public void Update_Set_with_Enumerable_should_throw() + { + var values = new[] { "abc", "def" }.Select(x => x); + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + [Fact] + public void Update_Set_with_Enumerable_ToList_should_work() + { + var values = new[] { "abc", "def" }.Select(x => x); + var update = Builders.Update.Set(x => x.ReadOnlyCollection, values.ToList()); + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var rendered = (BsonDocument)update.Render(new (documentSerializer, serializerRegistry)); + + rendered.Should().Be("{ $set : { ReadOnlyCollection : ['[abc]', '[def]'] } }"); + } + + public class C + { + public int Id { get; set; } + public IReadOnlyCollection ReadOnlyCollection { get; set; } + } + + + private class StringBracketingSerializer : SerializerBase + { + public override string Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var bracketedValue = StringSerializer.Instance.Deserialize(context, args); + return bracketedValue.Substring(1, bracketedValue.Length - 2); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, string value) + { + var bracketedValue = "[" + value + "]"; + StringSerializer.Instance.Serialize(context, bracketedValue); + } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => null; + // [ + // new C { } + // ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs index 791ce3bcd75..e82194ef6cc 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4957Tests.cs @@ -84,7 +84,7 @@ public void New_array_with_two_items_should_work() [Theory] [ParameterAttributeData] - public void New_array_with_two_items_with_different_serializers_should_throw( + public void New_array_with_two_items_with_different_serializers_should_work( [Values(false, true)] bool enableClientSideProjections) { RequireServer.Check().Supports(Feature.FindProjectionExpressions); @@ -94,21 +94,11 @@ public void New_array_with_two_items_with_different_serializers_should_throw( var queryable = collection.AsQueryable(translationOptions) .Select(x => new[] { x.X, x.Y }); - if (enableClientSideProjections) - { - var stages = Translate(collection, queryable, out var outputSerializer); - AssertStages(stages, "{ $project : { _snippets : ['$X', '$Y'], _id : 0 } }"); - outputSerializer.Should().BeAssignableTo(); - - var result = queryable.Single(); - result.Should().Equal(1, 2); - } - else - { - var exception = Record.Exception(() => Translate(collection, queryable)); - exception.Should().BeOfType(); - exception.Message.Should().Contain("all items in the array must be serialized using the same serializer"); - } + var stages = Translate(collection, queryable, out var outputSerializer); + AssertStages(stages, "{ $project : { _v : ['$X', '$Y'], _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(1, 2); } public class C diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs new file mode 100644 index 00000000000..a93e1b4f387 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4967Tests.cs @@ -0,0 +1,75 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson.Serialization; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp4967Tests : LinqIntegrationTest +{ + public CSharp4967Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Set_Nested_should_work() + { + var collection = Fixture.Collection; + var update = Builders.Update + .Pipeline(new EmptyPipelineDefinition() + .Set(c => new MyDocument + { + Nested = new MyNestedDocument + { + ValueCopy = c.Value, + }, + })); + + var renderedUpdate = update.Render(new(collection.DocumentSerializer, BsonSerializer.SerializerRegistry)).AsBsonArray; + renderedUpdate.Count.Should().Be(1); + renderedUpdate[0].Should().Be("{ $set : { Nested : { ValueCopy : '$Value' } } }"); + + collection.UpdateMany("{ }", update); + + var updatedDocument = collection.FindSync("{}").Single(); + updatedDocument.Nested.ValueCopy.Should().Be("Value"); + } + + public class MyDocument + { + public int Id { get; set; } + public string Value { get; set; } + public string AnotherValue { get; set; } + public MyNestedDocument Nested { get; set; } + } + + public class MyNestedDocument + { + public string ValueCopy { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new MyDocument { Id = 1, Value = "Value" } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs new file mode 100644 index 00000000000..d188c18fa1f --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -0,0 +1,225 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5435Tests : Linq3IntegrationTest + { + [Fact] + public void Test_set_ValueObject_Value_using_creator_map() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue(x.ValueObject == null ? 1 : x.ValueObject.Value + 1) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_Value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_to_derived_value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyDerivedValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1, + B = 42 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_X_using_constructor() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + X = new X(x.Y) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { X : { Y : '$Y' } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_A() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + A = new [] { 2, x.A[0] } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { A : ['2', { $arrayElemAt : ['$A', 0] }] } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection.Database.GetCollection("test"), + BsonDocument.Parse("{ _id : 1 }"), + BsonDocument.Parse("{ _id : 2, X : null }"), + BsonDocument.Parse("{ _id : 3, X : 3 }")); + return collection; + } + + class MyDocument + { + [BsonRepresentation(MongoDB.Bson.BsonType.ObjectId)] + public string Id { get; set; } = ObjectId.GenerateNewId().ToString(); + + public MyValue ValueObject { get; set; } + + public long Long { get; set; } + + public X X { get; set; } + + public int Y { get; set; } + + [BsonRepresentation(BsonType.String)] + public int[] A { get; set; } + } + + class MyValue + { + [BsonConstructor] + public MyValue() { } + [BsonConstructor] + public MyValue(int value) { Value = value; } + public int Value { get; set; } + } + + class MyDerivedValue : MyValue + { + public int B { get; set; } + } + + [BsonSerializer(typeof(XSerializer))] + class X + { + public X(int y) + { + Y = y; + } + public int Y { get; } + } + + class XSerializer : SerializerBase, IBsonDocumentSerializer + { + public override X Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + reader.ReadStartArray(); + _ = reader.ReadName(); + var y = reader.ReadInt32(); + reader.ReadEndDocument(); + + return new X(y); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, X value) + { + var writer = context.Writer; + writer.WriteStartDocument(); + writer.WriteName("Y"); + writer.WriteInt32(value.Y); + writer.WriteEndDocument(); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + serializationInfo = memberName == "Y" ? new BsonSerializationInfo("Y", Int32Serializer.Instance, typeof(int)) : null; + return serializationInfo != null; + } + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs new file mode 100644 index 00000000000..30f3a73072a --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5519Tests.cs @@ -0,0 +1,66 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5519Tests : LinqIntegrationTest +{ + public CSharp5519Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Array_constant_Any_should_serialize_array_correctly() + { + var collection = Fixture.Collection; + var array = new[] { E.A, E.B }; + + var find = collection.Find(x => array.Any(e => x.E == e)); + + var filter = TranslateFindFilter(collection, find); + filter.Should().Be("{ E : { $in : ['A', 'B'] } }"); + + var results = find.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + + public class C + { + public int Id { get; set; } + [BsonRepresentation(BsonType.String)] public E E { get; set; } + } + + public enum E { A, B, C } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new C { Id = 1, E = E.A }, + new C { Id = 2, E = E.B }, + new C { Id = 3, E = E.C } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs new file mode 100644 index 00000000000..10c9e294982 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5532Tests.cs @@ -0,0 +1,199 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Core.Misc; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5532Tests : LinqIntegrationTest +{ + private static readonly ObjectId id1 = ObjectId.Parse("111111111111111111111111"); + private static readonly ObjectId id2 = ObjectId.Parse("222222222222222222222222"); + private static readonly ObjectId id3 = ObjectId.Parse("333333333333333333333333"); + + public CSharp5532Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void Filter_should_translate_correctly() + { + var collection = Fixture.Collection; + List jobIds = [id2.ToString()]; + + var find = collection + .Find(x => x.Parts.Any(a => a.Refs.Any(b => jobIds.Contains(b.id)))); + + var filter = TranslateFindFilter(collection, find); + + filter.Should().Be("{ Parts : { $elemMatch : { Refs : { $elemMatch : { _id : { $in : [ObjectId('222222222222222222222222')] } } } } } }"); + } + + [Fact] + public void Projection_should_translate_correctly() + { + var collection = Fixture.Collection; + List jobIds = [id2.ToString()]; + + var find = collection + .Find("{}") + .Project(chain => + new + { + chain.Parts + .First(p => p.Refs.Any(j => jobIds.Contains(j.id))) + .Refs.First(j => jobIds.Contains(j.id)).id + });; + + var projectionTranslation = TranslateFindProjection(collection, find); + + var expectedTranslation = + """ + { + _id : + { + $let : + { + vars : + { + this : + { + $arrayElemAt : + [ + { + $filter : + { + input : + { + $let : + { + vars : + { + this : + { + $arrayElemAt : + [ + { + $filter : + { + input : "$Parts", + as : "p", + cond : + { + $anyElementTrue : + { + $map : + { + input : "$$p.Refs", + as : "j", + in : { $in : ["$$j._id", [{ "$oid" : "222222222222222222222222" }]] } + } + } + }, + limit : 1 + } + }, + 0 + ] + } + }, + in : "$$this.Refs" + } + }, + as : "j", + cond : { $in : ['$$j._id', [{ "$oid" : "222222222222222222222222" }]] }, + limit : 1 + } + }, + 0 + ] + } + }, + in : "$$this._id" + } + } + } + """; + if (!Feature.FilterLimit.IsSupported(CoreTestConfiguration.MaxWireVersion)) + { + expectedTranslation = Regex.Replace(expectedTranslation, @",\s+limit : 1", ""); + } + + projectionTranslation.Should().Be(expectedTranslation); + } + + public class Document + { + [BsonId] + [BsonRepresentation(BsonType.ObjectId)] + public string id { get; set; } + } + + public class Chain : Document + { + public ICollection Parts { get; set; } = new List(); + } + + public class Unit + { + public ICollection Refs { get; set; } + + public Unit() + { + Refs = new List(); + } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new Chain + { + id = "0102030405060708090a0b0c", + Parts = new List() + { + new() + { + Refs = new List() + { + new() + { + id = id1.ToString(), + }, + new() + { + id = id2.ToString(), + }, + new() + { + id = id3.ToString(), + }, + } + } + } + } + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializerTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializerTests.cs similarity index 68% rename from tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializerTests.cs rename to tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializerTests.cs index f6f7ace6d48..de10de29a04 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializerTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Serializers/AsEnumUnderlyingTypeSerializerTests.cs @@ -22,7 +22,7 @@ namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Serializers { - public class EnumUnderlyingTypeSerializerTests + public class AsEnumUnderlyingTypeSerializerTests { private static readonly IBsonSerializer __enumSerializer1 = new ESerializer1(); private static readonly IBsonSerializer __enumSerializer2 = new ESerializer2(); @@ -30,8 +30,8 @@ public class EnumUnderlyingTypeSerializerTests [Fact] public void Equals_derived_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); - var y = new DerivedFromEnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); + var y = new DerivedFromAsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(y); @@ -41,7 +41,7 @@ public void Equals_derived_should_return_false() [Fact] public void Equals_null_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(null); @@ -51,7 +51,7 @@ public void Equals_null_should_return_false() [Fact] public void Equals_object_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var y = new object(); var result = x.Equals(y); @@ -62,7 +62,7 @@ public void Equals_object_should_return_false() [Fact] public void Equals_self_should_return_true() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(x); @@ -72,8 +72,8 @@ public void Equals_self_should_return_true() [Fact] public void Equals_with_equal_fields_should_return_true() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); - var y = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); + var y = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.Equals(y); @@ -83,8 +83,8 @@ public void Equals_with_equal_fields_should_return_true() [Fact] public void Equals_with_not_equal_field_should_return_false() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); - var y = new EnumUnderlyingTypeSerializer(__enumSerializer2); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); + var y = new AsEnumUnderlyingTypeSerializer(__enumSerializer2); var result = x.Equals(y); @@ -94,18 +94,18 @@ public void Equals_with_not_equal_field_should_return_false() [Fact] public void GetHashCode_should_return_zero() { - var x = new EnumUnderlyingTypeSerializer(__enumSerializer1); + var x = new AsEnumUnderlyingTypeSerializer(__enumSerializer1); var result = x.GetHashCode(); result.Should().Be(0); } - internal class DerivedFromEnumUnderlyingTypeSerializer : EnumUnderlyingTypeSerializer + internal class DerivedFromAsEnumUnderlyingTypeSerializer : AsEnumUnderlyingTypeSerializer where TEnum : Enum where TEnumUnderlyingType : struct { - public DerivedFromEnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) : base(enumSerializer) + public DerivedFromAsEnumUnderlyingTypeSerializer(IBsonSerializer enumSerializer) : base(enumSerializer) { } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs index 10f3f2a5d14..08f071902ab 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs @@ -31,8 +31,8 @@ public class ModuloComparisonExpressionToFilterTranslatorTests [Fact] public void Translate_should_return_expected_result_with_byte_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Byte % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Byte % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -44,8 +44,8 @@ public void Translate_should_return_expected_result_with_byte_arguments() [Fact] public void Translate_should_return_expected_result_with_decimal_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Decimal % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Decimal % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -57,8 +57,8 @@ public void Translate_should_return_expected_result_with_decimal_arguments() [Fact] public void Translate_should_return_expected_result_with_double_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Double % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Double % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -70,8 +70,8 @@ public void Translate_should_return_expected_result_with_double_arguments() [Fact] public void Translate_should_return_expected_result_with_float_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Float % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Float % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -83,8 +83,8 @@ public void Translate_should_return_expected_result_with_float_arguments() [Fact] public void Translate_should_return_expected_result_with_int16_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Int16 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Int16 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -96,8 +96,8 @@ public void Translate_should_return_expected_result_with_int16_arguments() [Fact] public void Translate_should_return_expected_result_with_int32_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Int32 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Int32 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -109,8 +109,8 @@ public void Translate_should_return_expected_result_with_int32_arguments() [Fact] public void Translate_should_return_expected_result_with_int64_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.Int64 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.Int64 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -122,8 +122,8 @@ public void Translate_should_return_expected_result_with_int64_arguments() [Fact] public void Translate_should_return_expected_result_with_sbyte_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.SByte % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.SByte % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -135,8 +135,8 @@ public void Translate_should_return_expected_result_with_sbyte_arguments() [Fact] public void Translate_should_return_expected_result_with_uint16_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.UInt16 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.UInt16 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -148,8 +148,8 @@ public void Translate_should_return_expected_result_with_uint16_arguments() [Fact] public void Translate_should_return_expected_result_with_uint32_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.UInt32 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.UInt32 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -161,8 +161,8 @@ public void Translate_should_return_expected_result_with_uint32_arguments() [Fact] public void Translate_should_return_expected_result_with_uint64_arguments() { - var (parameter, expression) = CreateExpression((C c) => c.UInt64 % 2 == 1); - var context = CreateContext(parameter); + var (lambdaExpression, expression) = CreateExpression((C c) => c.UInt64 % 2 == 1); + var context = CreateContext(lambdaExpression); var canTranslate = ModuloComparisonExpressionToFilterTranslator.CanTranslate(expression.Left, expression.Right, out var moduloExpression, out var remainderExpression); canTranslate.Should().BeTrue(); @@ -180,19 +180,19 @@ private void Assert(AstFilter result, string path, BsonValue divisor, BsonValue modFilterOperation.Remainder.Should().Be(remainder); } - private TranslationContext CreateContext(ParameterExpression parameter) + private TranslationContext CreateContext(LambdaExpression lambda) { + var parameter = lambda.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(parameter.Type); - var context = TranslationContext.Create(translationOptions: null); + var context = TranslationContext.Create(lambda, parameter, serializer, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); return context.WithSymbol(symbol); } - private (ParameterExpression, BinaryExpression) CreateExpression(Expression> lambda) + private (LambdaExpression, BinaryExpression) CreateExpression(Expression> lambda) { - var parameter = lambda.Parameters.Single(); var expression = (BinaryExpression)lambda.Body; - return (parameter, expression); + return (lambda, expression); } private class C diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index a8f7428079b..5c8f9e967e1 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -641,7 +641,7 @@ private ProjectedResult Group(Expression Project(Expression var query = __collection.AsQueryable().Select(projector); var provider = (MongoQueryProvider)query.Provider; + var inputSerializer = (IBsonSerializer)provider.PipelineInputSerializer; + var serializerRegistry = provider.Collection.Settings.SerializerRegistry; var translationOptions = new ExpressionTranslationOptions { EnableClientSideProjections = false }; - var executableQuery = ExpressionToExecutableQueryTranslator.Translate(provider, query.Expression, translationOptions); - var projection = executableQuery.Pipeline.Ast.Stages.First().Render()["$project"].AsBsonDocument; + var renderedProjection = LinqProviderAdapter.TranslateExpressionToProjection( + projector, + inputSerializer, + serializerRegistry, + translationOptions); + + var projection = renderedProjection.Document; var value = query.Take(1).FirstOrDefault(); return new ProjectedResult diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs index fa01543be13..a4c53d6fd03 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs @@ -1184,7 +1184,7 @@ private void Assert(Expression> expression, int var parameter = expression.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(); - var context = TranslationContext.Create(translationOptions: null); + var context = TranslationContext.Create(expression, parameter, serializer, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, expression.Body); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs index c96a2a96b9a..2e306b35c5a 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs @@ -1152,9 +1152,9 @@ public List Assert(IMongoCollection collection, { filter = (Expression>)LinqExpressionPreprocessor.Preprocess(filter); - var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); var parameter = filter.Parameters.Single(); - var context = TranslationContext.Create(translationOptions: null); + var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); + var context = TranslationContext.Create(filter, parameter, serializer, translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, filter.Body);