diff --git a/tpu/src/main/java/tpu/CreateQueuedResource.java b/tpu/src/main/java/tpu/CreateQueuedResource.java new file mode 100644 index 00000000000..421acff2d04 --- /dev/null +++ b/tpu/src/main/java/tpu/CreateQueuedResource.java @@ -0,0 +1,98 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tpu; + +//[START tpu_queued_resources_create] +import com.google.cloud.tpu.v2alpha1.CreateQueuedResourceRequest; +import com.google.cloud.tpu.v2alpha1.Node; +import com.google.cloud.tpu.v2alpha1.QueuedResource; +import com.google.cloud.tpu.v2alpha1.TpuClient; +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class CreateQueuedResource { + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + // TODO(developer): Replace these variables before running the sample. + // Project ID or project number of the Google Cloud project you want to create a node. + String projectId = "YOUR_PROJECT_ID"; + // The zone in which to create the TPU. + // For more information about supported TPU types for specific zones, + // see https://cloud.google.com/tpu/docs/regions-zones + String zone = "us-central1-f"; + // The name for your TPU. + String nodeName = "YOUR_NODE_ID"; + // The accelerator type that specifies the version and size of the Cloud TPU you want to create. + // For more information about supported accelerator types for each TPU version, + // see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions. + String tpuType = "v2-8"; + // Software version that specifies the version of the TPU runtime to install. + // For more information see https://cloud.google.com/tpu/docs/runtimes + String tpuSoftwareVersion = "tpu-vm-tf-2.14.1"; + // The name for your Queued Resource. + String queuedResourceId = "QUEUED_RESOURCE_ID"; + + createQueuedResource( + projectId, zone, queuedResourceId, nodeName, tpuType, tpuSoftwareVersion); + } + + // Creates a Queued Resource + public static QueuedResource createQueuedResource(String projectId, String zone, + String queuedResourceId, String nodeName, String tpuType, String tpuSoftwareVersion) + throws IOException, ExecutionException, InterruptedException, TimeoutException { + String resource = String.format("projects/%s/locations/%s/queuedResources/%s", + projectId, zone, queuedResourceId); + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. + try (TpuClient tpuClient = TpuClient.create()) { + String parent = String.format("projects/%s/locations/%s", projectId, zone); + Node node = + Node.newBuilder() + .setName(nodeName) + .setAcceleratorType(tpuType) + .setRuntimeVersion(tpuSoftwareVersion) + .setQueuedResource(resource) + .build(); + + QueuedResource queuedResource = + QueuedResource.newBuilder() + .setName(queuedResourceId) + .setTpu( + QueuedResource.Tpu.newBuilder() + .addNodeSpec( + QueuedResource.Tpu.NodeSpec.newBuilder() + .setParent(parent) + .setNode(node) + .setNodeId(nodeName) + .build()) + .build()) + .build(); + + CreateQueuedResourceRequest request = + CreateQueuedResourceRequest.newBuilder() + .setParent(parent) + .setQueuedResourceId(queuedResourceId) + .setQueuedResource(queuedResource) + .build(); + + return tpuClient.createQueuedResourceAsync(request).get(1, TimeUnit.MINUTES); + } + } +} +//[END tpu_queued_resources_create] \ No newline at end of file diff --git a/tpu/src/main/java/tpu/DeleteQueuedResource.java b/tpu/src/main/java/tpu/DeleteQueuedResource.java new file mode 100644 index 00000000000..9f0e123a43e --- /dev/null +++ b/tpu/src/main/java/tpu/DeleteQueuedResource.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tpu; + +//[START tpu_queued_resources_delete] +import com.google.cloud.tpu.v2alpha1.DeleteQueuedResourceRequest; +import com.google.cloud.tpu.v2alpha1.TpuClient; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +public class DeleteQueuedResource { + public static void main(String[] args) + throws IOException, ExecutionException, InterruptedException { + // TODO(developer): Replace these variables before running the sample. + // Project ID or project number of the Google Cloud project. + String projectId = "YOUR_PROJECT_ID"; + // The zone in which the TPU was created. + String zone = "us-central1-f"; + // The name for your Queued Resource. + String queuedResourceId = "QUEUED_RESOURCE_ID"; + + deleteQueuedResource(projectId, zone, queuedResourceId); + } + + // Deletes a Queued Resource asynchronously. + public static void deleteQueuedResource(String projectId, String zone, String queuedResourceId) + throws ExecutionException, InterruptedException, IOException { + String name = String.format("projects/%s/locations/%s/queuedResources/%s", + projectId, zone, queuedResourceId); + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. + try (TpuClient tpuClient = TpuClient.create()) { + // Before deleting the queued resource it is required to delete the TPU VM. + // For more information about deleting TPU + // see https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm + + DeleteQueuedResourceRequest request = + DeleteQueuedResourceRequest.newBuilder().setName(name).build(); + + tpuClient.deleteQueuedResourceAsync(request).get(); + } + } +} +//[END tpu_queued_resources_delete] diff --git a/tpu/src/main/java/tpu/GetQueuedResource.java b/tpu/src/main/java/tpu/GetQueuedResource.java index a17c2b41f79..588987a25f0 100644 --- a/tpu/src/main/java/tpu/GetQueuedResource.java +++ b/tpu/src/main/java/tpu/GetQueuedResource.java @@ -28,7 +28,7 @@ public static void main(String[] args) throws IOException { // Project ID or project number of the Google Cloud project. String projectId = "YOUR_PROJECT_ID"; // The zone in which the TPU was created. - String zone = "europe-west4-a"; + String zone = "us-central1-f"; // The name for your Queued Resource. String queuedResourceId = "QUEUED_RESOURCE_ID"; @@ -50,4 +50,4 @@ public static QueuedResource getQueuedResource( } } } -//[END tpu_queued_resources_get] +//[END tpu_queued_resources_get] \ No newline at end of file diff --git a/tpu/src/test/java/tpu/QueuedResourceIT.java b/tpu/src/test/java/tpu/QueuedResourceIT.java index 15dd2e768cc..92f96298829 100644 --- a/tpu/src/test/java/tpu/QueuedResourceIT.java +++ b/tpu/src/test/java/tpu/QueuedResourceIT.java @@ -17,6 +17,7 @@ package tpu; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; @@ -33,6 +34,7 @@ import com.google.cloud.tpu.v2alpha1.TpuSettings; import java.io.IOException; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.runner.RunWith; @@ -40,7 +42,7 @@ import org.mockito.MockedStatic; @RunWith(JUnit4.class) -@Timeout(value = 10) +@Timeout(value = 2, unit = TimeUnit.MINUTES) public class QueuedResourceIT { private static final String PROJECT_ID = "project-id"; private static final String ZONE = "europe-west4-a"; @@ -50,6 +52,30 @@ public class QueuedResourceIT { private static final String QUEUED_RESOURCE_NAME = "queued-resource"; private static final String NETWORK_NAME = "default"; + @Test + public void testCreateQueuedResource() throws Exception { + try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { + QueuedResource mockQueuedResource = mock(QueuedResource.class); + TpuClient mockTpuClient = mock(TpuClient.class); + OperationFuture mockFuture = mock(OperationFuture.class); + + mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient); + when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class))) + .thenReturn(mockFuture); + when(mockFuture.get(anyLong(), any(TimeUnit.class))).thenReturn(mockQueuedResource); + + QueuedResource returnedQueuedResource = + CreateQueuedResource.createQueuedResource( + PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME, + TPU_TYPE, TPU_SOFTWARE_VERSION); + + verify(mockTpuClient, times(1)) + .createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)); + verify(mockFuture, times(1)).get(anyLong(), any(TimeUnit.class)); + assertEquals(returnedQueuedResource, mockQueuedResource); + } + } + @Test public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { @@ -113,6 +139,25 @@ public void testDeleteForceQueuedResource() } } + @Test + public void testDeleteQueuedResource() + throws IOException, ExecutionException, InterruptedException { + try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) { + TpuClient mockTpuClient = mock(TpuClient.class); + OperationFuture mockFuture = mock(OperationFuture.class); + + mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient); + when(mockTpuClient.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class))) + .thenReturn(mockFuture); + when(mockFuture.get()).thenReturn(null); + + DeleteQueuedResource.deleteQueuedResource(PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME); + + verify(mockTpuClient, times(1)) + .deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class)); + } + } + @Test public void testCreateQueuedResourceWithStartupScript() throws Exception { try (MockedStatic mockedTpuClient = mockStatic(TpuClient.class)) {