diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java index 8be22c7330dbbb..b3f9539b2fc611 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java @@ -33,6 +33,7 @@ import org.apache.doris.catalog.Function.NullableMode; import org.apache.doris.catalog.FunctionUtil; import org.apache.doris.catalog.MapType; +import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.ScalarFunction; import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.StructType; @@ -58,11 +59,13 @@ import org.apache.doris.nereids.trees.expressions.BitNot; import org.apache.doris.nereids.trees.expressions.BitOr; import org.apache.doris.nereids.trees.expressions.BitXor; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IntegralDivide; import org.apache.doris.nereids.trees.expressions.Mod; import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; @@ -100,6 +103,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -460,6 +464,8 @@ private void analyzeUdaf() throws AnalysisException { function.setBinaryType(binaryType); function.setChecksum(checksum); function.setNullableMode(returnNullMode); + function.setStaticLoad(isStaticLoad); + function.setExpirationTime(expirationTime); } private void analyzeUdf() throws AnalysisException { @@ -887,10 +893,86 @@ private TFunctionBinaryType getFunctionBinaryType(String type) { } private void analyzeAliasFunction(ConnectContext ctx) throws AnalysisException { + if (parameters.size() != argsDef.getArgTypes().length) { + throw new AnalysisException( + "Alias function [" + functionName + "] args number is not equal to parameters number"); + } + List exprs; + List typeDefParams = new ArrayList<>(); + if (originFunction instanceof org.apache.doris.nereids.trees.expressions.functions.Function) { + exprs = originFunction.getArguments(); + } else if (originFunction instanceof Cast) { + exprs = originFunction.children(); + DataType targetType = originFunction.getDataType(); + Type type = targetType.toCatalogDataType(); + if (type.isScalarType()) { + ScalarType scalarType = (ScalarType) type; + PrimitiveType primitiveType = scalarType.getPrimitiveType(); + switch (primitiveType) { + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: + case DECIMAL256: + case DECIMALV2: + if (!Strings.isNullOrEmpty(scalarType.getScalarPrecisionStr())) { + typeDefParams.add(scalarType.getScalarPrecisionStr()); + } + if (!Strings.isNullOrEmpty(scalarType.getScalarScaleStr())) { + typeDefParams.add(scalarType.getScalarScaleStr()); + } + break; + case CHAR: + case VARCHAR: + if (!Strings.isNullOrEmpty(scalarType.getLenStr())) { + typeDefParams.add(scalarType.getLenStr()); + } + break; + default: + throw new AnalysisException("Alias type is invalid: " + primitiveType); + } + } + } else { + throw new AnalysisException("Not supported expr type: " + originFunction); + } + Set set = new HashSet<>(); + for (String str : parameters) { + if (!set.add(str)) { + throw new AnalysisException( + "Alias function [" + functionName + "] has duplicate parameter [" + str + "]."); + } + boolean existFlag = false; + // check exprs + for (Expression expr : exprs) { + existFlag |= checkParams(expr, str); + } + // check targetTypeDef + for (String typeDefParam : typeDefParams) { + existFlag |= typeDefParam.equals(str); + } + if (!existFlag) { + throw new AnalysisException("Alias function [" + functionName + "] do not contain parameter [" + str + + "]. typeDefParams=" + + typeDefParams.stream().map(String::toString).collect(Collectors.joining(", "))); + } + } function = AliasFunction.createFunction(functionName, argsDef.getArgTypes(), Type.VARCHAR, argsDef.isVariadic(), parameters, translateToLegacyExpr(originFunction, ctx)); } + private boolean checkParams(Expression expr, String param) { + for (Expression e : expr.children()) { + if (checkParams(e, param)) { + return true; + } + } + if (expr instanceof Slot) { + if (param.equals(((Slot) expr).getName())) { + return true; + } + } + return false; + } + /** * translate to legacy expr, which do not need complex expression and table columns */ diff --git a/regression-test/suites/ddl_p0/test_alias_function.groovy b/regression-test/suites/ddl_p0/test_alias_function.groovy index 7793de925531fb..41e950719d1aac 100644 --- a/regression-test/suites/ddl_p0/test_alias_function.groovy +++ b/regression-test/suites/ddl_p0/test_alias_function.groovy @@ -24,4 +24,11 @@ suite("test_alias_function") { sql """DROP FUNCTION IF EXISTS mesh_udf_test2(INT,INT)""" sql """CREATE ALIAS FUNCTION mesh_udf_test2(INT,INT) WITH PARAMETER(n,d) AS add(1,floor(divide(n,d)))""" qt_sql1 """select mesh_udf_test2(1,2);""" + + + sql """DROP FUNCTION IF EXISTS userlevel(bigint)""" + test { + sql """create GLOBAL ALIAS FUNCTION userlevel(bigint) with PARAMETER(level_score) as (CASE WHEN level_score < 0 THEN 0 WHEN level_score < 1000 THEN 1 WHEN level_score < 5000 THEN 2 WHEN level_score < 10000 THEN 3 WHEN level_score < 407160000 THEN 29 ELSE 30 END);""" + exception "Not supported expr type" + } }