16
16
package org .springframework .data .jdbc .core .convert ;
17
17
18
18
import java .util .*;
19
+ import java .util .function .BiFunction ;
19
20
import java .util .function .Function ;
20
21
import java .util .function .Predicate ;
21
22
import java .util .stream .Collectors ;
@@ -118,7 +119,7 @@ public class SqlGenerator {
118
119
119
120
/**
120
121
* Create a basic select structure with all the necessary joins
121
- *
122
+ *
122
123
* @param table the table to base the select on
123
124
* @param pathFilter a filter for excluding paths from the select. All paths for which the filter returns
124
125
* {@literal true} will be skipped when determining columns to select.
@@ -188,6 +189,8 @@ private Condition getSubselectCondition(AggregatePath path,
188
189
Table subSelectTable = Table .create (parentPathTableInfo .qualifiedTableName ());
189
190
190
191
Map <AggregatePath , Column > selectFilterColumns = new TreeMap <>();
192
+
193
+ // TODO: cannot we simply pass on the columnInfos?
191
194
parentPathTableInfo .effectiveIdColumnInfos ().forEach ( //
192
195
(ap , ci ) -> //
193
196
selectFilterColumns .put (ap , subSelectTable .column (ci .name ())) //
@@ -471,6 +474,8 @@ String createDeleteAllSql(@Nullable PersistentPropertyPath<RelationalPersistentP
471
474
* @return the statement as a {@link String}. Guaranteed to be not {@literal null}.
472
475
*/
473
476
String createDeleteByPath (PersistentPropertyPath <RelationalPersistentProperty > path ) {
477
+ // TODO: When deleting by path, why do we expect the where-value to be id and not named after the path?
478
+ // See SqlGeneratorEmbeddedUnitTests.deleteByPath
474
479
return createDeleteByPathAndCriteria (mappingContext .getAggregatePath (path ), this ::equalityCondition );
475
480
}
476
481
@@ -490,12 +495,10 @@ String createDeleteInByPath(PersistentPropertyPath<RelationalPersistentProperty>
490
495
*/
491
496
private Condition inCondition (Map <AggregatePath , Column > columnMap ) {
492
497
493
- List <Column > columns = List . copyOf ( columnMap .values () );
498
+ Collection <Column > columns = columnMap .values ();
494
499
495
- if (columns .size () == 1 ) {
496
- return Conditions .in (columns .get (0 ), getBindMarker (IDS_SQL_PARAMETER ));
497
- }
498
- return Conditions .in (TupleExpression .create (columns ), getBindMarker (IDS_SQL_PARAMETER ));
500
+ return Conditions .in (columns .size () == 1 ? columns .iterator ().next () : TupleExpression .create (columns ),
501
+ getBindMarker (IDS_SQL_PARAMETER ));
499
502
}
500
503
501
504
/**
@@ -504,44 +507,54 @@ private Condition inCondition(Map<AggregatePath, Column> columnMap) {
504
507
*/
505
508
private Condition equalityCondition (Map <AggregatePath , Column > columnMap ) {
506
509
507
- AggregatePath . ColumnInfos idColumnInfos = mappingContext . getAggregatePath ( entity ). getTableInfo (). idColumnInfos ( );
510
+ Assert . isTrue (! columnMap . isEmpty (), "Column map must not be empty" );
508
511
509
- Condition result = null ;
510
- for (Map .Entry <AggregatePath , Column > entry : columnMap .entrySet ()) {
511
- BindMarker bindMarker = getBindMarker (idColumnInfos .get (entry .getKey ()).name ());
512
- Comparison singleCondition = entry .getValue ().isEqualTo (bindMarker );
512
+ AggregatePath .ColumnInfos idColumnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
513
513
514
- result = result == null ? singleCondition : result .and (singleCondition );
515
- }
516
- Assert .state (result != null , "We need at least one condition" );
517
- return result ;
514
+ return createPredicate (columnMap , (aggregatePath , column ) -> {
515
+ return column .isEqualTo (getBindMarker (idColumnInfos .get (aggregatePath ).name ()));
516
+ });
518
517
}
519
518
520
519
/**
521
520
* Constructs a function for constructing where a condition. The where condition will be of the form
522
521
* {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
523
522
*/
524
523
private Condition isNotNullCondition (Map <AggregatePath , Column > columnMap ) {
524
+ return createPredicate (columnMap , (aggregatePath , column ) -> column .isNotNull ());
525
+ }
526
+
527
+ /**
528
+ * Constructs a function for constructing where a condition. The where condition will be of the form
529
+ * {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
530
+ */
531
+ private static Condition createPredicate (Map <AggregatePath , Column > columnMap ,
532
+ BiFunction <AggregatePath , Column , Condition > conditionFunction ) {
525
533
526
534
Condition result = null ;
527
- for (Column column : columnMap .values ()) {
528
- Condition singleCondition = column .isNotNull ();
535
+ for (Map .Entry <AggregatePath , Column > entry : columnMap .entrySet ()) {
529
536
537
+ Condition singleCondition = conditionFunction .apply (entry .getKey (), entry .getValue ());
530
538
result = result == null ? singleCondition : result .and (singleCondition );
531
539
}
532
540
Assert .state (result != null , "We need at least one condition" );
533
541
return result ;
534
542
}
535
543
536
544
private String createFindOneSql () {
537
-
538
545
return render (selectBuilder ().where (equalityIdWhereCondition ()).build ());
539
546
}
540
547
541
548
private Condition equalityIdWhereCondition () {
549
+ return equalityIdWhereCondition (getIdColumns ());
550
+ }
551
+
552
+ private Condition equalityIdWhereCondition (Iterable <Column > columns ) {
553
+
554
+ Assert .isTrue (columns .iterator ().hasNext (), "Identifier columns must not be empty" );
542
555
543
556
Condition aggregate = null ;
544
- for (Column column : getIdColumns () ) {
557
+ for (Column column : columns ) {
545
558
546
559
Comparison condition = column .isEqualTo (getBindMarker (column .getName ()));
547
560
aggregate = aggregate == null ? condition : aggregate .and (condition );
@@ -766,19 +779,13 @@ Join getJoin(AggregatePath path) {
766
779
Table parentTable = sqlContext .getTable (idDefiningParentPath );
767
780
AggregatePath .ColumnInfos idColumnInfos = idDefiningParentPath .getTableInfo ().idColumnInfos ();
768
781
769
- final Condition [] joinCondition = { null };
770
- backRefColumnInfos .forEach ((ap , ci ) -> {
771
-
772
- Condition elementalCondition = currentTable .column (ci .name ())
773
- .isEqualTo (parentTable .column (idColumnInfos .get (ap ).name ()));
774
- joinCondition [0 ] = joinCondition [0 ] == null ? elementalCondition : joinCondition [0 ].and (elementalCondition );
775
- });
782
+ Condition joinCondition = backRefColumnInfos .reduce (Conditions .unrestricted (), (aggregatePath , columnInfo ) -> {
776
783
777
- return new Join ( //
778
- currentTable , //
779
- joinCondition [0 ] //
780
- );
784
+ return currentTable .column (columnInfo .name ())
785
+ .isEqualTo (parentTable .column (idColumnInfos .get (aggregatePath ).name ()));
786
+ }, Condition ::and );
781
787
788
+ return new Join (currentTable , joinCondition );
782
789
}
783
790
784
791
private String createFindAllInListSql () {
@@ -917,6 +924,8 @@ private String createDeleteByPathAndCriteria(AggregatePath path,
917
924
918
925
Map <AggregatePath , Column > columns = new TreeMap <>();
919
926
AggregatePath .ColumnInfos columnInfos = path .getTableInfo ().backReferenceColumnInfos ();
927
+
928
+ // TODO: cannot we simply pass on the columnInfos?
920
929
columnInfos .forEach ((ag , ci ) -> columns .put (ag , table .column (ci .name ())));
921
930
922
931
if (isFirstNonRoot (path )) {
@@ -970,17 +979,20 @@ private Table getTable() {
970
979
*/
971
980
private Column getSingleNonNullColumn () {
972
981
982
+ // getColumn() is slightly different from the code in any(…). Why?
983
+ // AggregatePath.ColumnInfo columnInfo = path.getColumnInfo();
984
+ // return getTable(path).column(columnInfo.name()).as(columnInfo.alias());
985
+
973
986
AggregatePath .ColumnInfos columnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
974
987
return columnInfos .any ((ap , ci ) -> sqlContext .getTable (columnInfos .fullPath (ap )).column (ci .name ()).as (ci .alias ()));
975
988
}
976
989
977
990
private List <Column > getIdColumns () {
978
991
979
992
AggregatePath .ColumnInfos columnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
980
- List <Column > result = new ArrayList <>(columnInfos .size ());
981
- columnInfos .forEach ((ap , ci ) -> result .add (sqlContext .getColumn (columnInfos .fullPath (ap ))));
982
993
983
- return result ;
994
+ return columnInfos
995
+ .toColumnList ((aggregatePath , columnInfo ) -> sqlContext .getColumn (columnInfos .fullPath (aggregatePath )));
984
996
}
985
997
986
998
private Column getVersionColumn () {
0 commit comments