Skip to content
This repository was archived by the owner on May 29, 2023. It is now read-only.

Commit 5ebc35b

Browse files
committed
torch.lstsq specific case
1 parent 4fd2091 commit 5ebc35b

File tree

8 files changed

+315
-0
lines changed

8 files changed

+315
-0
lines changed

examples/lstsq/export_model.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import numpy as np
2+
import torch
3+
from torch import nn
4+
from .lstsq import LSTSQ
5+
6+
class Model(nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
10+
def forward(self, A, B):
11+
return LSTSQ.apply(B, A)
12+
13+
14+
# Solves min_X||AX - B|| where A has a shape Mx2 and B has a shape MxN
15+
def export(M, N):
16+
np.random.seed(324)
17+
torch.manual_seed(32)
18+
19+
model = Model()
20+
A = torch.rand([M, 2])
21+
B = torch.rand([M, N])
22+
23+
with torch.no_grad():
24+
torch.onnx.export(model, (A, B), 'model.onnx',
25+
input_names=['input', 'input1'],
26+
output_names=['output'],
27+
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
28+
29+
ref = model(A, B)
30+
np.save('inp', A.detach().numpy())
31+
np.save('inp1', B.detach().numpy())
32+
np.save('ref', ref.detach().numpy())

examples/lstsq/lstsq.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
3+
def solve_squares(B, A):
4+
# 1. Perform QR decomposition of matrix A
5+
print("A", A.shape)
6+
print("B", B.shape)
7+
8+
def prod(vec0, vec1):
9+
return (vec0 * vec1).sum()
10+
11+
def norm(vec):
12+
return vec / (vec * vec).sum().sqrt()
13+
14+
col0 = norm(A[:, 0])
15+
col1 = norm(A[:, 1] - prod(A[:, 1], col0) * col0)
16+
17+
Q = torch.stack((col0, col1), axis=1)
18+
R = torch.tensor([[prod(A[:, 0], col0), prod(A[:, 1], col0)],
19+
[0, prod(A[:, 1], col1)]])
20+
21+
X = torch.matmul(torch.inverse(R), Q.transpose(1, 0))
22+
X = torch.matmul(X, B)
23+
return X
24+
25+
class LSTSQ(torch.autograd.Function):
26+
@staticmethod
27+
def symbolic(g, input, A):
28+
return g.op("lstsq", input, A)
29+
30+
@staticmethod
31+
def forward(self, input, A):
32+
return torch.lstsq(input, A)[0][:2]

tests/run_tests.py

+8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def run_test(convert_ir=True, test_onnx=False, num_inputs=1, threshold=1e-5):
3232
ref = np.load('ref.npy')
3333

3434
ie = IECore()
35+
print(get_extensions_path())
3536
ie.add_extension(get_extensions_path(), 'CPU')
3637
ie.set_config({'CONFIG_FILE': 'user_ie_extensions/gpu_extensions.xml'}, 'GPU')
3738

@@ -145,3 +146,10 @@ def test_deformable_conv():
145146
)
146147
run_test(num_inputs=2, threshold=2e-5)
147148
run_test(num_inputs=2, test_onnx=True, threshold=2e-5)
149+
150+
151+
def test_lstsq():
152+
from examples.lstsq.export_model import export
153+
154+
export(5, 1000)
155+
run_test(num_inputs=2, test_onnx=True)

user_ie_extensions/cpu_kernel.hpp

+16
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,20 @@ class CalculateGridImpl : public InferenceEngine::ILayerExecImpl {
127127
std::string error;
128128
};
129129

130+
class LSTSQImpl : public InferenceEngine::ILayerExecImpl {
131+
public:
132+
explicit LSTSQImpl(const std::shared_ptr<ngraph::Node>& node);
133+
InferenceEngine::StatusCode getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig> &conf,
134+
InferenceEngine::ResponseDesc *resp) noexcept override;
135+
InferenceEngine::StatusCode init(InferenceEngine::LayerConfig &config,
136+
InferenceEngine::ResponseDesc *resp) noexcept override;
137+
InferenceEngine::StatusCode execute(std::vector<InferenceEngine::Blob::Ptr> &inputs,
138+
std::vector<InferenceEngine::Blob::Ptr> &outputs,
139+
InferenceEngine::ResponseDesc *resp) noexcept override;
140+
private:
141+
std::vector<ngraph::Shape> inShapes;
142+
ngraph::Shape outShape;
143+
std::string error;
144+
};
145+
130146
} // namespace TemplateExtension

user_ie_extensions/extension.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ Extension::Extension() {
4949
ngraph::OutputVector ng_inputs {node.get_ng_inputs()};
5050
return {std::make_shared<CalculateGridOp>(ng_inputs.at(0))};
5151
});
52+
ngraph::onnx_import::register_operator(LSTSQOp::type_info.name, 1, "", [](const ngraph::onnx_import::Node& node) -> ngraph::OutputVector {
53+
ngraph::OutputVector ng_inputs {node.get_ng_inputs()};
54+
return {std::make_shared<LSTSQOp>(ng_inputs.at(0), ng_inputs.at(1))};
55+
});
5256
}
5357

5458
Extension::~Extension() {
@@ -59,6 +63,7 @@ Extension::~Extension() {
5963
ngraph::onnx_import::unregister_operator(SparseConvOp::type_info.name, 1, "org.open3d");
6064
ngraph::onnx_import::unregister_operator(SparseConvTransposeOp::type_info.name, 1, "org.open3d");
6165
ngraph::onnx_import::unregister_operator(CalculateGridOp::type_info.name, 1, "org.open3d");
66+
ngraph::onnx_import::unregister_operator(LSTSQOp::type_info.name, 1, "");
6267
}
6368

6469
//! [extension:GetVersion]
@@ -85,6 +90,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
8590
opset.insert<SparseConvOp>();
8691
opset.insert<SparseConvTransposeOp>();
8792
opset.insert<CalculateGridOp>();
93+
opset.insert<LSTSQOp>();
8894
opsets["extension"] = opset;
8995
return opsets;
9096
}
@@ -98,6 +104,7 @@ std::vector<std::string> Extension::getImplTypes(const std::shared_ptr<ngraph::N
98104
std::dynamic_pointer_cast<SparseConvOp>(node) ||
99105
std::dynamic_pointer_cast<SparseConvTransposeOp>(node) ||
100106
std::dynamic_pointer_cast<CalculateGridOp>(node) ||
107+
std::dynamic_pointer_cast<LSTSQOp>(node) ||
101108
std::dynamic_pointer_cast<IFFTOp>(node) ||
102109
std::dynamic_pointer_cast<FFTOp>(node)) {
103110
return {"CPU"};
@@ -129,6 +136,9 @@ InferenceEngine::ILayerImpl::Ptr Extension::getImplementation(const std::shared_
129136
if (std::dynamic_pointer_cast<CalculateGridOp>(node) && implType == "CPU") {
130137
return std::make_shared<CalculateGridImpl>(node);
131138
}
139+
if (std::dynamic_pointer_cast<LSTSQOp>(node) && implType == "CPU") {
140+
return std::make_shared<LSTSQImpl>(node);
141+
}
132142
return nullptr;
133143
}
134144
//! [extension:getImplementation]

user_ie_extensions/lstsq_impl.cpp

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Copyright (C) 2020 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include "cpu_kernel.hpp"
5+
#include "op.hpp"
6+
#include <details/ie_exception.hpp>
7+
#include <ie_layouts.h>
8+
#include "ie_parallel.hpp"
9+
10+
using namespace TemplateExtension;
11+
12+
//! [cpu_implementation:ctor]
13+
LSTSQImpl::LSTSQImpl(const std::shared_ptr<ngraph::Node> &node) {
14+
try {
15+
auto castedNode = std::dynamic_pointer_cast<LSTSQOp>(node);
16+
if (!castedNode)
17+
THROW_IE_EXCEPTION << "Cannot create implementation for unknown operation!";
18+
if (castedNode->inputs().size() != 2 || castedNode->outputs().size() != 1)
19+
THROW_IE_EXCEPTION << "Cannot create implementation for operation with incorrect number of inputs or outputs!";
20+
if (castedNode->get_input_partial_shape(0).is_dynamic() || castedNode->get_output_partial_shape(0).is_dynamic())
21+
THROW_IE_EXCEPTION << "Cannot create implementation for op with dynamic shapes!";
22+
if (castedNode->get_input_shape(0).size() != 2 || castedNode->get_output_shape(0).size() != 2)
23+
THROW_IE_EXCEPTION << "Operation supports only 4d tensors for input and output.";
24+
if (castedNode->get_input_element_type(0) != ngraph::element::f32 || castedNode->get_output_element_type(0) != ngraph::element::f32)
25+
THROW_IE_EXCEPTION << "Operation supports only FP32 tensors.";
26+
inShapes.resize(2);
27+
for (int i = 0; i < inShapes.size(); ++i)
28+
inShapes[i] = castedNode->get_input_shape(i);
29+
outShape = castedNode->get_output_shape(0);
30+
} catch (InferenceEngine::details::InferenceEngineException& ex) {
31+
error = ex.what();
32+
}
33+
34+
}
35+
//! [cpu_implementation:ctor]
36+
37+
//! [cpu_implementation:getSupportedConfigurations]
38+
InferenceEngine::StatusCode LSTSQImpl::getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig> &conf,
39+
InferenceEngine::ResponseDesc *resp) noexcept {
40+
std::vector<InferenceEngine::DataConfig> inDataConfig;
41+
std::vector<InferenceEngine::DataConfig> outDataConfig;
42+
// Allow any offset before data
43+
size_t offset((std::numeric_limits<size_t>::max)());
44+
45+
// Input shape
46+
for (const auto& shape : inShapes)
47+
{
48+
InferenceEngine::SizeVector order(shape.size());
49+
std::iota(order.begin(), order.end(), 0);
50+
51+
InferenceEngine::DataConfig inpConf;
52+
inpConf.desc = InferenceEngine::TensorDesc(InferenceEngine::Precision::FP32, shape, {shape, order, offset});
53+
inDataConfig.push_back(inpConf);
54+
}
55+
56+
// Output shape
57+
InferenceEngine::SizeVector order(outShape.size());
58+
std::iota(order.begin(), order.end(), 0);
59+
60+
InferenceEngine::DataConfig outConf;
61+
outConf.desc = InferenceEngine::TensorDesc(InferenceEngine::Precision::FP32, outShape, {outShape, order, offset});
62+
outDataConfig.push_back(outConf);
63+
64+
InferenceEngine::LayerConfig layerConfig;
65+
layerConfig.inConfs = inDataConfig;
66+
layerConfig.outConfs = outDataConfig;
67+
68+
conf.push_back(layerConfig);
69+
return InferenceEngine::StatusCode::OK;
70+
}
71+
//! [cpu_implementation:getSupportedConfigurations]
72+
73+
//! [cpu_implementation:init]
74+
InferenceEngine::StatusCode LSTSQImpl::init(InferenceEngine::LayerConfig &config, InferenceEngine::ResponseDesc *resp) noexcept {
75+
try {
76+
if (config.inConfs.size() != 2 || config.outConfs.size() != 1) {
77+
THROW_IE_EXCEPTION << "Operation cannot be initialized with incorrect number of inputs/outputs!";
78+
}
79+
80+
if (config.inConfs[0].desc.getDims().size() != 2 || config.outConfs[0].desc.getDims().size() != 2) {
81+
THROW_IE_EXCEPTION << "Operation can be initialized only with 2d input/output tensors!";
82+
}
83+
84+
if (config.outConfs[0].desc.getPrecision() != InferenceEngine::Precision::FP32 ||
85+
config.inConfs[0].desc.getPrecision() != InferenceEngine::Precision::FP32) {
86+
THROW_IE_EXCEPTION << "Operation supports only FP32 precisions!";
87+
}
88+
} catch (InferenceEngine::details::InferenceEngineException& ex) {
89+
if (resp) {
90+
strncpy(resp->msg, error.c_str(), sizeof(resp->msg) - 1);
91+
resp->msg[sizeof(resp->msg)-1] = 0;
92+
}
93+
return InferenceEngine::GENERAL_ERROR;
94+
}
95+
96+
return InferenceEngine::OK;
97+
}
98+
//! [cpu_implementation:init]
99+
100+
//! [cpu_implementation:execute]
101+
InferenceEngine::StatusCode LSTSQImpl::execute(std::vector<InferenceEngine::Blob::Ptr> &inputs,
102+
std::vector<InferenceEngine::Blob::Ptr> &outputs,
103+
InferenceEngine::ResponseDesc *resp) noexcept {
104+
const float* B = inputs[0]->cbuffer().as<float*>();
105+
const float* A = inputs[1]->cbuffer().as<float*>();
106+
float* out = outputs[0]->buffer().as<float*>();
107+
108+
// Perform A = QR factorization. This implementation works on A with 2 columns.
109+
const size_t M = inputs[0]->getTensorDesc().getDims()[0];
110+
const size_t N = inputs[0]->getTensorDesc().getDims()[1];
111+
112+
std::vector<float> Q(M * 2);
113+
std::vector<float> R(4, 0.0f);
114+
float norm0 = 0.0f;
115+
float product = 0.0f; // cross-product between second column of A with first column of Q
116+
for (int i = 0; i < M; ++i) {
117+
float val = A[i * 2];
118+
product += A[i * 2 + 1] * val;
119+
norm0 += val * val;
120+
}
121+
norm0 = sqrtf(norm0);
122+
product /= norm0;
123+
R[1] = product;
124+
125+
float norm1 = 0.0f;
126+
for (int i = 0; i < M; ++i) {
127+
float val = A[i * 2] / norm0;
128+
Q[i * 2] = val;
129+
R[0] += A[i * 2] * val;
130+
131+
val = A[i * 2 + 1] - product * val;
132+
Q[i * 2 + 1] = val;
133+
norm1 += val * val;
134+
R[3] += A[i * 2 + 1] * val;
135+
}
136+
norm1 = sqrtf(norm1);
137+
for (int i = 0; i < M; ++i) {
138+
Q[i * 2 + 1] /= norm1;
139+
}
140+
R[3] /= norm1;
141+
142+
// Inverse R matrix
143+
float scale = 1.0f / (R[0] * R[3]);
144+
std::vector<float> R_inv{R[3] * scale, -R[1] * scale, 0.0f, R[0] * scale};
145+
146+
// Output is inverse(R) * transpose(Q) * B
147+
for (int i = 0; i < M; ++i) {
148+
Q[i * 2] = R_inv[0] * Q[i * 2] + R_inv[1] * Q[i * 2 + 1];
149+
Q[i * 2 + 1] *= R_inv[3];
150+
}
151+
152+
for (int i = 0; i < N; ++i) {
153+
out[i] = 0.0f;
154+
out[N + i] = 0.0f;
155+
for (int j = 0; j < M; ++j) {
156+
out[i] += Q[j * 2] * B[j * N + i];
157+
out[N + i] += Q[j * 2 + 1] * B[j * N + i];
158+
}
159+
}
160+
return InferenceEngine::OK;
161+
}
162+
//! [cpu_implementation:execute]

user_ie_extensions/lstsq_op.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (C) 2020 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include "op.hpp"
5+
6+
using namespace TemplateExtension;
7+
8+
constexpr ngraph::NodeTypeInfo LSTSQOp::type_info;
9+
10+
//! [op:ctor]
11+
LSTSQOp::LSTSQOp(
12+
const ngraph::Output<ngraph::Node>& B,
13+
const ngraph::Output<ngraph::Node>& A
14+
)
15+
: Op({B, A}) {
16+
constructor_validate_and_infer_types();
17+
}
18+
//! [op:ctor]
19+
20+
//! [op:validate]
21+
void LSTSQOp::validate_and_infer_types() {
22+
auto outShape = get_input_partial_shape(0);
23+
outShape[0] = 2;
24+
set_output_type(0, get_input_element_type(0), outShape);
25+
}
26+
//! [op:validate]
27+
28+
//! [op:copy]
29+
std::shared_ptr<ngraph::Node> LSTSQOp::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
30+
if (new_args.size() != 2) {
31+
throw ngraph::ngraph_error("Incorrect number of new arguments");
32+
}
33+
return std::make_shared<LSTSQOp>(new_args.at(0), new_args.at(1));
34+
}
35+
//! [op:copy]
36+
37+
//! [op:visit_attributes]
38+
bool LSTSQOp::visit_attributes(ngraph::AttributeVisitor &visitor) {
39+
return true;
40+
}
41+
//! [op:visit_attributes]

user_ie_extensions/op.hpp

+14
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ class CalculateGridOp : public ngraph::op::Op {
116116
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
117117
};
118118

119+
class LSTSQOp : public ngraph::op::Op {
120+
public:
121+
static constexpr ngraph::NodeTypeInfo type_info{"lstsq", 0};
122+
const ngraph::NodeTypeInfo& get_type_info() const override { return type_info; }
123+
124+
LSTSQOp() = default;
125+
LSTSQOp(const ngraph::Output<ngraph::Node>& B,
126+
const ngraph::Output<ngraph::Node>& A);
127+
void validate_and_infer_types() override;
128+
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override;
129+
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
130+
};
131+
132+
119133
//! [op:header]
120134

121135
} // namespace TemplateExtension

0 commit comments

Comments
 (0)