Skip to content

Commit a760c83

Browse files
waterhorse1XuehaiPanBenjamin-eecs
authored
docs(tutorials): update tutorials (#120)
Co-authored-by: Xuehai Pan <[email protected]> Co-authored-by: Benjamin-eecs <[email protected]>
1 parent b155b15 commit a760c83

13 files changed

+757
-1959
lines changed

README.md

+38-25
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,26 @@
66
<img src="https://github.com/metaopt/torchopt/raw/HEAD/image/logo-large.png" width="75%" />
77
</div>
88

9-
![Python 3.7+](https://img.shields.io/badge/Python-3.7%2B-brightgreen.svg)
10-
[![PyPI](https://img.shields.io/pypi/v/torchopt?logo=pypi)](https://pypi.org/project/torchopt)
11-
![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/torchopt/Tests?label=tests&logo=github)
12-
[![Documentation Status](https://img.shields.io/readthedocs/torchopt?logo=readthedocs)](https://torchopt.readthedocs.io)
13-
[![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=total&left_color=grey&right_color=blue&left_text=downloads)](https://pepy.tech/project/torchopt)
14-
[![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?color=brightgreen&logo=github)](https://github.com/metaopt/torchopt/stargazers)
15-
[![License](https://img.shields.io/github/license/metaopt/torchopt?label=license)](#license)
9+
<div align="center">
10+
11+
<a>![Python 3.7+](https://img.shields.io/badge/Python-3.7%2B-brightgreen.svg)</a>
12+
<a href="https://pypi.org/project/torchopt">![PyPI](https://img.shields.io/pypi/v/torchopt?logo=pypi)</a>
13+
<a href="https://github.com/metaopt/torchopt/tree/HEAD/tests">![GitHub Workflow Status](https://img.shields.io/github/workflow/status/metaopt/torchopt/Tests?label=tests&logo=github)</a>
14+
<a href="https://torchopt.readthedocs.io">![Documentation Status](https://img.shields.io/readthedocs/torchopt?logo=readthedocs)</a>
15+
<a href="https://pepy.tech/project/torchopt">![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=total&left_color=grey&right_color=blue&left_text=downloads)</a>
16+
<a href="https://github.com/metaopt/torchopt/stargazers">![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?color=brightgreen&logo=github)</a>
17+
<a href="https://github.com/metaopt/torchopt/blob/HEAD/LICENSE">![License](https://img.shields.io/github/license/metaopt/torchopt?label=license&logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCAyNCAyNCIgd2lkdGg9IjI0IiBoZWlnaHQ9IjI0IiBmaWxsPSIjZmZmZmZmIj48cGF0aCBmaWxsLXJ1bGU9ImV2ZW5vZGQiIGQ9Ik0xMi43NSAyLjc1YS43NS43NSAwIDAwLTEuNSAwVjQuNUg5LjI3NmExLjc1IDEuNzUgMCAwMC0uOTg1LjMwM0w2LjU5NiA1Ljk1N0EuMjUuMjUgMCAwMTYuNDU1IDZIMi4zNTNhLjc1Ljc1IDAgMTAwIDEuNUgzLjkzTC41NjMgMTUuMThhLjc2Mi43NjIgMCAwMC4yMS44OGMuMDguMDY0LjE2MS4xMjUuMzA5LjIyMS4xODYuMTIxLjQ1Mi4yNzguNzkyLjQzMy42OC4zMTEgMS42NjIuNjIgMi44NzYuNjJhNi45MTkgNi45MTkgMCAwMDIuODc2LS42MmMuMzQtLjE1NS42MDYtLjMxMi43OTItLjQzMy4xNS0uMDk3LjIzLS4xNTguMzEtLjIyM2EuNzUuNzUgMCAwMC4yMDktLjg3OEw1LjU2OSA3LjVoLjg4NmMuMzUxIDAgLjY5NC0uMTA2Ljk4NC0uMzAzbDEuNjk2LTEuMTU0QS4yNS4yNSAwIDAxOS4yNzUgNmgxLjk3NXYxNC41SDYuNzYzYS43NS43NSAwIDAwMCAxLjVoMTAuNDc0YS43NS43NSAwIDAwMC0xLjVIMTIuNzVWNmgxLjk3NGMuMDUgMCAuMS4wMTUuMTQuMDQzbDEuNjk3IDEuMTU0Yy4yOS4xOTcuNjMzLjMwMy45ODQuMzAzaC44ODZsLTMuMzY4IDcuNjhhLjc1Ljc1IDAgMDAuMjMuODk2Yy4wMTIuMDA5IDAgMCAuMDAyIDBhMy4xNTQgMy4xNTQgMCAwMC4zMS4yMDZjLjE4NS4xMTIuNDUuMjU2Ljc5LjRhNy4zNDMgNy4zNDMgMCAwMDIuODU1LjU2OCA3LjM0MyA3LjM0MyAwIDAwMi44NTYtLjU2OWMuMzM4LS4xNDMuNjA0LS4yODcuNzktLjM5OWEzLjUgMy41IDAgMDAuMzEtLjIwNi43NS43NSAwIDAwLjIzLS44OTZMMjAuMDcgNy41aDEuNTc4YS43NS43NSAwIDAwMC0xLjVoLTQuMTAyYS4yNS4yNSAwIDAxLS4xNC0uMDQzbC0xLjY5Ny0xLjE1NGExLjc1IDEuNzUgMCAwMC0uOTg0LS4zMDNIMTIuNzVWMi43NXpNMi4xOTMgMTUuMTk4YTUuNDE4IDUuNDE4IDAgMDAyLjU1Ny42MzUgNS40MTggNS40MTggMCAwMDIuNTU3LS42MzVMNC43NSA5LjM2OGwtMi41NTcgNS44M3ptMTQuNTEtLjAyNGMuMDgyLjA0LjE3NC4wODMuMjc1LjEyNi41My4yMjMgMS4zMDUuNDUgMi4yNzIuNDVhNS44NDYgNS44NDYgMCAwMDIuNTQ3LS41NzZMMTkuMjUgOS4zNjdsLTIuNTQ3IDUuODA3eiI+PC9wYXRoPjwvc3ZnPgo=)</a>
18+
19+
</div>
20+
21+
<p align="center">
22+
<a href="https://github.com/metaopt/torchopt#installation">Installation</a> |
23+
<a href="https://torchopt.readthedocs.io">Documentation</a> |
24+
<a href="https://github.com/metaopt/torchopt/tree/HEAD/tutorials">Tutorials</a> |
25+
<a href="https://github.com/metaopt/torchopt/tree/HEAD/examples">Examples</a> |
26+
<a href="https://arxiv.org/abs/2211.06934">Paper</a> |
27+
<a href="https://github.com/metaopt/torchopt#citing-torchopt">Citation</a>
28+
</p>
1629

1730
**TorchOpt** is an efficient library for differentiable optimization built upon [PyTorch](https://pytorch.org).
1831
TorchOpt is:
@@ -44,8 +57,8 @@ The README is organized as follows:
4457
- [Examples](#examples)
4558
- [Installation](#installation)
4659
- [Changelog](#changelog)
47-
- [The Team](#the-team)
4860
- [Citing TorchOpt](#citing-torchopt)
61+
- [The Team](#the-team)
4962
- [License](#license)
5063

5164
--------------------------------------------------------------------------------
@@ -136,7 +149,7 @@ We design a bilevel-optimization updating scheme, which can be easily extended t
136149
<img src="image/diffmode.png" width="90%" />
137150
</div>
138151

139-
As shown above, the scheme contains an outer level that has parameters $\phi$ that can be learned end-to-end through the inner level parameters solution $\theta^{\star}(\phi)$ by using the best-response derivatives $\partial \theta^{\star}(\phi) / \partial \phi$.
152+
As shown above, the scheme contains an outer level that has parameters $\phi$ that can be learned end-to-end through the inner level parameters solution $\theta^{\prime}(\phi)$ by using the best-response derivatives $\partial \theta^{\prime}(\phi) / \partial \phi$.
140153
TorchOpt supports three differentiation modes.
141154
It can be seen that the key component of this algorithm is to calculate the best-response (BR) Jacobian.
142155
From the BR-based perspective, existing gradient methods can be categorized into three groups: explicit gradient over unrolled optimization, implicit differentiation, and zero-order gradient differentiation.
@@ -176,7 +189,7 @@ meta_grads = torch.autograd.grad(loss, meta_params)
176189
#### OOP API <!-- omit in toc -->
177190

178191
TorchOpt also provides OOP API compatible with PyTorch programming style.
179-
Refer to the example and the tutorial notebook [Meta Optimizer](tutorials/3_Meta_Optimizer.ipynb), [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more guidances.
192+
Refer to the example and the tutorial notebook [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb), [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more guidances.
180193

181194
```python
182195
# Define meta and inner parameters
@@ -196,8 +209,8 @@ loss.backward()
196209

197210
### Implicit Gradient (IG)
198211

199-
By treating the solution $\theta^{\star}$ as an implicit function of $\phi$, the idea of IG is to directly get analytical best-response derivatives $\partial \theta^{\star} (\phi) / \partial \phi$ by [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem).
200-
This is suitable for algorithms when the inner-level optimal solution is achieved ${\left. \frac{\partial F (\theta, \phi)}{\partial \theta} \right\rvert}_{\theta^{\star}} = 0$ or reaches some stationary conditions $F (\theta^{\star}, \phi) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377).
212+
By treating the solution $\theta^{\prime}$ as an implicit function of $\phi$, the idea of IG is to directly get analytical best-response derivatives $\partial \theta^{\prime} (\phi) / \partial \phi$ by [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem).
213+
This is suitable for algorithms when the inner-level optimal solution is achieved ${\left. \frac{\partial F (\theta, \phi)}{\partial \theta} \right\rvert}_{\theta=\theta^{\prime}} = 0$ or reaches some stationary conditions $F (\theta^{\prime}, \phi) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377).
201214
TorchOpt offers both functional and OOP APIs for supporting both [conjugate gradient-based](https://arxiv.org/abs/1909.04630) and [Neumann series-based](https://arxiv.org/abs/1911.02590) IG methods.
202215
Refer to the example [iMAML](https://github.com/waterhorse1/torchopt/tree/readme/examples/iMAML) and the notebook [Implicit Gradient](tutorials/5_Implicit_Differentiation.ipynb) for more guidances.
203216

@@ -218,7 +231,7 @@ def solve(params, meta_params, data):
218231
# Forward optimization process for params
219232
return output
220233

221-
# Define params, meta params and get data
234+
# Define params, meta_params and get data
222235
params, meta_prams, data = ..., ..., ...
223236
optimal_params = solve(params, meta_params, data)
224237
loss = outer_loss(optimal_params)
@@ -262,10 +275,10 @@ class InnerNet(ImplicitMetaGradientModule, linear_solver):
262275
meta_params, data = ..., ...
263276
inner_net = InnerNet(meta_params)
264277

265-
# Solve for inner-loop process related with the meta parameters
278+
# Solve for inner-loop process related with the meta-parameters
266279
optimal_inner_net = inner_net.solve(data)
267280

268-
# Get outer loss and solve for meta gradient
281+
# Get outer loss and solve for meta-gradient
269282
loss = outer_loss(optimal_inner_net)
270283
meta_grads = torch.autograd.grad(loss, meta_params)
271284
```
@@ -301,10 +314,10 @@ def forward(params, batch, labels):
301314

302315
We take the optimizer as a whole instead of separating it into several basic operators (e.g., `sqrt` and `div`).
303316
Therefore, by manually writing the forward and backward functions, we can perform the symbolic reduction.
304-
In addition, we can store some intermediate data that can be reused during the back-propagation.
317+
In addition, we can store some intermediate data that can be reused during the backpropagation.
305318
We write the accelerated functions in C++ OpenMP and CUDA, bind them by [`pybind11`](https://github.com/pybind/pybind11) to allow they can be called by Python, and then we define the forward and backward behavior using `torch.autograd.Function`.
306319
Users can use by simply setting the `use_accelerated_op` flag as `True`.
307-
Refer to the corresponding sections in tutorials [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) and [Meta Optimizer](tutorials/3_Meta_Optimizer.ipynb)
320+
Refer to the corresponding sections in tutorials [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) and [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb)
308321

309322
```python
310323
optimizer = torchopt.MetaAdam(model, lr, use_accelerated_op=True)
@@ -329,7 +342,7 @@ For more guidance and comparison results, please refer to our open source projec
329342
## Visualization
330343

331344
Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it.
332-
TorchOpt provides a visualization tool that draw variable (e.g., network parameters or meta parameters) names on the gradient graph for better analyzing.
345+
TorchOpt provides a visualization tool that draw variable (e.g., network parameters or meta-parameters) names on the gradient graph for better analyzing.
333346
The visualization tool is modified from [`torchviz`](https://github.com/szagoruyko/pytorchviz).
334347
Refer to the example [visualization code](examples/visualize.py) and the tutorial notebook [Visualization](tutorials/2_Visualization.ipynb) for more details.
335348

@@ -346,10 +359,10 @@ Compared with [`torchviz`](https://github.com/szagoruyko/pytorchviz), TorchOpt f
346359

347360
In the [`examples`](examples) directory, we offer several examples of functional optimizer and light-weight meta-learning examples with TorchOpt.
348361

349-
- [Model Agnostic Meta Learning (MAML) - Supervised Learning](https://arxiv.org/abs/1703.03400) (ICML 2017)
362+
- [Model-Agnostic Meta-Learning (MAML) - Supervised Learning](https://arxiv.org/abs/1703.03400) (ICML 2017)
350363
- [Learning to Reweight Examples for Robust Deep Learning](https://arxiv.org/abs/1803.09050) (ICML 2018)
351-
- [Model Agnostic Meta Learning (MAML) - Reinforcement Learning](https://arxiv.org/abs/1703.03400) (ICML 2017)
352-
- [Meta Gradient Reinforcement Learning (MGRL)](https://arxiv.org/abs/1805.09801) (NeurIPS 2018)
364+
- [Model-Agnostic Meta-Learning (MAML) - Reinforcement Learning](https://arxiv.org/abs/1703.03400) (ICML 2017)
365+
- [Meta-Gradient Reinforcement Learning (MGRL)](https://arxiv.org/abs/1805.09801) (NeurIPS 2018)
353366
- [Learning through opponent learning process (LOLA)](https://arxiv.org/abs/1709.04326) (AAMAS 2018)
354367
- [Meta-Learning with Implicit Gradients](https://arxiv.org/abs/1909.04630) (NeurIPS 2019)
355368

@@ -408,10 +421,6 @@ See [CHANGELOG.md](CHANGELOG.md).
408421

409422
--------------------------------------------------------------------------------
410423

411-
## The Team
412-
413-
TorchOpt is a work by [Jie Ren](https://github.com/JieRen98), [Xidong Feng](https://github.com/waterhorse1), [Bo Liu](https://github.com/Benjamin-eecs), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io) and [Yaodong Yang](https://www.yangyaodong.com).
414-
415424
## Citing TorchOpt
416425

417426
If you find TorchOpt useful, please cite it in your publications.
@@ -425,6 +434,10 @@ If you find TorchOpt useful, please cite it in your publications.
425434
}
426435
```
427436

437+
## The Team
438+
439+
TorchOpt is a work by [Jie Ren](https://github.com/JieRen98), [Xidong Feng](https://github.com/waterhorse1), [Bo Liu](https://github.com/Benjamin-eecs), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io), and [Yaodong Yang](https://www.yangyaodong.com).
440+
428441
## License
429442

430443
TorchOpt is released under the Apache License, Version 2.0.

docs/source/examples/MAML.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Model-Agnostic Meta-Learning
22
============================
33

4-
Meta reinforcement learning has achieved significant successes in various applications.
4+
Meta-reinforcement learning has achieved significant successes in various applications.
55
**Model-Agnostic Meta-Learning** (MAML) :cite:`MAML` is the pioneer one.
66
In this tutorial, we will show how to train MAML on few-shot Omniglot classification with TorchOpt step by step.
77
The full script is at :gitcode:`examples/few-shot/maml_omniglot.py`.

docs/source/spelling_wordlist.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Pan
2626
Yao
2727
Fu
2828
Jupyter
29-
Colaboratory
29+
Colab
3030
Omniglot
3131
differentiable
3232
Dataset
@@ -97,6 +97,10 @@ KKT
9797
num
9898
posinf
9999
neginf
100+
backpropagated
101+
backpropagating
102+
backpropagation
103+
backprop
100104
fmt
101105
pragma
102106
broadcasted
+7-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
Get Started with Jupyter Notebook
22
=================================
33

4-
In this tutorial, we will use Google Colaboratory to show you the most basic usages of TorchOpt.
4+
In this tutorial, we will use Google Colab notebooks to show you the most basic usages of TorchOpt.
55

6-
- 1: `Functional Optimizer <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb>`_
7-
- 2: `Visualization <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/2_Visualization.ipynb>`_
8-
- 3: `Meta Optimizer <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/3_Meta_Optimizer.ipynb>`_
9-
- 4: `Stop Gradient <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/4_Stop_Gradient.ipynb>`_
10-
- 5: `Implicit Differentiation <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb>`_
6+
- 1: `Functional Optimizer <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb>`_
7+
- 2: `Visualization <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/2_Visualization.ipynb>`_
8+
- 3: `Meta-Optimizer <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/3_Meta_Optimizer.ipynb>`_
9+
- 4: `Stop Gradient <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/4_Stop_Gradient.ipynb>`_
10+
- 5: `Implicit Differentiation <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb>`_
11+
- 6: `Zero-order Differentiation <https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/6_Zero_Order_Differentiation>`_

tests/test_zero_order.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import functorch
17+
import torch
18+
import torch.nn as nn
19+
import torch.types
20+
21+
import helpers
22+
import torchopt
23+
24+
25+
BATCH_SIZE = 8
26+
NUM_UPDATES = 5
27+
28+
29+
class FcNet(nn.Module):
30+
def __init__(self, dim, out):
31+
super().__init__()
32+
self.fc = nn.Linear(in_features=dim, out_features=out, bias=True)
33+
nn.init.ones_(self.fc.weight)
34+
nn.init.zeros_(self.fc.bias)
35+
36+
def forward(self, x):
37+
return self.fc(x)
38+
39+
40+
@helpers.parametrize(
41+
dtype=[torch.float64, torch.float32],
42+
lr=[1e-2, 1e-3],
43+
method=['naive', 'forward', 'antithetic'],
44+
sigma=[0.01, 0.1, 1],
45+
)
46+
def test_zero_order(dtype: torch.dtype, lr: float, method: str, sigma: float) -> None:
47+
helpers.seed_everything(42)
48+
input_size = 32
49+
output_size = 1
50+
batch_size = BATCH_SIZE
51+
coef = 0.1
52+
num_iterations = NUM_UPDATES
53+
num_samples = 500
54+
55+
model = FcNet(input_size, output_size)
56+
57+
fmodel, params = functorch.make_functional(model)
58+
x = torch.randn(batch_size, input_size) * coef
59+
y = torch.randn(input_size) * coef
60+
distribution = torch.distributions.Normal(loc=0, scale=1)
61+
62+
@torchopt.diff.zero_order.zero_order(
63+
distribution=distribution, method=method, argnums=0, sigma=sigma, num_samples=num_samples
64+
)
65+
def forward_process(params, fn, x, y):
66+
y_pred = fn(params, x)
67+
loss = torch.mean((y - y_pred) ** 2)
68+
return loss
69+
70+
optimizer = torchopt.adam(lr=lr)
71+
opt_state = optimizer.init(params)
72+
73+
for i in range(num_iterations):
74+
opt_state = optimizer.init(params) # init optimizer
75+
loss = forward_process(params, fmodel, x, y) # compute loss
76+
77+
grads = torch.autograd.grad(loss, params) # compute gradients
78+
updates, opt_state = optimizer.update(grads, opt_state) # get updates
79+
params = torchopt.apply_updates(params, updates) # update network parameters

torchopt/transform/scale_by_adam.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def scale_by_adam(
9090
Term added to the denominator to improve numerical stability.
9191
eps_root: (default: :const:`0.0`)
9292
Term added to the denominator inside the square-root to improve
93-
numerical stability when back-propagating gradients through the rescaling.
93+
numerical stability when backpropagating gradients through the rescaling.
9494
moment_requires_grad: (default: :data:`False`)
9595
If :data:`True`, states will be created with flag `requires_grad = True`.
9696
@@ -214,7 +214,7 @@ def scale_by_accelerated_adam(
214214
Term added to the denominator to improve numerical stability.
215215
eps_root: (default: :const:`0.0`)
216216
Term added to the denominator inside the square-root to improve
217-
numerical stability when back-propagating gradients through the rescaling.
217+
numerical stability when backpropagating gradients through the rescaling.
218218
moment_requires_grad: (default: :data:`False`)
219219
If :data:`True`, states will be created with flag `requires_grad = True`.
220220

0 commit comments

Comments
 (0)