Skip to content

Commit 8bfa2e9

Browse files
authored
Add validation for unsupported type/identifier/commands (opensearch-project#3195)
Signed-off-by: Tomoyuki Morita <[email protected]>
1 parent 6911faf commit 8bfa2e9

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java

+35
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableContext;
3131
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTablePartitionsContext;
3232
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext;
33+
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ErrorCapturingIdentifierContext;
34+
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ErrorCapturingIdentifierExtraContext;
3335
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext;
3436
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext;
3537
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext;
@@ -43,6 +45,7 @@
4345
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinRelationContext;
4446
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinTypeContext;
4547
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LateralViewContext;
48+
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LiteralTypeContext;
4649
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LoadDataContext;
4750
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ManageResourceContext;
4851
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.QueryOrganizationContext;
@@ -77,7 +80,9 @@
7780
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext;
7881
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext;
7982
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TruncateTableContext;
83+
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TypeContext;
8084
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UncacheTableContext;
85+
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsupportedHiveNativeCommandsContext;
8186
import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor;
8287

8388
/** This visitor validate grammar using GrammarElementValidator */
@@ -584,4 +589,34 @@ private void validateAllowed(SQLGrammarElement element) {
584589
throw new IllegalArgumentException(element + " is not allowed.");
585590
}
586591
}
592+
593+
@Override
594+
public Void visitErrorCapturingIdentifier(ErrorCapturingIdentifierContext ctx) {
595+
ErrorCapturingIdentifierExtraContext extra = ctx.errorCapturingIdentifierExtra();
596+
if (extra.children != null) {
597+
throw new IllegalArgumentException("Invalid identifier: " + ctx.getText());
598+
}
599+
return super.visitErrorCapturingIdentifier(ctx);
600+
}
601+
602+
@Override
603+
public Void visitLiteralType(LiteralTypeContext ctx) {
604+
if (ctx.unsupportedType != null) {
605+
throw new IllegalArgumentException("Unsupported typed literal: " + ctx.getText());
606+
}
607+
return super.visitLiteralType(ctx);
608+
}
609+
610+
@Override
611+
public Void visitType(TypeContext ctx) {
612+
if (ctx.unsupportedType != null) {
613+
throw new IllegalArgumentException("Unsupported data type: " + ctx.getText());
614+
}
615+
return super.visitType(ctx);
616+
}
617+
618+
@Override
619+
public Void visitUnsupportedHiveNativeCommands(UnsupportedHiveNativeCommandsContext ctx) {
620+
throw new IllegalArgumentException("Unsupported command.");
621+
}
587622
}

async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java

+59-1
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,56 @@ void testValidateFlintExtensionQuery() {
571571
UUID.randomUUID().toString(), DataSourceType.SECURITY_LAKE));
572572
}
573573

574+
@Test
575+
void testInvalidIdentifier() {
576+
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true);
577+
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK);
578+
v.ng("SELECT a.b.c as a-b-c FROM abc");
579+
v.ok("SELECT a.b.c as `a-b-c` FROM abc");
580+
v.ok("SELECT a.b.c as a_b_c FROM abc");
581+
582+
v.ng("SELECT a.b.c FROM a-b-c");
583+
v.ng("SELECT a.b.c FROM a.b-c");
584+
v.ok("SELECT a.b.c FROM b.c.`a-b-c`");
585+
v.ok("SELECT a.b.c FROM `a-b-c`");
586+
}
587+
588+
@Test
589+
void testUnsupportedType() {
590+
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true);
591+
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK);
592+
593+
v.ng("SELECT cast ( a as DateTime ) FROM tbl");
594+
v.ok("SELECT cast ( a as DATE ) FROM tbl");
595+
v.ok("SELECT cast ( a as Date ) FROM tbl");
596+
v.ok("SELECT cast ( a as Timestamp ) FROM tbl");
597+
}
598+
599+
@Test
600+
void testUnsupportedTypedLiteral() {
601+
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true);
602+
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK);
603+
604+
v.ng("SELECT DATETIME '2024-10-11'");
605+
v.ok("SELECT DATE '2024-10-11'");
606+
v.ok("SELECT TIMESTAMP '2024-10-11'");
607+
}
608+
609+
@Test
610+
void testUnsupportedHiveNativeCommand() {
611+
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> true);
612+
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK);
613+
614+
v.ng("CREATE ROLE aaa");
615+
v.ng("SHOW GRANT");
616+
v.ng("EXPORT TABLE");
617+
v.ng("ALTER TABLE aaa NOT CLUSTERED");
618+
v.ng("START TRANSACTION");
619+
v.ng("COMMIT");
620+
v.ng("ROLLBACK");
621+
v.ng("DFS");
622+
}
623+
574624
@AllArgsConstructor
575625
private static class VerifyValidator {
576626
private final SQLQueryValidator validator;
@@ -580,10 +630,18 @@ public void ok(TestElement query) {
580630
runValidate(query.getQueries());
581631
}
582632

633+
public void ok(String query) {
634+
runValidate(query);
635+
}
636+
583637
public void ng(TestElement query) {
638+
Arrays.stream(query.getQueries()).forEach(this::ng);
639+
}
640+
641+
public void ng(String query) {
584642
assertThrows(
585643
IllegalArgumentException.class,
586-
() -> runValidate(query.getQueries()),
644+
() -> runValidate(query),
587645
"The query should throw: query=`" + query.toString() + "`");
588646
}
589647

0 commit comments

Comments
 (0)