Skip to content

Commit e801568

Browse files
committed
Add support for prelu
1 parent 56491a2 commit e801568

File tree

4 files changed

+14
-5
lines changed

4 files changed

+14
-5
lines changed

Diff for: common/flatbuffers_helper.h

+2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ inline std::string layer_type_to_str(flatbnn::LayerType type) {
150150
return "split";
151151
case flatbnn::LayerType::Shuffle:
152152
return "shuffle";
153+
case flatbnn::LayerType::PRelu:
154+
return "prelu";
153155
default:
154156
BNN_ASSERT(false, "Missing type in this function");
155157
}

Diff for: dabnn/layers/PRelu.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ class PRelu : public Layer {
1111
MatCP data_mat;
1212
MatCP slope_mat;
1313

14-
PRelu(NetCP net, const std::string &name, css data)
15-
: Layer(net, name, "PRelu"), data_mat(mat(data)) {}
14+
PRelu(NetCP net, const std::string &name, css data, css slope)
15+
: Layer(net, name, "PRelu"), data_mat(mat(data)), slope_mat(mat(slope)) {}
1616
virtual void forward_impl() const;
1717
};
1818
} // namespace bnn

Diff for: dabnn/net.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <dabnn/layers/Relu.h>
2323
#include <dabnn/layers/Shuffle.h>
2424
#include <dabnn/layers/Split.h>
25+
#include <dabnn/layers/PRelu.h>
2526

2627
using std::string;
2728
using std::vector;
@@ -253,6 +254,12 @@ void Net::prepare() {
253254
std::make_shared<Shuffle>(get_weak(), name, input));
254255
break;
255256
}
257+
case flatbnn::LayerType::PRelu: {
258+
ADD_INPLACE_LAYER(prelu, Eltwise, input, slope, output);
259+
layers.push_back(
260+
std::make_shared<PRelu>(get_weak(), name, input, slope));
261+
break;
262+
}
256263
default: {
257264
throw std::runtime_error("Not supported op " +
258265
layer_type_to_str(layer->type()));

Diff for: tools/onnx2bnn/OnnxConverter.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ std::vector<std::string> OnnxConverter::Convert(
356356
layers_.push_back(layer);
357357
VLOG(5) << "Converting Pool completed";
358358
} else if (op == "PRelu") {
359-
VLOG(5) << "Start converting Relu";
359+
VLOG(5) << "Start converting PRelu";
360360
auto input_name = m(node.input(0));
361361
auto slope_name = m(node.input(1));
362362
const auto onnx_slope_tensor = onnx_float_tensors_.at(slope_name);
@@ -366,7 +366,7 @@ std::vector<std::string> OnnxConverter::Convert(
366366
BNN_ASSERT(
367367
(slope_shape.size() == 3 && slope_shape[1] == 1 &&
368368
slope_shape[2] == 1) ||
369-
onnx_slope_tensor.data == {1},
369+
onnx_slope_tensor.data == std::vector<float>{1},
370370
"PRelu only support scalr slope or per-channel slope for now");
371371
const Shape flat_slope_shape{slope_shape[0]};
372372
auto flat_slope_tensor = flatbnn::CreateTensorDirect(
@@ -382,7 +382,7 @@ std::vector<std::string> OnnxConverter::Convert(
382382
flatbnn::CreateLayer(builder_, flatbnn::LayerType::PRelu, 0, 0,
383383
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, param);
384384
layers_.push_back(layer);
385-
VLOG(5) << "Converting Relu completed";
385+
VLOG(5) << "Converting PRelu completed";
386386
} else if (op == "Relu") {
387387
VLOG(5) << "Start converting Relu";
388388
auto input_name = m(node.input(0));

0 commit comments

Comments
 (0)