Skip to content

Commit b63ddd6

Browse files
xta0facebook-github-bot
authored andcommitted
[OSS][Metal] Support Resnet models
Summary: This diff adds the missing ops to run the Resnet models from Torchvision. Move the tensors to GPU can significantly improve the perf as show below (iPhone11) Time running on CPU (ms): ``` forward took: 166.115 forward took: 150.722 forward took: 150.383 forward took: 150.345 forward took: 150.761 forward took: 150.533 forward took: 150.588 forward took: 150.812 forward took: 150.925 forward took: 150.25 ``` Time running on GPU (ms): ``` forward took: 39.9355 forward took: 41.3531 forward took: 41.798 forward took: 40.4744 forward took: 39.5181 forward took: 42.6464 forward took: 41.2658 forward took: 40.0862 forward took: 42.3533 forward took: 41.9348 ``` Discrepancy in result ``` GPU: "(623, 4.6211)", "(111, 3.8809)", "(499, 3.8555)", "(596, 3.8047)", "(473, 3.7422)", "(846, 3.5762)", "(892, 3.5449)", "(813, 3.5098)", "(446, 3.5020)", "(902, 3.4980)" CPU: "(623, 4.4229)", "(499, 3.8321)", "(596, 3.6192)", "(111, 3.5295)", "(813, 3.4848)", "(584, 3.3979)", "(418, 3.3357)", "(473, 3.2760)", "(846, 3.2745)", "(902, 3.2376)" ``` Test Plan: {F340824316} Reviewed By: IvanKobzarev Differential Revision: D24416294 fbshipit-source-id: 12c9199ade0b76a7aa8a3838eddc4c19c79b6f37
1 parent 9371944 commit b63ddd6

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

Diff for: aten/src/ATen/native/metal/MetalAten.mm

+24
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ Tensor relu(const Tensor& input) {
160160
return mpscnn::relu(input);
161161
}
162162

163+
Tensor& relu_(Tensor& input) {
164+
TORCH_CHECK(input.is_metal());
165+
return mpscnn::relu_(input);
166+
}
167+
163168
Tensor sigmoid(const Tensor& input) {
164169
TORCH_CHECK(input.is_metal());
165170
return mpscnn::sigmoid(input);
@@ -192,6 +197,14 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
192197
return mpscnn::add(input1, input2.is_metal() ? input2 : input2.metal());
193198
}
194199

200+
Tensor& add__Tensor(Tensor& input1, const Tensor& input2, Scalar alpha) {
201+
TORCH_CHECK(input1.is_metal());
202+
TORCH_CHECK(input1.dim() == input2.dim());
203+
TORCH_CHECK(input1.sizes()[2] == input2.sizes()[2]);
204+
TORCH_CHECK(input1.sizes()[3] == input2.sizes()[3]);
205+
return mpscnn::add_(input1, input2.is_metal() ? input2 : input2.metal());
206+
}
207+
195208
Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
196209
TORCH_CHECK(input1.is_metal());
197210
TORCH_CHECK(input1.dim() == input2.dim());
@@ -223,23 +236,34 @@ Tensor reshape(const Tensor& input, IntArrayRef shape) {
223236
return mpscnn::reshape(input, shape);
224237
}
225238

239+
Tensor flatten_using_ints(
240+
const Tensor& input,
241+
int64_t start_dim,
242+
int64_t end_dim) {
243+
TORCH_CHECK(input.is_metal());
244+
return mpscnn::flatten_using_ints(input, start_dim, end_dim);
245+
}
246+
226247
TORCH_LIBRARY_IMPL(aten, Metal, m) {
227248
m.impl("conv2d", TORCH_FN(conv2d));
228249
m.impl("add.Tensor", TORCH_FN(add_Tensor));
250+
m.impl("add_.Tensor", TORCH_FN(add__Tensor));
229251
m.impl("addmm", TORCH_FN(addmm));
230252
m.impl_UNBOXED("empty.memory_format", empty);
231253
m.impl("empty_strided", TORCH_FN(empty_strided));
232254
m.impl("log_softmax.int", TORCH_FN(log_softmax_int));
233255
m.impl("max_pool2d", TORCH_FN(max_pool2d));
234256
m.impl("mul.Tensor", TORCH_FN(mul_Tensor));
235257
m.impl("relu", TORCH_FN(relu));
258+
m.impl("relu_", TORCH_FN(relu_));
236259
m.impl("sigmoid", TORCH_FN(sigmoid));
237260
m.impl("sub.Tensor", TORCH_FN(sub_Tensor));
238261
m.impl("upsample_nearest2d.vec", TORCH_FN(upsample_nearest2d_vec));
239262
m.impl("view", TORCH_FN(view));
240263
m.impl("adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d));
241264
m.impl("hardtanh_", TORCH_FN(hardtanh_));
242265
m.impl("reshape", TORCH_FN(reshape));
266+
m.impl("flatten.using_ints", TORCH_FN(flatten_using_ints));
243267
}
244268

245269
} // namespace metal

Diff for: aten/src/ATen/native/metal/mpscnn/MPSCNNOps.h

+6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ Tensor global_avg_pool2d(const Tensor& input, IntArrayRef output_size);
3030

3131
Tensor relu(const Tensor& input);
3232

33+
Tensor& relu_(Tensor& input);
34+
3335
Tensor sigmoid(const Tensor& input);
3436

3537
Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val);
@@ -44,6 +46,8 @@ Tensor addmm(const Tensor& bias, const Tensor& input, const Tensor& weight);
4446

4547
Tensor add(const Tensor& input1, const Tensor& input2);
4648

49+
Tensor& add_(Tensor& input1, const Tensor& input2);
50+
4751
Tensor sub(const Tensor& input1, const Tensor& input2);
4852

4953
Tensor mul(const Tensor& input1, const Tensor& input2);
@@ -55,6 +59,8 @@ Tensor upsample_nearest2d_vec(
5559
c10::optional<IntArrayRef> output_size,
5660
c10::optional<ArrayRef<double>> scale_factors);
5761

62+
Tensor flatten_using_ints(const Tensor & input, int64_t start_dim, int64_t end_dim);
63+
5864
Tensor copy_to_host(const Tensor& input);
5965

6066
} // namespace mpscnn

Diff for: aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm

+92
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,36 @@ Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) {
216216
return output;
217217
}
218218

219+
API_AVAILABLE(ios(10.0), macos(10.13))
220+
Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) {
221+
MPSImage* X = imageFromTensor(input);
222+
std::vector<int64_t> outputSize = input.sizes().vec();
223+
std::vector<int64_t> textureSize = outputSize;
224+
if (input.dim() == 2) {
225+
textureSize = {outputSize[0], outputSize[1], 1, 1};
226+
}
227+
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
228+
MPSImage* Y = [MPSImage temporaryImageFromSize:input.sizes().vec()
229+
commandBuffer:commandBuffer];
230+
[neuron encodeToCommandBuffer:commandBuffer.buffer
231+
sourceImage:X
232+
destinationImage:Y];
233+
MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
234+
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
235+
metalTensor.texture()->copyFromTexture(Y);
236+
return input;
237+
}
238+
219239
API_AVAILABLE(ios(10.0), macos(10.13))
220240
Tensor relu(const Tensor& input) {
221241
return neuronKernel(input, [MPSCNNNeuronOp relu]);
222242
}
223243

244+
API_AVAILABLE(ios(10.0), macos(10.13))
245+
Tensor& relu_(Tensor& input) {
246+
return neuronKernel_(input, [MPSCNNNeuronOp relu]);
247+
}
248+
224249
API_AVAILABLE(ios(10.0), macos(10.13))
225250
Tensor sigmoid(const Tensor& input) {
226251
return neuronKernel(input, [MPSCNNNeuronOp sigmoid]);
@@ -356,12 +381,50 @@ Tensor binaryElementwiseKernel(
356381
return output;
357382
}
358383

384+
API_AVAILABLE(ios(10.0), macos(10.13))
385+
Tensor& binaryElementwiseKernel_(
386+
Tensor& input1,
387+
const Tensor& input2,
388+
NSString* arrayKernel,
389+
NSString* nonarrayKernal) {
390+
MPSImage* X1 = imageFromTensor(input1);
391+
MPSImage* X2 = imageFromTensor(input2);
392+
std::vector<int64_t> outputSize = input1.sizes().vec();
393+
MetalCommandBuffer* cb1 = commandBufferFromInputTensor(input1);
394+
MetalCommandBuffer* cb2 = commandBufferFromInputTensor(input2);
395+
TORCH_CHECK([cb1 isEqual:cb2], @"inputs have different command buffer");
396+
MPSImage* Y = [MPSImage temporaryImageFromSize:outputSize commandBuffer:cb1];
397+
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
398+
pipelineState:kernelFor(X1, arrayKernel, nonarrayKernal)];
399+
id<MTLComputeCommandEncoder> encoder = [cb1.buffer computeCommandEncoder];
400+
[encoder setComputePipelineState:state];
401+
[encoder setTexture:[X1 texture] atIndex:0];
402+
[encoder setTexture:[X2 texture] atIndex:1];
403+
[encoder setTexture:[Y texture] atIndex:2];
404+
const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Y);
405+
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
406+
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
407+
[encoder endEncoding];
408+
[X1 markRead];
409+
[X2 markRead];
410+
MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl();
411+
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
412+
metalTensor.texture()->copyFromTexture(Y);
413+
return input1;
414+
}
415+
359416
API_AVAILABLE(ios(10.0), macos(10.13))
360417
Tensor add(const Tensor& input1, const Tensor& input2) {
361418
return binaryElementwiseKernel(
362419
input1, input2, @"elementwise_add", @"elementwise_add_nonarray");
363420
}
364421

422+
API_AVAILABLE(ios(10.0), macos(10.13))
423+
Tensor& add_(Tensor& input1, const Tensor& input2) {
424+
return binaryElementwiseKernel_(
425+
input1, input2, @"elementwise_add", @"elementwise_add_nonarray");
426+
}
427+
365428
API_AVAILABLE(ios(10.0), macos(10.13))
366429
Tensor sub(const Tensor& input1, const Tensor& input2) {
367430
return binaryElementwiseKernel(
@@ -510,6 +573,35 @@ Tensor upsample_nearest2d_vec(
510573
return output;
511574
}
512575

576+
Tensor flatten_using_ints(
577+
const Tensor& input,
578+
int64_t start_dim,
579+
int64_t end_dim) {
580+
start_dim = maybe_wrap_dim(start_dim, input.dim());
581+
end_dim = maybe_wrap_dim(end_dim, input.dim());
582+
TORCH_CHECK(
583+
start_dim <= end_dim,
584+
"flatten() has invalid args: start_dim cannot come after end_dim");
585+
std::vector<int64_t> shape;
586+
if (input.dim() == 0) {
587+
return input.reshape({1});
588+
}
589+
if (start_dim == end_dim) {
590+
return input;
591+
}
592+
auto slice_numel =
593+
prod_intlist(input.sizes().slice(start_dim, end_dim - start_dim + 1));
594+
shape.reserve(input.dim() - end_dim + start_dim);
595+
for (int64_t i = 0; i < start_dim; i++) {
596+
shape.push_back(input.size(i));
597+
}
598+
shape.push_back(slice_numel);
599+
for (int64_t i = end_dim + 1; i < input.dim(); i++) {
600+
shape.push_back(input.size(i));
601+
}
602+
return input.reshape(shape);
603+
}
604+
513605
Tensor copy_to_host(const Tensor& input) {
514606
MPSImage* X = imageFromTensor(input);
515607
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);

0 commit comments

Comments
 (0)