Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AINode] Support output time column for model inference #15069

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,22 @@ public void ModelOperationTest() {

@Test
public void callInferenceTest() {
String sql = "CALL INFERENCE(identity, \"select s0,s1,s2 from root.AI.data\")";
String sql =
"CALL INFERENCE(identity, \"select s0,s1,s2 from root.AI.data\", generateTime=true)";
String sql2 = "CALL INFERENCE(identity, \"select s2,s0,s1 from root.AI.data\")";
String sql3 =
"CALL INFERENCE(_NaiveForecaster, \"select s0 from root.AI.data\", predict_length=3)";
"CALL INFERENCE(_NaiveForecaster, \"select s0 from root.AI.data\", predict_length=3, generateTime=true)";
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {

try (ResultSet resultSet = statement.executeQuery(sql)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "output0,output1,output2");
checkHeader(resultSetMetaData, "Time,output0,output1,output2");
int count = 0;
while (resultSet.next()) {
float s0 = resultSet.getFloat(1);
float s1 = resultSet.getFloat(2);
float s2 = resultSet.getFloat(3);
float s0 = resultSet.getFloat(2);
float s1 = resultSet.getFloat(3);
float s2 = resultSet.getFloat(4);

assertEquals(s0, count + 1.0, 0.0001);
assertEquals(s1, count + 2.0, 0.0001);
Expand Down Expand Up @@ -221,7 +222,7 @@ public void callInferenceTest() {

try (ResultSet resultSet = statement.executeQuery(sql3)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "output0,output1,output2");
checkHeader(resultSetMetaData, "Time,output0,output1,output2");
int count = 0;
while (resultSet.next()) {
count++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.TsBlock;
Expand Down Expand Up @@ -81,13 +82,20 @@ public class InferenceOperator implements ProcessOperator {
private final TsBlockSerde serde = new TsBlockSerde();
private InferenceWindowType windowType = null;

private final boolean generateTimeColumn;
private long maxTimestamp;
private long minTimestamp;
private long interval;
private long currentRowIndex;

public InferenceOperator(
OperatorContext operatorContext,
Operator child,
ModelInferenceDescriptor modelInferenceDescriptor,
ExecutorService modelInferenceExecutor,
List<String> targetColumnNames,
List<String> inputColumnNames,
boolean generateTimeColumn,
long maxRetainedSize,
long maxReturnSize) {
this.operatorContext = operatorContext;
Expand All @@ -106,6 +114,14 @@ public InferenceOperator(
if (modelInferenceDescriptor.getInferenceWindowParameter() != null) {
windowType = modelInferenceDescriptor.getInferenceWindowParameter().getWindowType();
}

if (generateTimeColumn) {
this.interval = 0;
this.minTimestamp = Long.MAX_VALUE;
this.maxTimestamp = Long.MIN_VALUE;
this.currentRowIndex = 0;
}
this.generateTimeColumn = generateTimeColumn;
}

@Override
Expand Down Expand Up @@ -140,6 +156,15 @@ public boolean hasNext() throws Exception {
return !finished || (results != null && results.size() != resultIndex);
}

private void fillTimeColumn(TsBlock tsBlock) {
Column timeColumn = tsBlock.getTimeColumn();
long[] time = timeColumn.getLongs();
for (int i = 0; i < time.length; i++) {
time[i] = maxTimestamp + interval * currentRowIndex;
currentRowIndex++;
}
}

@Override
public TsBlock next() throws Exception {
if (inferenceExecutionFuture == null) {
Expand All @@ -156,6 +181,9 @@ public TsBlock next() throws Exception {

if (results != null && resultIndex != results.size()) {
TsBlock tsBlock = serde.deserialize(results.get(resultIndex));
if (generateTimeColumn) {
fillTimeColumn(tsBlock);
}
resultIndex++;
return tsBlock;
}
Expand All @@ -177,6 +205,9 @@ public TsBlock next() throws Exception {

finished = true;
TsBlock resultTsBlock = serde.deserialize(inferenceResp.inferenceResult.get(0));
if (generateTimeColumn) {
fillTimeColumn(resultTsBlock);
}
results = inferenceResp.inferenceResult;
resultIndex++;
return resultTsBlock;
Expand All @@ -194,7 +225,12 @@ private void appendTsBlockToBuilder(TsBlock inputTsBlock) {
ColumnBuilder[] columnBuilders = inputTsBlockBuilder.getValueColumnBuilders();
totalRow += inputTsBlock.getPositionCount();
for (int i = 0; i < inputTsBlock.getPositionCount(); i++) {
timeColumnBuilder.writeLong(inputTsBlock.getTimeByIndex(i));
long timestamp = inputTsBlock.getTimeByIndex(i);
if (generateTimeColumn) {
minTimestamp = Math.min(minTimestamp, timestamp);
maxTimestamp = Math.max(maxTimestamp, timestamp);
}
timeColumnBuilder.writeLong(timestamp);
for (int columnIndex = 0; columnIndex < inputTsBlock.getValueColumnCount(); columnIndex++) {
columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex), i);
}
Expand Down Expand Up @@ -259,6 +295,10 @@ private TsBlock preProcess(TsBlock inputTsBlock) {

private void submitInferenceTask() {

if (generateTimeColumn) {
interval = (maxTimestamp - minTimestamp) / totalRow;
}

TsBlock inputTsBlock = inputTsBlockBuilder.build();

TsBlock finalInputTsBlock = preProcess(inputTsBlock);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1695,7 +1695,8 @@ static void analyzeOutput(
.getModelInferenceDescriptor()
.setOutputColumnNames(
columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList()));
analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, true));
boolean isIgnoreTimestamp = !queryStatement.isGenerateTime();
analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp));
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4527,6 +4527,9 @@ public Statement visitCallInference(IoTDBSqlParser.CallInferenceContext ctx) {
"Window Function(e.g. HEAD, TAIL, COUNT) should be set in value when key is 'WINDOW' in CALL INFERENCE");
}
parseWindowFunctionInInference(valueContext.windowFunction(), statement);
} else if (paramKey.equalsIgnoreCase("GENERATETIME")) {
statement.setGenerateTime(
Boolean.parseBoolean(parseAttributeValue(valueContext.attributeValue())));
} else {
statement.addInferenceAttribute(
paramKey, parseAttributeValue(valueContext.attributeValue()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ public LogicalPlanBuilder planInference(Analysis analysis) {
context.getQueryId().genPlanNodeId(),
root,
analysis.getModelInferenceDescriptor(),
!analysis.getRespDatasetHeader().isIgnoreTimestamp(),
analysis.getOutputExpressions().stream()
.map(expressionStringPair -> expressionStringPair.left.getExpressionString())
.collect(Collectors.toList()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ public PlanNode visitQuery(QueryStatement queryStatement, MPPQueryContext contex
}

if (queryStatement.hasModelInference()) {
planBuilder.planInference(analysis);
planBuilder = planBuilder.planInference(analysis);
}

// plan select into
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,7 @@ public Operator visitInference(InferenceNode node, LocalExecutionPlanContext con
FragmentInstanceManager.getInstance().getModelInferenceExecutor(),
node.getInputColumnNames(),
node.getChild().getOutputColumnNames(),
node.isGenerateTimeColumn(),
maxRetainedSize,
maxReturnSize);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,29 @@ public class InferenceNode extends SingleChildProcessNode {

// the column order in select item which reflects the real input order
private final List<String> targetColumnNames;
private boolean generateTimeColumn = false;

public InferenceNode(
PlanNodeId id,
PlanNode child,
ModelInferenceDescriptor modelInferenceDescriptor,
boolean generateTimeColumn,
List<String> targetColumnNames) {
super(id, child);
this.modelInferenceDescriptor = modelInferenceDescriptor;
this.targetColumnNames = targetColumnNames;
this.generateTimeColumn = generateTimeColumn;
}

public InferenceNode(
PlanNodeId id,
ModelInferenceDescriptor modelInferenceDescriptor,
boolean generateTimeColumn,
List<String> inputColumnNames) {
super(id);
this.modelInferenceDescriptor = modelInferenceDescriptor;
this.targetColumnNames = inputColumnNames;
this.generateTimeColumn = generateTimeColumn;
}

public ModelInferenceDescriptor getModelInferenceDescriptor() {
Expand All @@ -68,14 +73,19 @@ public List<String> getInputColumnNames() {
return targetColumnNames;
}

public boolean isGenerateTimeColumn() {
return generateTimeColumn;
}

@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitInference(this, context);
}

@Override
public PlanNode clone() {
return new InferenceNode(getPlanNodeId(), child, modelInferenceDescriptor, targetColumnNames);
return new InferenceNode(
getPlanNodeId(), child, modelInferenceDescriptor, generateTimeColumn, targetColumnNames);
}

@Override
Expand All @@ -87,22 +97,26 @@ public List<String> getOutputColumnNames() {
protected void serializeAttributes(ByteBuffer byteBuffer) {
PlanNodeType.INFERENCE.serialize(byteBuffer);
modelInferenceDescriptor.serialize(byteBuffer);
ReadWriteIOUtils.write(generateTimeColumn, byteBuffer);
ReadWriteIOUtils.writeStringList(targetColumnNames, byteBuffer);
}

@Override
protected void serializeAttributes(DataOutputStream stream) throws IOException {
PlanNodeType.INFERENCE.serialize(stream);
modelInferenceDescriptor.serialize(stream);
ReadWriteIOUtils.write(generateTimeColumn, stream);
ReadWriteIOUtils.writeStringList(targetColumnNames, stream);
}

public static InferenceNode deserialize(ByteBuffer buffer) {
ModelInferenceDescriptor modelInferenceDescriptor =
ModelInferenceDescriptor.deserialize(buffer);
boolean generateTimeColumn = ReadWriteIOUtils.readBool(buffer);
List<String> inputColumnNames = ReadWriteIOUtils.readStringList(buffer);
PlanNodeId planNodeId = PlanNodeId.deserialize(buffer);
return new InferenceNode(planNodeId, modelInferenceDescriptor, inputColumnNames);
return new InferenceNode(
planNodeId, modelInferenceDescriptor, generateTimeColumn, inputColumnNames);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,18 @@ public class QueryStatement extends AuthorityInformationStatement {
// [IoTDB-AI] used for model inference, which will be removed in the future
private String modelName;
private boolean hasModelInference = false;
private boolean generateTime = false;
private InferenceWindow inferenceWindow = null;
private Map<String, String> inferenceAttribute = null;

public void setGenerateTime(boolean generateTime) {
this.generateTime = generateTime;
}

public boolean isGenerateTime() {
return generateTime;
}

public void setModelName(String modelName) {
this.modelName = modelName;
}
Expand Down
Loading