Skip to content

Commit f16252d

Browse files
authored
[AINode] Support output time column for model inference
1 parent 24a93f5 commit f16252d

File tree

9 files changed

+82
-12
lines changed

9 files changed

+82
-12
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java

+8-7
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,22 @@ public void ModelOperationTest() {
178178

179179
@Test
180180
public void callInferenceTest() {
181-
String sql = "CALL INFERENCE(identity, \"select s0,s1,s2 from root.AI.data\")";
181+
String sql =
182+
"CALL INFERENCE(identity, \"select s0,s1,s2 from root.AI.data\", generateTime=true)";
182183
String sql2 = "CALL INFERENCE(identity, \"select s2,s0,s1 from root.AI.data\")";
183184
String sql3 =
184-
"CALL INFERENCE(_NaiveForecaster, \"select s0 from root.AI.data\", predict_length=3)";
185+
"CALL INFERENCE(_NaiveForecaster, \"select s0 from root.AI.data\", predict_length=3, generateTime=true)";
185186
try (Connection connection = EnvFactory.getEnv().getConnection();
186187
Statement statement = connection.createStatement()) {
187188

188189
try (ResultSet resultSet = statement.executeQuery(sql)) {
189190
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
190-
checkHeader(resultSetMetaData, "output0,output1,output2");
191+
checkHeader(resultSetMetaData, "Time,output0,output1,output2");
191192
int count = 0;
192193
while (resultSet.next()) {
193-
float s0 = resultSet.getFloat(1);
194-
float s1 = resultSet.getFloat(2);
195-
float s2 = resultSet.getFloat(3);
194+
float s0 = resultSet.getFloat(2);
195+
float s1 = resultSet.getFloat(3);
196+
float s2 = resultSet.getFloat(4);
196197

197198
assertEquals(s0, count + 1.0, 0.0001);
198199
assertEquals(s1, count + 2.0, 0.0001);
@@ -221,7 +222,7 @@ public void callInferenceTest() {
221222

222223
try (ResultSet resultSet = statement.executeQuery(sql3)) {
223224
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
224-
checkHeader(resultSetMetaData, "output0,output1,output2");
225+
checkHeader(resultSetMetaData, "Time,output0,output1,output2");
225226
int count = 0;
226227
while (resultSet.next()) {
227228
count++;

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java

+41-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import com.google.common.util.concurrent.Futures;
3838
import com.google.common.util.concurrent.ListenableFuture;
39+
import org.apache.tsfile.block.column.Column;
3940
import org.apache.tsfile.block.column.ColumnBuilder;
4041
import org.apache.tsfile.enums.TSDataType;
4142
import org.apache.tsfile.read.common.block.TsBlock;
@@ -81,13 +82,20 @@ public class InferenceOperator implements ProcessOperator {
8182
private final TsBlockSerde serde = new TsBlockSerde();
8283
private InferenceWindowType windowType = null;
8384

85+
private final boolean generateTimeColumn;
86+
private long maxTimestamp;
87+
private long minTimestamp;
88+
private long interval;
89+
private long currentRowIndex;
90+
8491
public InferenceOperator(
8592
OperatorContext operatorContext,
8693
Operator child,
8794
ModelInferenceDescriptor modelInferenceDescriptor,
8895
ExecutorService modelInferenceExecutor,
8996
List<String> targetColumnNames,
9097
List<String> inputColumnNames,
98+
boolean generateTimeColumn,
9199
long maxRetainedSize,
92100
long maxReturnSize) {
93101
this.operatorContext = operatorContext;
@@ -106,6 +114,14 @@ public InferenceOperator(
106114
if (modelInferenceDescriptor.getInferenceWindowParameter() != null) {
107115
windowType = modelInferenceDescriptor.getInferenceWindowParameter().getWindowType();
108116
}
117+
118+
if (generateTimeColumn) {
119+
this.interval = 0;
120+
this.minTimestamp = Long.MAX_VALUE;
121+
this.maxTimestamp = Long.MIN_VALUE;
122+
this.currentRowIndex = 0;
123+
}
124+
this.generateTimeColumn = generateTimeColumn;
109125
}
110126

111127
@Override
@@ -140,6 +156,15 @@ public boolean hasNext() throws Exception {
140156
return !finished || (results != null && results.size() != resultIndex);
141157
}
142158

159+
private void fillTimeColumn(TsBlock tsBlock) {
160+
Column timeColumn = tsBlock.getTimeColumn();
161+
long[] time = timeColumn.getLongs();
162+
for (int i = 0; i < time.length; i++) {
163+
time[i] = maxTimestamp + interval * currentRowIndex;
164+
currentRowIndex++;
165+
}
166+
}
167+
143168
@Override
144169
public TsBlock next() throws Exception {
145170
if (inferenceExecutionFuture == null) {
@@ -156,6 +181,9 @@ public TsBlock next() throws Exception {
156181

157182
if (results != null && resultIndex != results.size()) {
158183
TsBlock tsBlock = serde.deserialize(results.get(resultIndex));
184+
if (generateTimeColumn) {
185+
fillTimeColumn(tsBlock);
186+
}
159187
resultIndex++;
160188
return tsBlock;
161189
}
@@ -177,6 +205,9 @@ public TsBlock next() throws Exception {
177205

178206
finished = true;
179207
TsBlock resultTsBlock = serde.deserialize(inferenceResp.inferenceResult.get(0));
208+
if (generateTimeColumn) {
209+
fillTimeColumn(resultTsBlock);
210+
}
180211
results = inferenceResp.inferenceResult;
181212
resultIndex++;
182213
return resultTsBlock;
@@ -194,7 +225,12 @@ private void appendTsBlockToBuilder(TsBlock inputTsBlock) {
194225
ColumnBuilder[] columnBuilders = inputTsBlockBuilder.getValueColumnBuilders();
195226
totalRow += inputTsBlock.getPositionCount();
196227
for (int i = 0; i < inputTsBlock.getPositionCount(); i++) {
197-
timeColumnBuilder.writeLong(inputTsBlock.getTimeByIndex(i));
228+
long timestamp = inputTsBlock.getTimeByIndex(i);
229+
if (generateTimeColumn) {
230+
minTimestamp = Math.min(minTimestamp, timestamp);
231+
maxTimestamp = Math.max(maxTimestamp, timestamp);
232+
}
233+
timeColumnBuilder.writeLong(timestamp);
198234
for (int columnIndex = 0; columnIndex < inputTsBlock.getValueColumnCount(); columnIndex++) {
199235
columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex), i);
200236
}
@@ -259,6 +295,10 @@ private TsBlock preProcess(TsBlock inputTsBlock) {
259295

260296
private void submitInferenceTask() {
261297

298+
if (generateTimeColumn) {
299+
interval = (maxTimestamp - minTimestamp) / totalRow;
300+
}
301+
262302
TsBlock inputTsBlock = inputTsBlockBuilder.build();
263303

264304
TsBlock finalInputTsBlock = preProcess(inputTsBlock);

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -1695,7 +1695,8 @@ static void analyzeOutput(
16951695
.getModelInferenceDescriptor()
16961696
.setOutputColumnNames(
16971697
columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList()));
1698-
analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, true));
1698+
boolean isIgnoreTimestamp = !queryStatement.isGenerateTime();
1699+
analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp));
16991700
return;
17001701
}
17011702

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java

+3
Original file line numberDiff line numberDiff line change
@@ -4527,6 +4527,9 @@ public Statement visitCallInference(IoTDBSqlParser.CallInferenceContext ctx) {
45274527
"Window Function(e.g. HEAD, TAIL, COUNT) should be set in value when key is 'WINDOW' in CALL INFERENCE");
45284528
}
45294529
parseWindowFunctionInInference(valueContext.windowFunction(), statement);
4530+
} else if (paramKey.equalsIgnoreCase("GENERATETIME")) {
4531+
statement.setGenerateTime(
4532+
Boolean.parseBoolean(parseAttributeValue(valueContext.attributeValue())));
45304533
} else {
45314534
statement.addInferenceAttribute(
45324535
paramKey, parseAttributeValue(valueContext.attributeValue()));

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java

+1
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ public LogicalPlanBuilder planInference(Analysis analysis) {
13831383
context.getQueryId().genPlanNodeId(),
13841384
root,
13851385
analysis.getModelInferenceDescriptor(),
1386+
!analysis.getRespDatasetHeader().isIgnoreTimestamp(),
13861387
analysis.getOutputExpressions().stream()
13871388
.map(expressionStringPair -> expressionStringPair.left.getExpressionString())
13881389
.collect(Collectors.toList()));

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ public PlanNode visitQuery(QueryStatement queryStatement, MPPQueryContext contex
234234
}
235235

236236
if (queryStatement.hasModelInference()) {
237-
planBuilder.planInference(analysis);
237+
planBuilder = planBuilder.planInference(analysis);
238238
}
239239

240240
// plan select into

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java

+1
Original file line numberDiff line numberDiff line change
@@ -2310,6 +2310,7 @@ public Operator visitInference(InferenceNode node, LocalExecutionPlanContext con
23102310
FragmentInstanceManager.getInstance().getModelInferenceExecutor(),
23112311
node.getInputColumnNames(),
23122312
node.getChild().getOutputColumnNames(),
2313+
node.isGenerateTimeColumn(),
23132314
maxRetainedSize,
23142315
maxReturnSize);
23152316
}

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java

+16-2
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,29 @@ public class InferenceNode extends SingleChildProcessNode {
4040

4141
// the column order in select item which reflects the real input order
4242
private final List<String> targetColumnNames;
43+
private boolean generateTimeColumn = false;
4344

4445
public InferenceNode(
4546
PlanNodeId id,
4647
PlanNode child,
4748
ModelInferenceDescriptor modelInferenceDescriptor,
49+
boolean generateTimeColumn,
4850
List<String> targetColumnNames) {
4951
super(id, child);
5052
this.modelInferenceDescriptor = modelInferenceDescriptor;
5153
this.targetColumnNames = targetColumnNames;
54+
this.generateTimeColumn = generateTimeColumn;
5255
}
5356

5457
public InferenceNode(
5558
PlanNodeId id,
5659
ModelInferenceDescriptor modelInferenceDescriptor,
60+
boolean generateTimeColumn,
5761
List<String> inputColumnNames) {
5862
super(id);
5963
this.modelInferenceDescriptor = modelInferenceDescriptor;
6064
this.targetColumnNames = inputColumnNames;
65+
this.generateTimeColumn = generateTimeColumn;
6166
}
6267

6368
public ModelInferenceDescriptor getModelInferenceDescriptor() {
@@ -68,14 +73,19 @@ public List<String> getInputColumnNames() {
6873
return targetColumnNames;
6974
}
7075

76+
public boolean isGenerateTimeColumn() {
77+
return generateTimeColumn;
78+
}
79+
7180
@Override
7281
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
7382
return visitor.visitInference(this, context);
7483
}
7584

7685
@Override
7786
public PlanNode clone() {
78-
return new InferenceNode(getPlanNodeId(), child, modelInferenceDescriptor, targetColumnNames);
87+
return new InferenceNode(
88+
getPlanNodeId(), child, modelInferenceDescriptor, generateTimeColumn, targetColumnNames);
7989
}
8090

8191
@Override
@@ -87,22 +97,26 @@ public List<String> getOutputColumnNames() {
8797
protected void serializeAttributes(ByteBuffer byteBuffer) {
8898
PlanNodeType.INFERENCE.serialize(byteBuffer);
8999
modelInferenceDescriptor.serialize(byteBuffer);
100+
ReadWriteIOUtils.write(generateTimeColumn, byteBuffer);
90101
ReadWriteIOUtils.writeStringList(targetColumnNames, byteBuffer);
91102
}
92103

93104
@Override
94105
protected void serializeAttributes(DataOutputStream stream) throws IOException {
95106
PlanNodeType.INFERENCE.serialize(stream);
96107
modelInferenceDescriptor.serialize(stream);
108+
ReadWriteIOUtils.write(generateTimeColumn, stream);
97109
ReadWriteIOUtils.writeStringList(targetColumnNames, stream);
98110
}
99111

100112
public static InferenceNode deserialize(ByteBuffer buffer) {
101113
ModelInferenceDescriptor modelInferenceDescriptor =
102114
ModelInferenceDescriptor.deserialize(buffer);
115+
boolean generateTimeColumn = ReadWriteIOUtils.readBool(buffer);
103116
List<String> inputColumnNames = ReadWriteIOUtils.readStringList(buffer);
104117
PlanNodeId planNodeId = PlanNodeId.deserialize(buffer);
105-
return new InferenceNode(planNodeId, modelInferenceDescriptor, inputColumnNames);
118+
return new InferenceNode(
119+
planNodeId, modelInferenceDescriptor, generateTimeColumn, inputColumnNames);
106120
}
107121

108122
@Override

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/QueryStatement.java

+9
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,18 @@ public class QueryStatement extends AuthorityInformationStatement {
139139
// [IoTDB-AI] used for model inference, which will be removed in the future
140140
private String modelName;
141141
private boolean hasModelInference = false;
142+
private boolean generateTime = false;
142143
private InferenceWindow inferenceWindow = null;
143144
private Map<String, String> inferenceAttribute = null;
144145

146+
public void setGenerateTime(boolean generateTime) {
147+
this.generateTime = generateTime;
148+
}
149+
150+
public boolean isGenerateTime() {
151+
return generateTime;
152+
}
153+
145154
public void setModelName(String modelName) {
146155
this.modelName = modelName;
147156
}

0 commit comments

Comments
 (0)