Skip to content

Commit af4f476

Browse files
committed
Polishing.
1 parent 60f4e57 commit af4f476

File tree

19 files changed

+456
-328
lines changed

19 files changed

+456
-328
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.data.jdbc.core.convert;
1717

1818
import java.util.*;
19+
import java.util.function.BiFunction;
1920
import java.util.function.Function;
2021
import java.util.function.Predicate;
2122
import java.util.stream.Collectors;
@@ -118,7 +119,7 @@ public class SqlGenerator {
118119

119120
/**
120121
* Create a basic select structure with all the necessary joins
121-
*
122+
*
122123
* @param table the table to base the select on
123124
* @param pathFilter a filter for excluding paths from the select. All paths for which the filter returns
124125
* {@literal true} will be skipped when determining columns to select.
@@ -188,6 +189,8 @@ private Condition getSubselectCondition(AggregatePath path,
188189
Table subSelectTable = Table.create(parentPathTableInfo.qualifiedTableName());
189190

190191
Map<AggregatePath, Column> selectFilterColumns = new TreeMap<>();
192+
193+
// TODO: cannot we simply pass on the columnInfos?
191194
parentPathTableInfo.effectiveIdColumnInfos().forEach( //
192195
(ap, ci) -> //
193196
selectFilterColumns.put(ap, subSelectTable.column(ci.name())) //
@@ -471,6 +474,8 @@ String createDeleteAllSql(@Nullable PersistentPropertyPath<RelationalPersistentP
471474
* @return the statement as a {@link String}. Guaranteed to be not {@literal null}.
472475
*/
473476
String createDeleteByPath(PersistentPropertyPath<RelationalPersistentProperty> path) {
477+
// TODO: When deleting by path, why do we expect the where-value to be id and not named after the path?
478+
// See SqlGeneratorEmbeddedUnitTests.deleteByPath
474479
return createDeleteByPathAndCriteria(mappingContext.getAggregatePath(path), this::equalityCondition);
475480
}
476481

@@ -490,12 +495,10 @@ String createDeleteInByPath(PersistentPropertyPath<RelationalPersistentProperty>
490495
*/
491496
private Condition inCondition(Map<AggregatePath, Column> columnMap) {
492497

493-
List<Column> columns = List.copyOf(columnMap.values());
498+
Collection<Column> columns = columnMap.values();
494499

495-
if (columns.size() == 1) {
496-
return Conditions.in(columns.get(0), getBindMarker(IDS_SQL_PARAMETER));
497-
}
498-
return Conditions.in(TupleExpression.create(columns), getBindMarker(IDS_SQL_PARAMETER));
500+
return Conditions.in(columns.size() == 1 ? columns.iterator().next() : TupleExpression.create(columns),
501+
getBindMarker(IDS_SQL_PARAMETER));
499502
}
500503

501504
/**
@@ -504,44 +507,54 @@ private Condition inCondition(Map<AggregatePath, Column> columnMap) {
504507
*/
505508
private Condition equalityCondition(Map<AggregatePath, Column> columnMap) {
506509

507-
AggregatePath.ColumnInfos idColumnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
510+
Assert.isTrue(!columnMap.isEmpty(), "Column map must not be empty");
508511

509-
Condition result = null;
510-
for (Map.Entry<AggregatePath, Column> entry : columnMap.entrySet()) {
511-
BindMarker bindMarker = getBindMarker(idColumnInfos.get(entry.getKey()).name());
512-
Comparison singleCondition = entry.getValue().isEqualTo(bindMarker);
512+
AggregatePath.ColumnInfos idColumnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
513513

514-
result = result == null ? singleCondition : result.and(singleCondition);
515-
}
516-
Assert.state(result != null, "We need at least one condition");
517-
return result;
514+
return createPredicate(columnMap, (aggregatePath, column) -> {
515+
return column.isEqualTo(getBindMarker(idColumnInfos.get(aggregatePath).name()));
516+
});
518517
}
519518

520519
/**
521520
* Constructs a function for constructing where a condition. The where condition will be of the form
522521
* {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
523522
*/
524523
private Condition isNotNullCondition(Map<AggregatePath, Column> columnMap) {
524+
return createPredicate(columnMap, (aggregatePath, column) -> column.isNotNull());
525+
}
526+
527+
/**
528+
* Constructs a function for constructing where a condition. The where condition will be of the form
529+
* {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
530+
*/
531+
private static Condition createPredicate(Map<AggregatePath, Column> columnMap,
532+
BiFunction<AggregatePath, Column, Condition> conditionFunction) {
525533

526534
Condition result = null;
527-
for (Column column : columnMap.values()) {
528-
Condition singleCondition = column.isNotNull();
535+
for (Map.Entry<AggregatePath, Column> entry : columnMap.entrySet()) {
529536

537+
Condition singleCondition = conditionFunction.apply(entry.getKey(), entry.getValue());
530538
result = result == null ? singleCondition : result.and(singleCondition);
531539
}
532540
Assert.state(result != null, "We need at least one condition");
533541
return result;
534542
}
535543

536544
private String createFindOneSql() {
537-
538545
return render(selectBuilder().where(equalityIdWhereCondition()).build());
539546
}
540547

541548
private Condition equalityIdWhereCondition() {
549+
return equalityIdWhereCondition(getIdColumns());
550+
}
551+
552+
private Condition equalityIdWhereCondition(Iterable<Column> columns) {
553+
554+
Assert.isTrue(columns.iterator().hasNext(), "Identifier columns must not be empty");
542555

543556
Condition aggregate = null;
544-
for (Column column : getIdColumns()) {
557+
for (Column column : columns) {
545558

546559
Comparison condition = column.isEqualTo(getBindMarker(column.getName()));
547560
aggregate = aggregate == null ? condition : aggregate.and(condition);
@@ -766,19 +779,13 @@ Join getJoin(AggregatePath path) {
766779
Table parentTable = sqlContext.getTable(idDefiningParentPath);
767780
AggregatePath.ColumnInfos idColumnInfos = idDefiningParentPath.getTableInfo().idColumnInfos();
768781

769-
final Condition[] joinCondition = { null };
770-
backRefColumnInfos.forEach((ap, ci) -> {
771-
772-
Condition elementalCondition = currentTable.column(ci.name())
773-
.isEqualTo(parentTable.column(idColumnInfos.get(ap).name()));
774-
joinCondition[0] = joinCondition[0] == null ? elementalCondition : joinCondition[0].and(elementalCondition);
775-
});
782+
Condition joinCondition = backRefColumnInfos.reduce(Conditions.unrestricted(), (aggregatePath, columnInfo) -> {
776783

777-
return new Join( //
778-
currentTable, //
779-
joinCondition[0] //
780-
);
784+
return currentTable.column(columnInfo.name())
785+
.isEqualTo(parentTable.column(idColumnInfos.get(aggregatePath).name()));
786+
}, Condition::and);
781787

788+
return new Join(currentTable, joinCondition);
782789
}
783790

784791
private String createFindAllInListSql() {
@@ -917,6 +924,8 @@ private String createDeleteByPathAndCriteria(AggregatePath path,
917924

918925
Map<AggregatePath, Column> columns = new TreeMap<>();
919926
AggregatePath.ColumnInfos columnInfos = path.getTableInfo().backReferenceColumnInfos();
927+
928+
// TODO: cannot we simply pass on the columnInfos?
920929
columnInfos.forEach((ag, ci) -> columns.put(ag, table.column(ci.name())));
921930

922931
if (isFirstNonRoot(path)) {
@@ -970,17 +979,20 @@ private Table getTable() {
970979
*/
971980
private Column getSingleNonNullColumn() {
972981

982+
// getColumn() is slightly different from the code in any(…). Why?
983+
// AggregatePath.ColumnInfo columnInfo = path.getColumnInfo();
984+
// return getTable(path).column(columnInfo.name()).as(columnInfo.alias());
985+
973986
AggregatePath.ColumnInfos columnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
974987
return columnInfos.any((ap, ci) -> sqlContext.getTable(columnInfos.fullPath(ap)).column(ci.name()).as(ci.alias()));
975988
}
976989

977990
private List<Column> getIdColumns() {
978991

979992
AggregatePath.ColumnInfos columnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
980-
List<Column> result = new ArrayList<>(columnInfos.size());
981-
columnInfos.forEach((ap, ci) -> result.add(sqlContext.getColumn(columnInfos.fullPath(ap))));
982993

983-
return result;
994+
return columnInfos
995+
.toColumnList((aggregatePath, columnInfo) -> sqlContext.getColumn(columnInfos.fullPath(aggregatePath)));
984996
}
985997

986998
private Column getVersionColumn() {

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlParametersFactory.java

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import java.sql.SQLType;
1919
import java.util.ArrayList;
20+
import java.util.Collection;
2021
import java.util.List;
2122
import java.util.Map;
2223
import java.util.function.BiFunction;
23-
import java.util.function.Function;
2424
import java.util.function.Predicate;
2525

2626
import org.springframework.data.jdbc.core.mapping.JdbcValue;
@@ -34,9 +34,7 @@
3434
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
3535
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
3636
import org.springframework.data.relational.core.sql.SqlIdentifier;
37-
import org.springframework.jdbc.support.JdbcUtils;
3837
import org.springframework.lang.Nullable;
39-
import org.springframework.util.Assert;
4038

4139
/**
4240
* Creates the {@link SqlIdentifierParameterSource} for various SQL operations, dialect identifier processing rules and
@@ -45,9 +43,11 @@
4543
* @author Jens Schauder
4644
* @author Chirag Tailor
4745
* @author Mikhail Polivakha
46+
* @author Mark Paluch
4847
* @since 2.4
4948
*/
5049
public class SqlParametersFactory {
50+
5151
private final RelationalMappingContext context;
5252
private final JdbcConverter converter;
5353

@@ -119,24 +119,20 @@ <T> SqlIdentifierParameterSource forUpdate(T instance, Class<T> domainType) {
119119
*/
120120
<T> SqlIdentifierParameterSource forQueryById(Object id, Class<T> domainType) {
121121

122-
SqlIdentifierParameterSource parameterSource = new SqlIdentifierParameterSource();
123-
124-
RelationalPersistentEntity<T> entity = getRequiredPersistentEntity(domainType);
125-
RelationalPersistentProperty singleIdProperty = entity.getRequiredIdProperty();
122+
return doWithIdentifiers(domainType, (columns, idProperty, complexId) -> {
126123

127-
RelationalPersistentEntity<?> complexId = context.getPersistentEntity(singleIdProperty);
124+
SqlIdentifierParameterSource parameterSource = new SqlIdentifierParameterSource();
125+
BiFunction<Object, AggregatePath, Object> valueExtractor = getIdMapper(complexId);
128126

129-
Function<AggregatePath, Object> valueExtractor = complexId == null ? ap -> id
130-
: ap -> complexId.getPropertyPathAccessor(id).getProperty(ap.getRequiredPersistentPropertyPath());
127+
columns.forEach((ap, ci) -> addConvertedPropertyValue( //
128+
parameterSource, //
129+
ap.getRequiredLeafProperty(), //
130+
valueExtractor.apply(id, ap), //
131+
ci.name() //
132+
));
131133

132-
context.getAggregatePath(entity).getTableInfo().idColumnInfos() //
133-
.forEach((ap, ci) -> addConvertedPropertyValue( //
134-
parameterSource, //
135-
ap.getRequiredLeafProperty(), //
136-
valueExtractor.apply(ap), //
137-
ci.name() //
138-
));
139-
return parameterSource;
134+
return parameterSource;
135+
});
140136
}
141137

142138
/**
@@ -149,29 +145,44 @@ <T> SqlIdentifierParameterSource forQueryById(Object id, Class<T> domainType) {
149145
*/
150146
<T> SqlIdentifierParameterSource forQueryByIds(Iterable<?> ids, Class<T> domainType) {
151147

152-
SqlIdentifierParameterSource parameterSource = new SqlIdentifierParameterSource();
148+
return doWithIdentifiers(domainType, (columns, idProperty, complexId) -> {
153149

154-
RelationalPersistentEntity<?> entity = context.getRequiredPersistentEntity(domainType);
155-
RelationalPersistentProperty singleIdProperty = entity.getRequiredIdProperty();
156-
RelationalPersistentEntity<?> complexId = context.getPersistentEntity(singleIdProperty);
157-
AggregatePath.ColumnInfos idColumnInfos = context.getAggregatePath(entity).getTableInfo().idColumnInfos();
150+
SqlIdentifierParameterSource parameterSource = new SqlIdentifierParameterSource();
158151

159-
BiFunction<Object, AggregatePath, Object> valueExtractor = complexId == null ? (id, ap) -> id
160-
: (id, ap) -> complexId.getPropertyPathAccessor(id).getProperty(ap.getRequiredPersistentPropertyPath());
152+
BiFunction<Object, AggregatePath, Object> valueExtractor = getIdMapper(complexId);
161153

162-
List<Object[]> parameterValues = new ArrayList<>();
163-
for (Object id : ids) {
154+
List<Object[]> parameterValues = new ArrayList<>(ids instanceof Collection<?> c ? c.size() : 16);
155+
for (Object id : ids) {
164156

165-
List<Object> tupleList = new ArrayList<>();
166-
idColumnInfos.forEach((ap, ci) -> {
167-
tupleList.add(valueExtractor.apply(id, ap));
168-
});
169-
parameterValues.add(tupleList.toArray(new Object[0]));
170-
}
157+
Object[] tupleList = new Object[columns.size()];
171158

172-
parameterSource.addValue(SqlGenerator.IDS_SQL_PARAMETER, parameterValues);
159+
int i = 0;
160+
for (AggregatePath path : columns.paths()) {
161+
tupleList[i++] = valueExtractor.apply(id, path);
162+
}
173163

174-
return parameterSource;
164+
parameterValues.add(tupleList);
165+
}
166+
167+
parameterSource.addValue(SqlGenerator.IDS_SQL_PARAMETER, parameterValues);
168+
return parameterSource;
169+
});
170+
}
171+
172+
private <T> T doWithIdentifiers(Class<?> domainType, IdentifierCallback<T> callback) {
173+
174+
RelationalPersistentEntity<?> entity = context.getRequiredPersistentEntity(domainType);
175+
RelationalPersistentProperty idProperty = entity.getRequiredIdProperty();
176+
RelationalPersistentEntity<?> complexId = context.getPersistentEntity(idProperty);
177+
AggregatePath.ColumnInfos columns = context.getAggregatePath(entity).getTableInfo().idColumnInfos();
178+
179+
return callback.doWithIdentifiers(columns, idProperty, complexId);
180+
}
181+
182+
interface IdentifierCallback<T> {
183+
184+
T doWithIdentifiers(AggregatePath.ColumnInfos columns, RelationalPersistentProperty idProperty,
185+
RelationalPersistentEntity<?> complexId);
175186
}
176187

177188
/**
@@ -191,6 +202,16 @@ SqlIdentifierParameterSource forQueryByIdentifier(Identifier identifier) {
191202
return parameterSource;
192203
}
193204

205+
private BiFunction<Object, AggregatePath, Object> getIdMapper(@Nullable RelationalPersistentEntity<?> complexId) {
206+
207+
if (complexId == null) {
208+
return (id, aggregatePath) -> id;
209+
}
210+
211+
return (id, aggregatePath) -> complexId.getPropertyPathAccessor(id)
212+
.getProperty(aggregatePath.getRequiredPersistentPropertyPath());
213+
}
214+
194215
private void addConvertedPropertyValue(SqlIdentifierParameterSource parameterSource,
195216
RelationalPersistentProperty property, @Nullable Object value, SqlIdentifier name) {
196217

@@ -219,28 +240,6 @@ private void addConvertedValue(SqlIdentifierParameterSource parameterSource, @Nu
219240
jdbcValue.getJdbcType().getVendorTypeNumber());
220241
}
221242

222-
private void addConvertedPropertyValuesAsList(SqlIdentifierParameterSource parameterSource,
223-
RelationalPersistentProperty property, Iterable<?> values) {
224-
225-
List<Object> convertedIds = new ArrayList<>();
226-
JdbcValue jdbcValue = null;
227-
for (Object id : values) {
228-
229-
Class<?> columnType = converter.getColumnType(property);
230-
SQLType sqlType = converter.getTargetSqlType(property);
231-
232-
jdbcValue = converter.writeJdbcValue(id, columnType, sqlType);
233-
convertedIds.add(jdbcValue.getValue());
234-
}
235-
236-
Assert.state(jdbcValue != null, "JdbcValue must be not null at this point; Please report this as a bug");
237-
238-
SQLType jdbcType = jdbcValue.getJdbcType();
239-
int typeNumber = jdbcType == null ? JdbcUtils.TYPE_UNKNOWN : jdbcType.getVendorTypeNumber();
240-
241-
parameterSource.addValue(SqlGenerator.IDS_SQL_PARAMETER, convertedIds, typeNumber);
242-
}
243-
244243
@SuppressWarnings("unchecked")
245244
private <S> RelationalPersistentEntity<S> getRequiredPersistentEntity(Class<S> domainType) {
246245
return (RelationalPersistentEntity<S>) context.getRequiredPersistentEntity(domainType);

0 commit comments

Comments
 (0)