@@ -216,11 +216,36 @@ Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) {
216
216
return output;
217
217
}
218
218
219
+ API_AVAILABLE (ios(10.0 ), macos(10.13 ))
220
+ Tensor& neuronKernel_ (Tensor& input, MPSCNNNeuron* neuron) {
221
+ MPSImage* X = imageFromTensor (input);
222
+ std::vector<int64_t > outputSize = input.sizes ().vec ();
223
+ std::vector<int64_t > textureSize = outputSize;
224
+ if (input.dim () == 2 ) {
225
+ textureSize = {outputSize[0 ], outputSize[1 ], 1 , 1 };
226
+ }
227
+ MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor (input);
228
+ MPSImage* Y = [MPSImage temporaryImageFromSize: input.sizes ().vec ()
229
+ commandBuffer: commandBuffer];
230
+ [neuron encodeToCommandBuffer: commandBuffer.buffer
231
+ sourceImage: X
232
+ destinationImage: Y];
233
+ MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl ();
234
+ MetalTensor& metalTensor = impl->unsafe_opaque_handle ();
235
+ metalTensor.texture ()->copyFromTexture (Y);
236
+ return input;
237
+ }
238
+
219
239
API_AVAILABLE (ios(10.0 ), macos(10.13 ))
220
240
Tensor relu (const Tensor& input) {
221
241
return neuronKernel (input, [MPSCNNNeuronOp relu ]);
222
242
}
223
243
244
+ API_AVAILABLE (ios(10.0 ), macos(10.13 ))
245
+ Tensor& relu_ (Tensor& input) {
246
+ return neuronKernel_ (input, [MPSCNNNeuronOp relu ]);
247
+ }
248
+
224
249
API_AVAILABLE (ios(10.0 ), macos(10.13 ))
225
250
Tensor sigmoid (const Tensor& input) {
226
251
return neuronKernel (input, [MPSCNNNeuronOp sigmoid ]);
@@ -356,12 +381,50 @@ Tensor binaryElementwiseKernel(
356
381
return output;
357
382
}
358
383
384
+ API_AVAILABLE (ios(10.0 ), macos(10.13 ))
385
+ Tensor& binaryElementwiseKernel_ (
386
+ Tensor& input1,
387
+ const Tensor& input2,
388
+ NSString * arrayKernel,
389
+ NSString * nonarrayKernal) {
390
+ MPSImage* X1 = imageFromTensor (input1);
391
+ MPSImage* X2 = imageFromTensor (input2);
392
+ std::vector<int64_t > outputSize = input1.sizes ().vec ();
393
+ MetalCommandBuffer* cb1 = commandBufferFromInputTensor (input1);
394
+ MetalCommandBuffer* cb2 = commandBufferFromInputTensor (input2);
395
+ TORCH_CHECK ([cb1 isEqual: cb2], @" inputs have different command buffer" );
396
+ MPSImage* Y = [MPSImage temporaryImageFromSize: outputSize commandBuffer: cb1];
397
+ id <MTLComputePipelineState > state = [[MPSCNNContext sharedInstance ]
398
+ pipelineState: kernelFor (X1, arrayKernel, nonarrayKernal)];
399
+ id <MTLComputeCommandEncoder > encoder = [cb1.buffer computeCommandEncoder ];
400
+ [encoder setComputePipelineState: state];
401
+ [encoder setTexture: [X1 texture ] atIndex: 0 ];
402
+ [encoder setTexture: [X2 texture ] atIndex: 1 ];
403
+ [encoder setTexture: [Y texture ] atIndex: 2 ];
404
+ const auto & launchParams = spatialPointwiseKernelLaunchParams (state, Y);
405
+ [encoder dispatchThreadgroups: launchParams.threadgroupsPerGrid
406
+ threadsPerThreadgroup: launchParams.threadsPerThreadgroup];
407
+ [encoder endEncoding ];
408
+ [X1 markRead ];
409
+ [X2 markRead ];
410
+ MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl ();
411
+ MetalTensor& metalTensor = impl->unsafe_opaque_handle ();
412
+ metalTensor.texture ()->copyFromTexture (Y);
413
+ return input1;
414
+ }
415
+
359
416
API_AVAILABLE (ios(10.0 ), macos(10.13 ))
360
417
Tensor add (const Tensor& input1, const Tensor& input2) {
361
418
return binaryElementwiseKernel (
362
419
input1, input2, @" elementwise_add" , @" elementwise_add_nonarray" );
363
420
}
364
421
422
+ API_AVAILABLE (ios(10.0 ), macos(10.13 ))
423
+ Tensor& add_ (Tensor& input1, const Tensor& input2) {
424
+ return binaryElementwiseKernel_ (
425
+ input1, input2, @" elementwise_add" , @" elementwise_add_nonarray" );
426
+ }
427
+
365
428
API_AVAILABLE (ios(10.0 ), macos(10.13 ))
366
429
Tensor sub (const Tensor& input1, const Tensor& input2) {
367
430
return binaryElementwiseKernel (
@@ -510,6 +573,35 @@ Tensor upsample_nearest2d_vec(
510
573
return output;
511
574
}
512
575
576
+ Tensor flatten_using_ints (
577
+ const Tensor& input,
578
+ int64_t start_dim,
579
+ int64_t end_dim) {
580
+ start_dim = maybe_wrap_dim (start_dim, input.dim ());
581
+ end_dim = maybe_wrap_dim (end_dim, input.dim ());
582
+ TORCH_CHECK (
583
+ start_dim <= end_dim,
584
+ " flatten() has invalid args: start_dim cannot come after end_dim" );
585
+ std::vector<int64_t > shape;
586
+ if (input.dim () == 0 ) {
587
+ return input.reshape ({1 });
588
+ }
589
+ if (start_dim == end_dim) {
590
+ return input;
591
+ }
592
+ auto slice_numel =
593
+ prod_intlist (input.sizes ().slice (start_dim, end_dim - start_dim + 1 ));
594
+ shape.reserve (input.dim () - end_dim + start_dim);
595
+ for (int64_t i = 0 ; i < start_dim; i++) {
596
+ shape.push_back (input.size (i));
597
+ }
598
+ shape.push_back (slice_numel);
599
+ for (int64_t i = end_dim + 1 ; i < input.dim (); i++) {
600
+ shape.push_back (input.size (i));
601
+ }
602
+ return input.reshape (shape);
603
+ }
604
+
513
605
Tensor copy_to_host (const Tensor& input) {
514
606
MPSImage* X = imageFromTensor (input);
515
607
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor (input);
0 commit comments