Skip to content

Commit 84d1c3d

Browse files
authored
chore(accelerated_op): use correct Python Ctype for pybind11 function prototype (#52)
1 parent 5b5b21d commit 84d1c3d

13 files changed

+162
-156
lines changed

.github/workflows/lint.yml

+6-8
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,9 @@ jobs:
5454
run: |
5555
python -m pip install --upgrade pip setuptools
5656
57-
- name: Install dependencies
58-
run: |
59-
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
60-
-r tests/requirements.txt
61-
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
62-
-r docs/requirements.txt
63-
6457
- name: Install TorchOpt
6558
run: |
66-
python -m pip install -e .
59+
python -m pip install -vvv -e '.[lint]'
6760
6861
- name: pre-commit
6962
run: |
@@ -97,6 +90,11 @@ jobs:
9790
run: |
9891
make mypy
9992
93+
- name: Install dependencies
94+
run: |
95+
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
96+
-r docs/requirements.txt
97+
10098
- name: docstyle
10199
run: |
102100
make docstyle

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ jobs:
7272
7373
- name: Install TorchOpt
7474
run: |
75-
python -m pip install -e .
75+
python -m pip install -vvv -e .
7676
7777
- name: Test with pytest
7878
run: |

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
- Use [`cibuildwheel`](https://github.com/pypa/cibuildwheel) to build wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#45](https://github.com/metaopt/TorchOpt/pull/45).
1616
- Use dynamic process number in CPU kernels by [@JieRen98](https://github.com/JieRen98) in [#42](https://github.com/metaopt/TorchOpt/pull/42).
1717

18+
### Changed
19+
20+
- Use correct Python Ctype for pybind11 function prototype [@XuehaiPan](https://github.com/XuehaiPan) in [#52](https://github.com/metaopt/TorchOpt/pull/52).
21+
1822
------
1923

2024
## [0.4.2] - 2022-07-26

conda-recipe.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ dependencies:
7676
- mypy
7777
- flake8
7878
- flake8-bugbear
79-
- doc8
79+
- doc8 < 1.0.0a0
8080
- pydocstyle
8181
- clang-format
8282
- clang-tools # clang-tidy

include/adam_op/adam_op.h

+15-12
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,35 @@
2323
namespace torchopt {
2424
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
2525
const torch::Tensor& mu,
26-
const torch::Tensor& nu, const float b1,
27-
const float b2, const float eps,
28-
const float eps_root, const int count);
26+
const torch::Tensor& nu, const pyfloat_t b1,
27+
const pyfloat_t b2, const pyfloat_t eps,
28+
const pyfloat_t eps_root,
29+
const pyuint_t count);
2930

3031
torch::Tensor adamForwardMu(const torch::Tensor& updates,
31-
const torch::Tensor& mu, const float b1);
32+
const torch::Tensor& mu, const pyfloat_t b1);
3233

3334
torch::Tensor adamForwardNu(const torch::Tensor& updates,
34-
const torch::Tensor& nu, const float b2);
35+
const torch::Tensor& nu, const pyfloat_t b2);
3536

3637
torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
37-
const torch::Tensor& new_nu, const float b1,
38-
const float b2, const float eps,
39-
const float eps_root, const int count);
38+
const torch::Tensor& new_nu,
39+
const pyfloat_t b1, const pyfloat_t b2,
40+
const pyfloat_t eps, const pyfloat_t eps_root,
41+
const pyuint_t count);
4042

4143
TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
4244
const torch::Tensor& updates,
43-
const torch::Tensor& mu, const float b1);
45+
const torch::Tensor& mu, const pyfloat_t b1);
4446

4547
TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
4648
const torch::Tensor& updates,
47-
const torch::Tensor& nu, const float b2);
49+
const torch::Tensor& nu, const pyfloat_t b2);
4850

4951
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
5052
const torch::Tensor& updates,
5153
const torch::Tensor& new_mu,
52-
const torch::Tensor& new_nu, const float b1,
53-
const float b2, const int count);
54+
const torch::Tensor& new_nu,
55+
const pyfloat_t b1, const pyfloat_t b2,
56+
const pyuint_t count);
5457
} // namespace torchopt

include/adam_op/adam_op_impl_cpu.h

+15-14
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,36 @@
2121
#include "include/common.h"
2222

2323
namespace torchopt {
24-
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates,
25-
const torch::Tensor& mu,
26-
const torch::Tensor& nu, const float b1,
27-
const float b2, const float eps,
28-
const float eps_root, const int count);
24+
TensorArray<3> adamForwardInplaceCPU(
25+
const torch::Tensor& updates, const torch::Tensor& mu,
26+
const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2,
27+
const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count);
2928

3029
torch::Tensor adamForwardMuCPU(const torch::Tensor& updates,
31-
const torch::Tensor& mu, const float b1);
30+
const torch::Tensor& mu, const pyfloat_t b1);
3231

3332
torch::Tensor adamForwardNuCPU(const torch::Tensor& updates,
34-
const torch::Tensor& nu, const float b2);
33+
const torch::Tensor& nu, const pyfloat_t b2);
3534

3635
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu,
37-
const torch::Tensor& new_nu, const float b1,
38-
const float b2, const float eps,
39-
const float eps_root, const int count);
36+
const torch::Tensor& new_nu,
37+
const pyfloat_t b1, const pyfloat_t b2,
38+
const pyfloat_t eps,
39+
const pyfloat_t eps_root,
40+
const pyuint_t count);
4041

4142
TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu,
4243
const torch::Tensor& updates,
43-
const torch::Tensor& mu, const float b1);
44+
const torch::Tensor& mu, const pyfloat_t b1);
4445

4546
TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu,
4647
const torch::Tensor& updates,
47-
const torch::Tensor& nu, const float b2);
48+
const torch::Tensor& nu, const pyfloat_t b2);
4849

4950
TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
5051
const torch::Tensor& updates,
5152
const torch::Tensor& new_mu,
5253
const torch::Tensor& new_nu,
53-
const float b1, const float b2,
54-
const int count);
54+
const pyfloat_t b1, const pyfloat_t b2,
55+
const pyuint_t count);
5556
} // namespace torchopt

include/adam_op/adam_op_impl_cuda.cuh

+14-14
Original file line numberDiff line numberDiff line change
@@ -21,36 +21,36 @@
2121
#include "include/common.h"
2222

2323
namespace torchopt {
24-
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates,
25-
const torch::Tensor &mu,
26-
const torch::Tensor &nu, const float b1,
27-
const float b2, const float eps,
28-
const float eps_root, const int count);
24+
TensorArray<3> adamForwardInplaceCUDA(
25+
const torch::Tensor &updates, const torch::Tensor &mu,
26+
const torch::Tensor &nu, const pyfloat_t b1, const pyfloat_t b2,
27+
const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count);
2928

3029
torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates,
31-
const torch::Tensor &mu, const float b1);
30+
const torch::Tensor &mu, const pyfloat_t b1);
3231

3332
torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates,
34-
const torch::Tensor &nu, const float b2);
33+
const torch::Tensor &nu, const pyfloat_t b2);
3534

3635
torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu,
3736
const torch::Tensor &new_nu,
38-
const float b1, const float b2,
39-
const float eps, const float eps_root,
40-
const int count);
37+
const pyfloat_t b1, const pyfloat_t b2,
38+
const pyfloat_t eps,
39+
const pyfloat_t eps_root,
40+
const pyuint_t count);
4141

4242
TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu,
4343
const torch::Tensor &updates,
44-
const torch::Tensor &mu, const float b1);
44+
const torch::Tensor &mu, const pyfloat_t b1);
4545

4646
TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu,
4747
const torch::Tensor &updates,
48-
const torch::Tensor &nu, const float b2);
48+
const torch::Tensor &nu, const pyfloat_t b2);
4949

5050
TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
5151
const torch::Tensor &updates,
5252
const torch::Tensor &new_mu,
5353
const torch::Tensor &new_nu,
54-
const float b1, const float b2,
55-
const int count);
54+
const pyfloat_t b1, const pyfloat_t b2,
55+
const pyuint_t count);
5656
} // namespace torchopt

include/common.h

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#include <torch/extension.h>
1818

1919
#include <array>
20+
#include <cstddef>
21+
22+
using pyfloat_t = double;
23+
using pyuint_t = std::size_t;
2024

2125
namespace torchopt {
2226
template <size_t _Nm>

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ lint = [
6666
"mypy",
6767
"flake8",
6868
"flake8-bugbear",
69-
"doc8",
69+
"doc8 < 1.0.0a0",
7070
"pydocstyle",
7171
"pyenchant",
7272
"cpplint",

src/adam_op/adam_op.cpp

+15-12
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
namespace torchopt {
2727
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
2828
const torch::Tensor& mu,
29-
const torch::Tensor& nu, const float b1,
30-
const float b2, const float eps,
31-
const float eps_root, const int count) {
29+
const torch::Tensor& nu, const pyfloat_t b1,
30+
const pyfloat_t b2, const pyfloat_t eps,
31+
const pyfloat_t eps_root,
32+
const pyuint_t count) {
3233
#if defined(__CUDACC__)
3334
if (updates.device().is_cuda()) {
3435
return adamForwardInplaceCUDA(updates, mu, nu, b1, b2, eps, eps_root,
@@ -42,7 +43,7 @@ TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
4243
}
4344
}
4445
torch::Tensor adamForwardMu(const torch::Tensor& updates,
45-
const torch::Tensor& mu, const float b1) {
46+
const torch::Tensor& mu, const pyfloat_t b1) {
4647
#if defined(__CUDACC__)
4748
if (updates.device().is_cuda()) {
4849
return adamForwardMuCUDA(updates, mu, b1);
@@ -56,7 +57,7 @@ torch::Tensor adamForwardMu(const torch::Tensor& updates,
5657
}
5758

5859
torch::Tensor adamForwardNu(const torch::Tensor& updates,
59-
const torch::Tensor& nu, const float b2) {
60+
const torch::Tensor& nu, const pyfloat_t b2) {
6061
#if defined(__CUDACC__)
6162
if (updates.device().is_cuda()) {
6263
return adamForwardNuCUDA(updates, nu, b2);
@@ -70,9 +71,10 @@ torch::Tensor adamForwardNu(const torch::Tensor& updates,
7071
}
7172

7273
torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
73-
const torch::Tensor& new_nu, const float b1,
74-
const float b2, const float eps,
75-
const float eps_root, const int count) {
74+
const torch::Tensor& new_nu,
75+
const pyfloat_t b1, const pyfloat_t b2,
76+
const pyfloat_t eps, const pyfloat_t eps_root,
77+
const pyuint_t count) {
7678
#if defined(__CUDACC__)
7779
if (new_mu.device().is_cuda()) {
7880
return adamForwardUpdatesCUDA(new_mu, new_nu, b1, b2, eps, eps_root, count);
@@ -87,7 +89,7 @@ torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
8789

8890
TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
8991
const torch::Tensor& updates,
90-
const torch::Tensor& mu, const float b1) {
92+
const torch::Tensor& mu, const pyfloat_t b1) {
9193
#if defined(__CUDACC__)
9294
if (dmu.device().is_cuda()) {
9395
return adamBackwardMuCUDA(dmu, updates, mu, b1);
@@ -102,7 +104,7 @@ TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
102104

103105
TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
104106
const torch::Tensor& updates,
105-
const torch::Tensor& nu, const float b2) {
107+
const torch::Tensor& nu, const pyfloat_t b2) {
106108
#if defined(__CUDACC__)
107109
if (dnu.device().is_cuda()) {
108110
return adamBackwardNuCUDA(dnu, updates, nu, b2);
@@ -118,8 +120,9 @@ TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
118120
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
119121
const torch::Tensor& updates,
120122
const torch::Tensor& new_mu,
121-
const torch::Tensor& new_nu, const float b1,
122-
const float b2, const int count) {
123+
const torch::Tensor& new_nu,
124+
const pyfloat_t b1, const pyfloat_t b2,
125+
const pyuint_t count) {
123126
#if defined(__CUDACC__)
124127
if (dupdates.device().is_cuda()) {
125128
return adamBackwardUpdatesCUDA(dupdates, updates, new_mu, new_nu, b1, b2,

0 commit comments

Comments
 (0)