34
34
import java .util .stream .Collectors ;
35
35
import org .apache .commons .lang3 .tuple .ImmutablePair ;
36
36
import org .apache .commons .lang3 .tuple .Pair ;
37
- import org .apache .commons .math3 .analysis .function .Exp ;
38
37
import org .opensearch .sql .DataSourceSchemaName ;
39
38
import org .opensearch .sql .analysis .symbol .Namespace ;
40
39
import org .opensearch .sql .analysis .symbol .Symbol ;
84
83
import org .opensearch .sql .expression .LiteralExpression ;
85
84
import org .opensearch .sql .expression .NamedExpression ;
86
85
import org .opensearch .sql .expression .ReferenceExpression ;
87
- import org .opensearch .sql .expression .aggregation .AggregationState ;
88
86
import org .opensearch .sql .expression .aggregation .Aggregator ;
89
- import org .opensearch .sql .expression .aggregation .AvgAggregator ;
90
87
import org .opensearch .sql .expression .aggregation .NamedAggregator ;
91
88
import org .opensearch .sql .expression .function .BuiltinFunctionName ;
92
89
import org .opensearch .sql .expression .function .BuiltinFunctionRepository ;
@@ -350,26 +347,42 @@ public LogicalPlan visitFieldSummary(FieldSummary node, AnalysisContext context)
350
347
TypeEnvironment env = context .peek ();
351
348
Map <String , ExprType > fieldsMap = env .lookupAllFields (Namespace .FIELD_NAME );
352
349
350
+ if (node .getIncludeFields () != null ) {
351
+ List <String > includeFields =
352
+ node .getIncludeFields ().stream ()
353
+ .map (expr -> ((Field ) expr ).getField ().toString ())
354
+ .toList ();
355
+
356
+ Map <String , ExprType > filteredFields = new HashMap <>();
357
+ for (String field : includeFields ) {
358
+ if (fieldsMap .containsKey (field )) {
359
+ filteredFields .put (field , fieldsMap .get (field ));
360
+ }
361
+ }
362
+ fieldsMap = filteredFields ;
363
+ }
364
+
353
365
ImmutableList .Builder <NamedAggregator > aggregatorBuilder = new ImmutableList .Builder <>();
354
- Map <String , String > aggregatorToFieldNameMap = new HashMap <String , String >();
366
+ Map <String , Map . Entry < String , ExprType >> aggregatorToFieldNameMap = new HashMap <>();
355
367
356
368
for (Map .Entry <String , ExprType > entry : fieldsMap .entrySet ()) {
357
369
ExprType fieldType = entry .getValue ();
358
370
String fieldName = entry .getKey ();
359
371
ReferenceExpression fieldExpression = DSL .ref (fieldName , fieldType );
360
372
361
373
aggregatorBuilder .add (new NamedAggregator ("Count" + fieldName , DSL .count (fieldExpression )));
362
- aggregatorToFieldNameMap .put ("Count" + fieldName , fieldName );
363
- aggregatorBuilder .add (new NamedAggregator ("Distinct" + fieldName , DSL .distinctCount (fieldExpression )));
364
- aggregatorToFieldNameMap .put ("Distinct" + fieldName , fieldName );
374
+ aggregatorToFieldNameMap .put ("Count" + fieldName , entry );
375
+ aggregatorBuilder .add (
376
+ new NamedAggregator ("Distinct" + fieldName , DSL .distinctCount (fieldExpression )));
377
+ aggregatorToFieldNameMap .put ("Distinct" + fieldName , entry );
365
378
366
379
if (ExprCoreType .numberTypes ().contains (fieldType )) {
367
- aggregatorBuilder .add (new NamedAggregator ("Avg" + fieldName , DSL .avg (fieldExpression )));
368
- aggregatorToFieldNameMap .put ("Avg" + fieldName , fieldName );
369
380
aggregatorBuilder .add (new NamedAggregator ("Max" + fieldName , DSL .max (fieldExpression )));
370
- aggregatorToFieldNameMap .put ("Max" + fieldName , fieldName );
381
+ aggregatorToFieldNameMap .put ("Max" + fieldName , entry );
371
382
aggregatorBuilder .add (new NamedAggregator ("Min" + fieldName , DSL .min (fieldExpression )));
372
- aggregatorToFieldNameMap .put ("Min" + fieldName , fieldName );
383
+ aggregatorToFieldNameMap .put ("Min" + fieldName , entry );
384
+ aggregatorBuilder .add (new NamedAggregator ("Avg" + fieldName , DSL .avg (fieldExpression )));
385
+ aggregatorToFieldNameMap .put ("Avg" + fieldName , entry );
373
386
}
374
387
}
375
388
@@ -386,6 +399,7 @@ public LogicalPlan visitFieldSummary(FieldSummary node, AnalysisContext context)
386
399
newEnv .define (new Symbol (Namespace .FIELD_NAME , "Avg" ), ExprCoreType .DOUBLE );
387
400
newEnv .define (new Symbol (Namespace .FIELD_NAME , "Max" ), ExprCoreType .DOUBLE );
388
401
newEnv .define (new Symbol (Namespace .FIELD_NAME , "Min" ), ExprCoreType .DOUBLE );
402
+ newEnv .define (new Symbol (Namespace .FIELD_NAME , "Type" ), ExprCoreType .STRING );
389
403
390
404
return new LogicalFieldSummary (child , aggregators , groupBys , aggregatorToFieldNameMap );
391
405
}
0 commit comments