Skip to content

Commit 4971ea1

Browse files
committed
CSHARP-4985: Verify that operands to numeric operators in LINQ expressions are represented as numbers on the server.
1 parent f9f582e commit 4971ea1

19 files changed

+254
-24
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs

+48
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
using MongoDB.Bson.Serialization.Options;
2222
using MongoDB.Bson.Serialization.Serializers;
2323
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
24+
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators;
2425

2526
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
2627
{
@@ -35,6 +36,11 @@ public static void EnsureRepresentationIsArray(Expression expression, IBsonSeria
3536
}
3637
}
3738

39+
public static void EnsureRepresentationIsNumeric(Expression expression, AggregationExpression translation)
40+
{
41+
EnsureRepresentationIsNumeric(expression, translation.Serializer);
42+
}
43+
3844
public static void EnsureRepresentationIsNumeric(Expression expression, IBsonSerializer serializer)
3945
{
4046
var representation = GetRepresentation(serializer);
@@ -56,6 +62,11 @@ public static BsonType GetRepresentation(IBsonSerializer serializer)
5662
return GetRepresentation(downcastingSerializer.DerivedSerializer);
5763
}
5864

65+
if (serializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer)
66+
{
67+
return GetRepresentation(enumUnderlyingTypeSerializer.EnumSerializer);
68+
}
69+
5970
if (serializer is IImpliedImplementationInterfaceSerializer impliedImplementationSerializer)
6071
{
6172
return GetRepresentation(impliedImplementationSerializer.ImplementationSerializer);
@@ -82,6 +93,11 @@ public static BsonType GetRepresentation(IBsonSerializer serializer)
8293
return keyValuePairSerializer.Representation;
8394
}
8495

96+
if (serializer is INullableSerializer nullableSerializer)
97+
{
98+
return GetRepresentation(nullableSerializer.ValueSerializer);
99+
}
100+
85101
// for backward compatibility assume that any remaining implementers of IBsonDocumentSerializer are represented as documents
86102
if (serializer is IBsonDocumentSerializer)
87103
{
@@ -97,6 +113,15 @@ public static BsonType GetRepresentation(IBsonSerializer serializer)
97113
return BsonType.Undefined;
98114
}
99115

116+
public static bool IsIntegerRepresentation(BsonType representation)
117+
{
118+
return representation switch
119+
{
120+
BsonType.Int32 or BsonType.Int64 => true,
121+
_ => false
122+
};
123+
}
124+
100125
public static bool IsNumericRepresentation(BsonType representation)
101126
{
102127
return representation switch
@@ -111,6 +136,29 @@ public static bool IsRepresentedAsDocument(IBsonSerializer serializer)
111136
return SerializationHelper.GetRepresentation(serializer) == BsonType.Document;
112137
}
113138

139+
public static bool IsRepresentedAsInteger(IBsonSerializer serializer)
140+
{
141+
var representation = GetRepresentation(serializer);
142+
return IsIntegerRepresentation(representation);
143+
}
144+
145+
public static bool IsRepresentedAsIntegerOrNullableInteger(AggregationExpression translation)
146+
{
147+
return IsRepresentedAsIntegerOrNullableInteger(translation.Serializer);
148+
}
149+
150+
public static bool IsRepresentedAsIntegerOrNullableInteger(IBsonSerializer serializer)
151+
{
152+
if (serializer is INullableSerializer nullableSerializer)
153+
{
154+
return IsRepresentedAsInteger(nullableSerializer.ValueSerializer);
155+
}
156+
else
157+
{
158+
return IsRepresentedAsInteger(serializer);
159+
}
160+
}
161+
114162
public static BsonValue SerializeValue(IBsonSerializer serializer, ConstantExpression constantExpression, Expression containingExpression)
115163
{
116164
var value = constantExpression.Value;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs

+18-15
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
using System;
1717
using System.Linq.Expressions;
18-
using MongoDB.Bson;
1918
using MongoDB.Bson.Serialization;
2019
using MongoDB.Bson.Serialization.Serializers;
2120
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
@@ -82,6 +81,12 @@ public static AggregationExpression Translate(TranslationContext context, Binary
8281
rightTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, rightExpression);
8382
}
8483

84+
if (IsArithmeticExpression(expression))
85+
{
86+
SerializationHelper.EnsureRepresentationIsNumeric(leftExpression, leftTranslation);
87+
SerializationHelper.EnsureRepresentationIsNumeric(rightExpression, rightTranslation);
88+
}
89+
8590
var ast = expression.NodeType switch
8691
{
8792
ExpressionType.Add => AstExpression.Add(leftTranslation.Ast, rightTranslation.Ast),
@@ -184,7 +189,7 @@ private static bool IsAddOrSubtractExpression(Expression expression)
184189

185190
private static bool IsArithmeticExpression(BinaryExpression expression)
186191
{
187-
return expression.Type.IsNumeric() && IsArithmeticOperator(expression.NodeType);
192+
return expression.Type.IsNumericOrNullableNumeric() && IsArithmeticOperator(expression.NodeType);
188193
}
189194

190195
private static bool IsArithmeticOperator(ExpressionType nodeType)
@@ -304,31 +309,29 @@ private static AggregationExpression TranslateEnumExpression(TranslationContext
304309
leftTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, leftExpression);
305310
rightTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, rightExpression);
306311

312+
AggregationExpression enumTranslation, operandTranslation;
307313
if (IsEnumOrConvertEnumToUnderlyingType(leftExpression))
308314
{
309-
serializer = leftTranslation.Serializer;
315+
enumTranslation = leftTranslation;
316+
operandTranslation = rightTranslation;
310317
}
311318
else
312319
{
313-
serializer = rightTranslation.Serializer;
320+
enumTranslation = rightTranslation;
321+
operandTranslation = leftTranslation;
314322
}
315323

316-
var representation = BsonType.Int32; // assume an integer representation unless we can determine otherwise
317-
var valueSerializer = serializer;
318-
if (valueSerializer is INullableSerializer nullableSerializer)
324+
if (!SerializationHelper.IsRepresentedAsIntegerOrNullableInteger(enumTranslation))
319325
{
320-
valueSerializer = nullableSerializer.ValueSerializer;
321-
}
322-
if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer &&
323-
enumUnderlyingTypeSerializer.EnumSerializer is IHasRepresentationSerializer withRepresentationSerializer)
324-
{
325-
representation = withRepresentationSerializer.Representation;
326+
throw new ExpressionNotSupportedException(expression, because: "arithmetic on enums is only allowed when the enum is represented as an integer");
326327
}
327328

328-
if (representation != BsonType.Int32 && representation != BsonType.Int64)
329+
if (!SerializationHelper.IsRepresentedAsIntegerOrNullableInteger(operandTranslation))
329330
{
330-
throw new ExpressionNotSupportedException(expression, because: "arithmetic on enums is only allowed when the enum is represented as an integer");
331+
throw new ExpressionNotSupportedException(expression, because: "the value being added to or subtracted from an enum must be represented as an integer");
331332
}
333+
334+
serializer = enumTranslation.Serializer;
332335
}
333336
else
334337
{

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AbsMethodToAggregationExpressionTranslator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
4343
{
4444
var valueExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
4545
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
46+
SerializationHelper.EnsureRepresentationIsNumeric(valueExpression, valueTranslation);
4647
var ast = AstExpression.Abs(valueTranslation.Ast);
4748
return new AggregationExpression(expression, ast, valueTranslation.Serializer);
4849
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CeilingMethodToAggregationExpressionTranslator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
3232
{
3333
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
3434
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
35+
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
3536
var ast = AstExpression.Ceil(argumentTranslation.Ast);
3637
var serializer = BsonSerializer.LookupSerializer(expression.Type);
3738
return new AggregationExpression(expression, ast, serializer);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,9 @@ public static AggregationExpression Translate(TranslationContext context, Method
144144
{
145145
throw new ExpressionNotSupportedException(valueExpression, expression);
146146
}
147-
var representation = timeSpanSerializer.Representation;
147+
SerializationHelper.EnsureRepresentationIsNumeric(valueExpression, timeSpanSerializer);
148+
148149
var serializerUnits = timeSpanSerializer.Units;
149-
if (representation != BsonType.Int32 && representation != BsonType.Int64 && representation != BsonType.Double)
150-
{
151-
throw new ExpressionNotSupportedException(valueExpression, expression);
152-
}
153150
(unit, amount) = serializerUnits switch
154151
{
155152
TimeSpanUnits.Ticks => ("millisecond", AstExpression.Divide(valueTranslation.Ast, (double)TimeSpan.TicksPerMillisecond)),
@@ -174,6 +171,8 @@ public static AggregationExpression Translate(TranslationContext context, Method
174171
else
175172
{
176173
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
174+
SerializationHelper.EnsureRepresentationIsNumeric(valueExpression, valueTranslation);
175+
177176
(unit, amount) = method.Name switch
178177
{
179178
"AddTicks" => ("millisecond", AstExpression.Divide(valueTranslation.Ast, (double)TimeSpan.TicksPerMillisecond)),

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ExpMethodToAggregationExpressionTranslator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
3232
{
3333
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
3434
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
35+
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
3536
var ast = AstExpression.Exp(argumentTranslation.Ast);
3637
return new AggregationExpression(expression, ast, new DoubleSerializer());
3738
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FloorMethodToAggregationExpressionTranslator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
3232
{
3333
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
3434
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
35+
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
3536
var ast = AstExpression.Floor(argumentTranslation.Ast);
3637
var serializer = BsonSerializer.LookupSerializer(expression.Type);
3738
return new AggregationExpression(expression, ast, serializer);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfAnyMethodToAggregationExpressionTranslator.cs

+2
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ string TranslateAnyOf(ReadOnlyCollection<Expression> arguments)
115115

116116
var startIndexExpression = arguments[1];
117117
var startIndexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, startIndexExpression);
118+
SerializationHelper.EnsureRepresentationIsNumeric(startIndexExpression, startIndexTranslation);
118119
return AstExpression.UseVarIfNotSimple("startIndex", startIndexTranslation.Ast);
119120
}
120121

@@ -127,6 +128,7 @@ string TranslateAnyOf(ReadOnlyCollection<Expression> arguments)
127128

128129
var countExpression = arguments[2];
129130
var countTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, countExpression);
131+
SerializationHelper.EnsureRepresentationIsNumeric(countExpression, countTranslation);
130132
return AstExpression.UseVarIfNotSimple("count", countTranslation.Ast);
131133
}
132134

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/IndexOfMethodToAggregationExpressionTranslator.cs

+12-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,18 @@ public static AggregationExpression Translate(TranslationContext context, Method
8383
{
8484
var objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression);
8585
var valueTranslation = TranslateValue();
86-
var startIndexTranslation = startIndexExpression == null ? null : ExpressionToAggregationExpressionTranslator.Translate(context, startIndexExpression);
87-
var countTranslation = countExpression == null ? null : ExpressionToAggregationExpressionTranslator.Translate(context, countExpression);
86+
AggregationExpression startIndexTranslation = null;
87+
if (startIndexExpression != null)
88+
{
89+
startIndexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, startIndexExpression);
90+
SerializationHelper.EnsureRepresentationIsNumeric(startIndexExpression, startIndexTranslation);
91+
}
92+
AggregationExpression countTranslation = null;
93+
if (countExpression != null)
94+
{
95+
countTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, countExpression);
96+
SerializationHelper.EnsureRepresentationIsNumeric(countExpression, countTranslation);
97+
}
8898
var ordinal = GetOrdinalFromComparisonType();
8999

90100
var endAst = CreateEndAst(startIndexTranslation?.Ast, countTranslation?.Ast);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/LogMethodToAggregationExpressionTranslator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
3232
{
3333
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
3434
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
35+
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
3536
AstExpression ast;
3637
if (method.Is(MathMethod.LogWithNewBase))
3738
{

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/PowMethodToAggregationExpressionTranslator.cs

+2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ public static AggregationExpression Translate(TranslationContext context, Method
3232
{
3333
var xExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
3434
var xTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, xExpression);
35+
SerializationHelper.EnsureRepresentationIsNumeric(xExpression, xTranslation);
3536
var yExpression = ConvertHelper.RemoveWideningConvert(arguments[1]);
3637
var yTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, yExpression);
38+
SerializationHelper.EnsureRepresentationIsNumeric(yExpression, yTranslation);
3739
var ast = AstExpression.Pow(xTranslation.Ast, yTranslation.Ast);
3840
return new AggregationExpression(expression, ast, new DoubleSerializer());
3941
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/RangeMethodToAggregationExpressionTranslator.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ public static AggregationExpression Translate(TranslationContext context, Method
3434
{
3535
var startExpression = arguments[0];
3636
var startTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, startExpression);
37-
var (startVar, startAst) = AstExpression.UseVarIfNotSimple("start", startTranslation.Ast);
37+
SerializationHelper.EnsureRepresentationIsNumeric(startExpression, startTranslation);
3838
var countExpression = arguments[1];
3939
var countTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, countExpression);
40+
SerializationHelper.EnsureRepresentationIsNumeric(countExpression, countTranslation);
41+
42+
var (startVar, startAst) = AstExpression.UseVarIfNotSimple("start", startTranslation.Ast);
4043
var (countVar, countAst) = AstExpression.UseVarIfNotSimple("count", countTranslation.Ast);
44+
4145
var ast = AstExpression.Let(
4246
startVar,
4347
countVar,

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/RoundMethodToAggregationExpressionTranslator.cs

+2
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ public static AggregationExpression Translate(TranslationContext context, Method
4646
{
4747
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
4848
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
49+
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
4950

5051
AstExpression ast;
5152
if (method.IsOneOf(__roundWithPlaceMethods))
5253
{
5354
var placeExpression = arguments[1];
5455
var placeTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, placeExpression);
56+
SerializationHelper.EnsureRepresentationIsNumeric(placeExpression, placeTranslation);
5557
ast = AstExpression.Round(argumentTranslation.Ast, placeTranslation.Ast);
5658
}
5759
else

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SqrtMethodToAggregationExpressionTranslator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
3232
{
3333
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
3434
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
35+
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
3536
var ast = AstExpression.Sqrt(argumentTranslation.Ast);
3637
return new AggregationExpression(expression, ast, new DoubleSerializer());
3738
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SubstringMethodToAggregationExpressionTranslator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ private static AggregationExpression TranslateHelper(TranslationContext context,
5353
{
5454
var stringTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, stringExpression);
5555
var startIndexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, startIndexExpression);
56+
SerializationHelper.EnsureRepresentationIsNumeric(startIndexExpression, startIndexTranslation);
5657

5758
AstExpression ast;
5859
if (lengthExpression == null)

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/TruncateMethodToAggregationExpressionTranslator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
7373
{
7474
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
7575
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
76+
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
7677
var ast = AstExpression.Trunc(argumentTranslation.Ast);
7778
return new AggregationExpression(expression, ast, argumentTranslation.Serializer);
7879
}

0 commit comments

Comments
 (0)