Skip to content

Commit cfbfa07

Browse files
committed
Translate nested CASE to simpler COALESCE
1 parent c1cf255 commit cfbfa07

File tree

7 files changed

+284
-1
lines changed

7 files changed

+284
-1
lines changed

src/EntityFramework5.Npgsql/EntityFramework5.Npgsql.csproj

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
<Compile Include="..\EntityFramework6.Npgsql\Properties\AssemblyInfo.cs" />
6464
<Compile Include="..\EntityFramework6.Npgsql\Spatial\PostgisDataReader.cs" />
6565
<Compile Include="..\EntityFramework6.Npgsql\Spatial\PostgisServices.cs" />
66+
<Compile Include="..\EntityFramework6.Npgsql\SqlGenerators\CaseIsNullToCoalesceReducer.cs" />
67+
<Compile Include="..\EntityFramework6.Npgsql\SqlGenerators\DbExpressionDeepEqual.cs" />
6668
<Compile Include="..\EntityFramework6.Npgsql\SqlGenerators\PendingProjectsNode.cs" />
6769
<Compile Include="..\EntityFramework6.Npgsql\SqlGenerators\SqlBaseGenerator.cs" />
6870
<Compile Include="..\EntityFramework6.Npgsql\SqlGenerators\SqlDeleteGenerator.cs" />

src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj

+2
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@
7676
<Compile Include="Properties\AssemblyInfo.cs" />
7777
<Compile Include="Spatial\PostgisDataReader.cs" />
7878
<Compile Include="Spatial\PostgisServices.cs" />
79+
<Compile Include="SqlGenerators\CaseIsNullToCoalesceReducer.cs" />
7980
<Compile Include="SqlGenerators\PendingProjectsNode.cs" />
81+
<Compile Include="SqlGenerators\DbExpressionDeepEqual.cs" />
8082
<Compile Include="SqlGenerators\SqlBaseGenerator.cs" />
8183
<Compile Include="SqlGenerators\SqlDeleteGenerator.cs" />
8284
<Compile Include="SqlGenerators\SqlInsertGenerator.cs" />

src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ public override ReadOnlyCollection<EdmFunction> GetStoreFunctions()
359359
.ToList()
360360
.AsReadOnly();
361361

362-
static EdmFunction CreateComposableEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo)
362+
internal static EdmFunction CreateComposableEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo)
363363
{
364364
if (method == null)
365365
throw new ArgumentNullException(nameof(method));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
using System.Collections.Generic;
2+
using System.Data.Entity.Core.Metadata.Edm;
3+
using System.Linq;
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Data.Common;
7+
using System.Data.Entity.Core.Common.CommandTrees.ExpressionBuilder;
8+
using System.Diagnostics;
9+
#if ENTITIES6
10+
using System.Globalization;
11+
using System.Data.Entity.Core.Common.CommandTrees;
12+
using System.Data.Entity.Core.Metadata.Edm;
13+
#else
14+
using System.Data.Common.CommandTrees;
15+
using System.Data.Metadata.Edm;
16+
#endif
17+
using JetBrains.Annotations;
18+
19+
20+
namespace Npgsql.SqlGenerators
21+
{
22+
public class CaseIsNullToCoalesceReducer
23+
{
24+
public static DbFunctionExpression InvokeCoalesceExpression(params DbExpression[] argumentExpressions)
25+
{
26+
var fromClrType = PrimitiveType
27+
.GetEdmPrimitiveTypes()
28+
.FirstOrDefault(t => t.ClrEquivalentType == typeof(string));
29+
30+
int i=0;
31+
var func = EdmFunction.Create(
32+
"coalesce",
33+
"Npgsql",
34+
DataSpace.SSpace,
35+
new EdmFunctionPayload
36+
{
37+
ParameterTypeSemantics = ParameterTypeSemantics.AllowImplicitConversion,
38+
Schema = string.Empty,
39+
IsBuiltIn = true,
40+
IsAggregate = false,
41+
IsFromProviderManifest = true,
42+
StoreFunctionName = "coalesce",
43+
IsComposable = true,
44+
ReturnParameters = new[]
45+
{
46+
FunctionParameter.Create("ReturnType", fromClrType,ParameterMode.ReturnValue)
47+
},
48+
Parameters = argumentExpressions.Select(
49+
x => FunctionParameter.Create(
50+
"p" + (i++).ToString(),fromClrType,ParameterMode.In)).ToList()
51+
},
52+
new List<MetadataProperty>());
53+
54+
return func.Invoke(argumentExpressions);
55+
}
56+
57+
public static DbFunctionExpression UnnestCoalesceInvocations(DbFunctionExpression dbFunctionExpression)
58+
{
59+
var args = new List<DbExpression>();
60+
foreach (var arg in dbFunctionExpression.Arguments)
61+
{
62+
if(arg is DbFunctionExpression funcCall
63+
&& funcCall.Function.NamespaceName=="Npgsql"
64+
&& funcCall.Function.Name=="coalesce")
65+
{
66+
args.AddRange(funcCall.Arguments);
67+
}
68+
else
69+
{
70+
args.Add(arg);
71+
}
72+
}
73+
return InvokeCoalesceExpression(args.ToArray());
74+
}
75+
76+
public static DbExpression TransformCoalesce(DbExpression expression)
77+
{
78+
if (expression is DbCaseExpression case2)
79+
{
80+
return TransformCoalesce(case2);
81+
}
82+
83+
if (expression is DbIsNullExpression nullExp)
84+
{
85+
return TransformCoalesce(nullExp.Argument).IsNull();
86+
}
87+
return expression;
88+
}
89+
90+
public static DbExpression TransformCoalesce(DbCaseExpression expression)
91+
{
92+
expression = DbExpressionBuilder.Case(
93+
expression.When.Select(TransformCoalesce),
94+
expression.Then.Select(TransformCoalesce),
95+
expression.Else);
96+
97+
var lastWhen = expression.When.Count-1;
98+
if (expression.When[lastWhen].ExpressionKind == DbExpressionKind.IsNull)
99+
{
100+
var is_null = expression.When[lastWhen] as DbIsNullExpression;
101+
if (DbExpressionDeepEqual.DeepEqual(is_null.Argument,expression.Else))
102+
{
103+
var coalesceInvocation = InvokeCoalesceExpression(is_null.Argument, expression.Then[lastWhen]);
104+
coalesceInvocation = UnnestCoalesceInvocations(coalesceInvocation);
105+
106+
if (expression.When.Count == 1)
107+
{
108+
return coalesceInvocation;
109+
}
110+
111+
var simplifiendCase = DbExpressionBuilder.Case(
112+
expression.When.Take(lastWhen),
113+
expression.Then.Take(lastWhen),
114+
coalesceInvocation);
115+
116+
return TransformCoalesce(simplifiendCase);
117+
}
118+
return expression;
119+
}
120+
return expression;
121+
}
122+
}
123+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using System.Data.Entity.Core.Common.CommandTrees;
2+
using System.Data.Entity.Core.Metadata.Edm;
3+
using System.Linq;
4+
5+
namespace Npgsql.SqlGenerators
6+
{
7+
public class DbExpressionDeepEqual
8+
{
9+
public static bool DeepEqual(DbExpression e1, DbExpression e2)
10+
{
11+
if (e1.Equals(e2)) return true;
12+
if (e1.GetType() != e2.GetType()) return false;
13+
if (!e1.ExpressionKind.Equals(e2.ExpressionKind)) return false;
14+
if (!DeepEqual(e1.ResultType,e2.ResultType)) return false;
15+
16+
if (e1 is DbFunctionExpression f1 && e2 is DbFunctionExpression f2)
17+
{
18+
return DeepEqual(f1,f2);
19+
}
20+
if (e1 is DbConstantExpression c1 && e2 is DbConstantExpression c2)
21+
{
22+
return c1.Value.Equals(c2.Value);
23+
}
24+
if (e1 is DbBinaryExpression b1 && e2 is DbBinaryExpression b2)
25+
{
26+
return DeepEqual(b1,b2);
27+
}
28+
if (e1 is DbUnaryExpression u1 && e2 is DbUnaryExpression u2)
29+
{
30+
return DeepEqual(u1,u2);
31+
}
32+
if (e1 is DbVariableReferenceExpression v1 && e2 is DbVariableReferenceExpression v2)
33+
{
34+
return DeepEqual(v1,v2);
35+
}
36+
37+
return false;
38+
}
39+
40+
static bool DeepEqual(TypeUsage r1, TypeUsage r2)
41+
{
42+
if (r1.EdmType != r2.EdmType) return false;
43+
return true;
44+
}
45+
46+
private static bool DeepEqual(DbFunctionExpression f1, DbFunctionExpression f2)
47+
{
48+
if (!f1.Function.Name.Equals(f2.Function.Name)) return false;
49+
if (!f1.Function.NamespaceName.Equals(f2.Function.NamespaceName)) return false;
50+
if (!f1.Arguments.Count.Equals(f2.Arguments.Count)) return false;
51+
52+
var argumenst_equals = f1.Arguments
53+
.Zip(f2.Arguments, (a, b) => DeepEqual(a, b))
54+
.All(areEquals => areEquals);
55+
56+
return argumenst_equals;
57+
}
58+
59+
private static bool DeepEqual(DbBinaryExpression b1, DbBinaryExpression b2)
60+
{
61+
if (!DeepEqual(b1.Left,b2.Left)) return false;
62+
if (!DeepEqual(b1.Right,b2.Right)) return false;
63+
64+
return true;
65+
}
66+
67+
private static bool DeepEqual(DbUnaryExpression u1, DbUnaryExpression u2)
68+
{
69+
return DeepEqual(u1.Argument,u2.Argument);
70+
}
71+
72+
private static bool DeepEqual(DbVariableReferenceExpression v1, DbVariableReferenceExpression v2)
73+
{
74+
return DeepEqual(v1.VariableName,v1.VariableName);
75+
}
76+
}
77+
}

src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs

+16
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,16 @@ protected string GetDbType(EdmType edmType)
829829

830830
public override VisitedExpression Visit([NotNull] DbCaseExpression expression)
831831
{
832+
var result = CaseIsNullToCoalesceReducer.TransformCoalesce(expression);
833+
if (result is DbCaseExpression case2)
834+
{
835+
expression = case2;
836+
}
837+
else
838+
{
839+
return result.Accept(this);
840+
}
841+
832842
var caseExpression = new LiteralExpression(" CASE ");
833843
for (var i = 0; i < expression.When.Count && i < expression.Then.Count; ++i)
834844
{
@@ -1191,6 +1201,12 @@ VisitedExpression VisitFunction(EdmFunction function, IList<DbExpression> args,
11911201
throw new NotSupportedException("cast type name argument must be a constant expression.");
11921202

11931203
return new CastExpression(args[0].Accept(this), typeNameExpression.Value.ToString());
1204+
}else if (functionName == "coalesce")
1205+
{
1206+
var coalesceFuncCall = new FunctionExpression("coalesce");
1207+
foreach (var a in args)
1208+
coalesceFuncCall.AddArgument(a.Accept(this));
1209+
return coalesceFuncCall;
11941210
}
11951211
}
11961212

test/EntityFramework6.Npgsql.Tests/EntityFrameworkBasicTests.cs

+63
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,69 @@ public void Test_issue_27_select_ef_generated_literals_from_inner_select()
735735
}
736736
}
737737

738+
[Test]
739+
public void Test_issue_60_and_62()
740+
{
741+
using (var context = new BloggingContext(ConnectionString))
742+
{
743+
context.Database.Log = Console.Out.WriteLine;
744+
745+
context.Blogs.Add( new Blog { Name = "Hello" });
746+
context.SaveChanges();
747+
748+
string string_value = "string_value";
749+
var query = context.Blogs.Select(b => string_value + "_postfijo").Take(1);
750+
var blog_title = query.First();
751+
Assert.That(blog_title, Is.EqualTo("string_value_postfijo"));
752+
StringAssert.DoesNotContain("case", query.ToString().ToLower() );
753+
}
754+
}
755+
756+
[Test]
757+
public void TestNullPropagation_1()
758+
{
759+
using (var context = new BloggingContext(ConnectionString))
760+
{
761+
context.Database.Log = Console.Out.WriteLine;
762+
763+
context.Blogs.Add( new Blog { Name = "Hello" });
764+
context.SaveChanges();
765+
766+
string valor_string = "string_value";
767+
var query = context.Blogs.Select(b => (valor_string ?? "otro_valor") + "_postfijo").Take(1);
768+
var blog_title = query.First();
769+
Assert.That(blog_title, Is.EqualTo("string_value_postfijo"));
770+
771+
var query_sql = query.ToString().ToLower();
772+
StringAssert.DoesNotContain("case", query.ToString().ToLower() );
773+
StringAssert.Contains("coalesce(@p__linq__0,e'otro_valor',e'')", query_sql);
774+
}
775+
}
776+
777+
[Test]
778+
public void TestNullPropagation_2()
779+
{
780+
using (var context = new BloggingContext(ConnectionString))
781+
{
782+
context.Database.Log = Console.Out.WriteLine;
783+
784+
context.Blogs.Add( new Blog { Name = "Hello" });
785+
context.SaveChanges();
786+
787+
string string_value1 = "string_value1";
788+
string string_value2 = "string_value2";
789+
string string_value3 = "string_value3";
790+
791+
var query = context.Blogs.Select(b => (string_value1 ?? string_value2 ?? string_value3) + "_postfijo").Take(1);
792+
var blog_title = query.First();
793+
Assert.That(blog_title, Is.EqualTo("string_value1_postfijo"));
794+
795+
var query_sql = query.ToString().ToLower();
796+
StringAssert.DoesNotContain("case", query_sql );
797+
StringAssert.Contains("coalesce(@p__linq__0,@p__linq__1,@p__linq__2,e'')", query_sql);
798+
}
799+
}
800+
738801
[Test]
739802
public void TestTableValuedStoredFunctions()
740803
{

0 commit comments

Comments
 (0)