You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -136,7 +149,7 @@ We design a bilevel-optimization updating scheme, which can be easily extended t
136
149
<imgsrc="image/diffmode.png"width="90%" />
137
150
</div>
138
151
139
-
As shown above, the scheme contains an outer level that has parameters $\phi$ that can be learned end-to-end through the inner level parameters solution $\theta^{\star}(\phi)$ by using the best-response derivatives $\partial \theta^{\star}(\phi) / \partial \phi$.
152
+
As shown above, the scheme contains an outer level that has parameters $\phi$ that can be learned end-to-end through the inner level parameters solution $\theta^{\prime}(\phi)$ by using the best-response derivatives $\partial \theta^{\prime}(\phi) / \partial \phi$.
140
153
TorchOpt supports three differentiation modes.
141
154
It can be seen that the key component of this algorithm is to calculate the best-response (BR) Jacobian.
142
155
From the BR-based perspective, existing gradient methods can be categorized into three groups: explicit gradient over unrolled optimization, implicit differentiation, and zero-order gradient differentiation.
TorchOpt also provides OOP API compatible with PyTorch programming style.
179
-
Refer to the example and the tutorial notebook [MetaOptimizer](tutorials/3_Meta_Optimizer.ipynb), [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more guidances.
192
+
Refer to the example and the tutorial notebook [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb), [Stop Gradient](tutorials/4_Stop_Gradient.ipynb) for more guidances.
180
193
181
194
```python
182
195
# Define meta and inner parameters
@@ -196,8 +209,8 @@ loss.backward()
196
209
197
210
### Implicit Gradient (IG)
198
211
199
-
By treating the solution $\theta^{\star}$ as an implicit function of $\phi$, the idea of IG is to directly get analytical best-response derivatives $\partial \theta^{\star} (\phi) / \partial \phi$ by [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem).
200
-
This is suitable for algorithms when the inner-level optimal solution is achieved ${\left. \frac{\partial F (\theta, \phi)}{\partial \theta} \right\rvert}_{\theta^{\star}} = 0$ or reaches some stationary conditions $F (\theta^{\star}, \phi) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377).
212
+
By treating the solution $\theta^{\prime}$ as an implicit function of $\phi$, the idea of IG is to directly get analytical best-response derivatives $\partial \theta^{\prime} (\phi) / \partial \phi$ by [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem).
213
+
This is suitable for algorithms when the inner-level optimal solution is achieved ${\left. \frac{\partial F (\theta, \phi)}{\partial \theta} \right\rvert}_{\theta=\theta^{\prime}} = 0$ or reaches some stationary conditions $F (\theta^{\prime}, \phi) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377).
201
214
TorchOpt offers both functional and OOP APIs for supporting both [conjugate gradient-based](https://arxiv.org/abs/1909.04630) and [Neumann series-based](https://arxiv.org/abs/1911.02590) IG methods.
202
215
Refer to the example [iMAML](https://github.com/waterhorse1/torchopt/tree/readme/examples/iMAML) and the notebook [Implicit Gradient](tutorials/5_Implicit_Differentiation.ipynb) for more guidances.
We take the optimizer as a whole instead of separating it into several basic operators (e.g., `sqrt` and `div`).
303
316
Therefore, by manually writing the forward and backward functions, we can perform the symbolic reduction.
304
-
In addition, we can store some intermediate data that can be reused during the back-propagation.
317
+
In addition, we can store some intermediate data that can be reused during the backpropagation.
305
318
We write the accelerated functions in C++ OpenMP and CUDA, bind them by [`pybind11`](https://github.com/pybind/pybind11) to allow they can be called by Python, and then we define the forward and backward behavior using `torch.autograd.Function`.
306
319
Users can use by simply setting the `use_accelerated_op` flag as `True`.
307
-
Refer to the corresponding sections in tutorials [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) and [MetaOptimizer](tutorials/3_Meta_Optimizer.ipynb)
320
+
Refer to the corresponding sections in tutorials [Functional Optimizer](tutorials/1_Functional_Optimizer.ipynb) and [Meta-Optimizer](tutorials/3_Meta_Optimizer.ipynb)
@@ -329,7 +342,7 @@ For more guidance and comparison results, please refer to our open source projec
329
342
## Visualization
330
343
331
344
Complex gradient flow in meta-learning brings in a great challenge for managing the gradient flow and verifying the correctness of it.
332
-
TorchOpt provides a visualization tool that draw variable (e.g., network parameters or metaparameters) names on the gradient graph for better analyzing.
345
+
TorchOpt provides a visualization tool that draw variable (e.g., network parameters or meta-parameters) names on the gradient graph for better analyzing.
333
346
The visualization tool is modified from [`torchviz`](https://github.com/szagoruyko/pytorchviz).
334
347
Refer to the example [visualization code](examples/visualize.py) and the tutorial notebook [Visualization](tutorials/2_Visualization.ipynb) for more details.
335
348
@@ -346,10 +359,10 @@ Compared with [`torchviz`](https://github.com/szagoruyko/pytorchviz), TorchOpt f
346
359
347
360
In the [`examples`](examples) directory, we offer several examples of functional optimizer and light-weight meta-learning examples with TorchOpt.
TorchOpt is a work by [Jie Ren](https://github.com/JieRen98), [Xidong Feng](https://github.com/waterhorse1), [Bo Liu](https://github.com/Benjamin-eecs), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io) and [Yaodong Yang](https://www.yangyaodong.com).
414
-
415
424
## Citing TorchOpt
416
425
417
426
If you find TorchOpt useful, please cite it in your publications.
@@ -425,6 +434,10 @@ If you find TorchOpt useful, please cite it in your publications.
425
434
}
426
435
```
427
436
437
+
## The Team
438
+
439
+
TorchOpt is a work by [Jie Ren](https://github.com/JieRen98), [Xidong Feng](https://github.com/waterhorse1), [Bo Liu](https://github.com/Benjamin-eecs), [Xuehai Pan](https://github.com/XuehaiPan), [Luo Mai](https://luomai.github.io), and [Yaodong Yang](https://www.yangyaodong.com).
440
+
428
441
## License
429
442
430
443
TorchOpt is released under the Apache License, Version 2.0.
0 commit comments