diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java index 5e5ca800402..83f9beb7cf7 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/it/ITDmlReturningTest.java @@ -22,12 +22,13 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeFalse; import com.google.cloud.spanner.AsyncResultSet; import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; +import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.KeySet; import com.google.cloud.spanner.Mutation; import com.google.cloud.spanner.ParallelIntegrationTest; import com.google.cloud.spanner.ResultSet; @@ -39,15 +40,15 @@ import com.google.cloud.spanner.connection.StatementResult; import com.google.cloud.spanner.connection.StatementResult.ResultType; import com.google.cloud.spanner.connection.TransactionMode; -import com.google.cloud.spanner.testing.EmulatorSpannerHelper; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; +import java.util.HashSet; import java.util.List; -import java.util.Map; +import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import org.junit.Before; @@ -81,12 +82,7 @@ public class ITDmlReturningTest extends ITAbstractSpannerTest { + " SingerId BIGINT PRIMARY KEY," + " FirstName character varying(1024)," + " LastName character varying(1024))"); - private final Map IS_INITIALIZED = new HashMap<>(); - - public ITDmlReturningTest() { - IS_INITIALIZED.put(Dialect.GOOGLE_STANDARD_SQL, false); - IS_INITIALIZED.put(Dialect.POSTGRESQL, false); - } + private static final Set IS_INITIALIZED = new HashSet<>(); @Parameter public Dialect dialect; @@ -96,41 +92,34 @@ public static Object[] data() { } private boolean checkAndSetInitialized() { - if ((dialect == Dialect.GOOGLE_STANDARD_SQL) && !IS_INITIALIZED.get(dialect)) { - IS_INITIALIZED.put(dialect, true); - return true; - } - if ((dialect == Dialect.POSTGRESQL) && !IS_INITIALIZED.get(dialect)) { - IS_INITIALIZED.put(dialect, true); - return true; - } - return false; + return !IS_INITIALIZED.add(dialect); } @Before public void setupTable() { - assumeFalse( - "DML Returning is not supported in the emulator", EmulatorSpannerHelper.isUsingEmulator()); - if (checkAndSetInitialized()) { + if (!checkAndSetInitialized()) { database = env.getTestHelper() .createTestDatabase(dialect, Collections.singleton(DDL_MAP.get(dialect))); - List firstNames = Arrays.asList("ABC", "ABC", "DEF", "PQR", "ABC"); - List lastNames = Arrays.asList("XYZ", "DEF", "XYZ", "ABC", "GHI"); - List mutations = new ArrayList<>(); - for (int id = 1; id <= 5; id++) { - mutations.add( - Mutation.newInsertBuilder("SINGERS") - .set("SINGERID") - .to(id) - .set("FIRSTNAME") - .to(firstNames.get(id - 1)) - .set("LASTNAME") - .to(lastNames.get(id - 1)) - .build()); - } - env.getTestHelper().getDatabaseClient(database).write(mutations); } + DatabaseClient client = env.getTestHelper().getDatabaseClient(database); + client.write(ImmutableList.of(Mutation.delete("SINGERS", KeySet.all()))); + + List firstNames = Arrays.asList("ABC", "ABC", "DEF", "PQR", "ABC"); + List lastNames = Arrays.asList("XYZ", "DEF", "XYZ", "ABC", "GHI"); + List mutations = new ArrayList<>(); + for (int id = 1; id <= 5; id++) { + mutations.add( + Mutation.newInsertBuilder("SINGERS") + .set("SINGERID") + .to(id) + .set("FIRSTNAME") + .to(firstNames.get(id - 1)) + .set("LASTNAME") + .to(lastNames.get(id - 1)) + .build()); + } + env.getTestHelper().getDatabaseClient(database).write(mutations); } @Test @@ -211,9 +200,9 @@ public void testDmlReturningExecuteUpdateAsync() { public void testDmlReturningExecuteBatchUpdate() { try (Connection connection = createConnection()) { connection.setAutocommit(false); - final Statement UPDATE_STMT = UPDATE_RETURNING_MAP.get(dialect); + final Statement updateStmt = Preconditions.checkNotNull(UPDATE_RETURNING_MAP.get(dialect)); long[] counts = - connection.executeBatchUpdate(ImmutableList.of(UPDATE_STMT, UPDATE_STMT, UPDATE_STMT)); + connection.executeBatchUpdate(ImmutableList.of(updateStmt, updateStmt, updateStmt)); assertArrayEquals(counts, new long[] {3, 3, 3}); } } @@ -222,10 +211,10 @@ public void testDmlReturningExecuteBatchUpdate() { public void testDmlReturningExecuteBatchUpdateAsync() { try (Connection connection = createConnection()) { connection.setAutocommit(false); - final Statement UPDATE_STMT = UPDATE_RETURNING_MAP.get(dialect); + final Statement updateStmt = Preconditions.checkNotNull(UPDATE_RETURNING_MAP.get(dialect)); long[] counts = connection - .executeBatchUpdateAsync(ImmutableList.of(UPDATE_STMT, UPDATE_STMT, UPDATE_STMT)) + .executeBatchUpdateAsync(ImmutableList.of(updateStmt, updateStmt, updateStmt)) .get(); assertArrayEquals(counts, new long[] {3, 3, 3}); } catch (ExecutionException | InterruptedException e) {