Skip to content

Commit f0712b6

Browse files
authored
Updated: Continous Action API (#8)
- all learners for all agents now have fit and validation capabilities - cleaned up a ton of stuff - TRPO and PPO still dont work right. I'll probably try to tacking this is another pr
1 parent b66b8f6 commit f0712b6

24 files changed

+507
-1120
lines changed

.github/workflows/fastrl-docker.yml

+16-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ on:
77
branches:
88
- main
99
- update_nbdev_docs
10-
- feature/ppo
10+
- refactor/bug-fix-api-update-and-stablize
1111

1212
jobs:
1313
build:
@@ -54,14 +54,27 @@ jobs:
5454
env:
5555
BUILD_TYPE: ${{ matrix.build_type }}
5656

57+
- name: Cache Docker layers
58+
if: always()
59+
uses: actions/cache@v2
60+
with:
61+
path: /tmp/.buildx-cache
62+
key: ${{ runner.os }}-buildx-${{ github.sha }}
63+
restore-keys: |
64+
${{ runner.os }}-buildx-
65+
66+
- name: Set up Docker Buildx
67+
uses: docker/setup-buildx-action@v3
68+
5769
- name: build and tag container
5870
run: |
5971
export DOCKER_BUILDKIT=1
6072
# We need to clear the previous docker images
6173
# docker system prune -fa
62-
docker pull ${IMAGE_NAME}:latest || true
74+
# docker pull ${IMAGE_NAME}:latest || true
6375
# docker build --build-arg BUILD=${BUILD_TYPE} \
64-
docker build --cache-from ${IMAGE_NAME}:latest --build-arg BUILD=${BUILD_TYPE} \
76+
docker buildx create --use
77+
docker buildx build --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --build-arg BUILD=${BUILD_TYPE} \
6578
-t ${IMAGE_NAME}:latest \
6679
-t ${IMAGE_NAME}:${VERSION} \
6780
-t ${IMAGE_NAME}:$(date +%F) \

fastrl/_modidx.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,9 @@
645645
'fastrl.loggers.vscode_visualizers.SimpleVSCodeVideoPlayer.show': ( '05_Logging/loggers.vscode_visualizers.html#simplevscodevideoplayer.show',
646646
'fastrl/loggers/vscode_visualizers.py'),
647647
'fastrl.loggers.vscode_visualizers.VSCodeDataPipe': ( '05_Logging/loggers.vscode_visualizers.html#vscodedatapipe',
648-
'fastrl/loggers/vscode_visualizers.py')},
648+
'fastrl/loggers/vscode_visualizers.py'),
649+
'fastrl.loggers.vscode_visualizers.VSCodeDataPipe.__new__': ( '05_Logging/loggers.vscode_visualizers.html#vscodedatapipe.__new__',
650+
'fastrl/loggers/vscode_visualizers.py')},
649651
'fastrl.memory.experience_replay': { 'fastrl.memory.experience_replay.ExperienceReplay': ( '04_Memory/memory.experience_replay.html#experiencereplay',
650652
'fastrl/memory/experience_replay.py'),
651653
'fastrl.memory.experience_replay.ExperienceReplay.__init__': ( '04_Memory/memory.experience_replay.html#experiencereplay.__init__',

fastrl/agents/ddpg.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,32 @@
77
'CriticLossProcessor', 'ActorLossProcessor', 'DDPGLearner']
88

99
# %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 2
10-
# # Python native modules
11-
# import os
10+
# Python native modules
1211
from typing import Tuple,Optional,Callable,Union,Dict,Literal,List
1312
from functools import partial
14-
# from typing_extensions import Literal
1513
from copy import deepcopy
16-
# # Third party libs
14+
# Third party libs
1715
from fastcore.all import add_docs
1816
import torchdata.datapipes as dp
1917
from torchdata.dataloader2.graph import traverse_dps,find_dps,DataPipe
20-
# from torchdata.dataloader2.graph import DataPipe,traverse
2118
from torch import nn
22-
# from torch.optim import AdamW,Adam
2319
import torch
24-
# import pandas as pd
25-
# import numpy as np
26-
# # Local modules
20+
# Local modules
2721
from ..core import SimpleStep
2822
from ..pipes.core import find_dp
2923
from ..torch_core import Module
3024
from ..memory.experience_replay import ExperienceReplay
31-
from ..loggers.core import Record,is_record,not_record,_RECORD_CATCH_LIST
3225
from ..learner.core import LearnerBase,LearnerHead,StepBatcher
33-
# from fastrl.pipes.core import *
34-
# from fastrl.data.block import *
35-
# from fastrl.data.dataloader2 import *
26+
from ..loggers.vscode_visualizers import VSCodeDataPipe
3627
from fastrl.loggers.core import (
37-
LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record
28+
ProgressBarLogger,
29+
Record,
30+
BatchCollector,
31+
EpochCollector,
32+
RollingTerminatedRewardCollector,
33+
EpisodeCollector,
34+
not_record,
35+
_RECORD_CATCH_LIST
3836
)
3937
from fastrl.agents.core import (
4038
AgentHead,
@@ -43,9 +41,6 @@
4341
SimpleModelRunner,
4442
NumpyConverter
4543
)
46-
# from fastrl.memory.experience_replay import ExperienceReplay
47-
# from fastrl.learner.core import *
48-
# from fastrl.loggers.core import *
4944

5045
# %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 6
5146
def init_xavier_uniform_weights(m:Module,bias=0.01):
@@ -798,7 +793,7 @@ def DDPGLearner(
798793
critic:Critic,
799794
# A list of dls, where index=0 is the training dl.
800795
dls,
801-
logger_bases:Optional[Callable]=None,
796+
do_logging:bool=True,
802797
# The learning rate for the actor. Expected to learn slower than the critic
803798
actor_lr:float=1e-3,
804799
# The optimizer for the actor
@@ -830,11 +825,12 @@ def DDPGLearner(
830825
# Debug mode will output device moves
831826
debug:bool=False
832827
) -> LearnerHead:
833-
learner = LearnerBase(actor,dls[0])
828+
learner = LearnerBase({'actor':actor,'critic':critic},dls[0])
834829
learner = BatchCollector(learner,batches=batches)
835830
learner = EpochCollector(learner)
836-
if logger_bases:
837-
learner = logger_bases(learner)
831+
if do_logging:
832+
learner = learner.dump_records()
833+
learner = ProgressBarLogger(learner)
838834
learner = RollingTerminatedRewardCollector(learner)
839835
learner = EpisodeCollector(learner).catch_records()
840836
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
@@ -847,12 +843,15 @@ def DDPGLearner(
847843
learner = ActorLossProcessor(learner,critic,actor,clip_critic_grad=5)
848844
learner = LossCollector(learner,title='actor-loss').catch_records()
849845
learner = BasicOptStepper(learner,actor,actor_lr,opt=actor_opt,filter=True,do_zero_grad=False)
850-
learner = LearnerHead(learner,(actor,critic))
851-
852-
# for dl in dls:
853-
# pipe_to_device(dl.datapipe,device,debug=debug)
854-
855-
return learner
846+
learner = LearnerHead(learner)
847+
848+
if len(dls)==2:
849+
val_learner = LearnerBase({'actor':actor,'critic':critic},dls[1]).visualize_vscode()
850+
val_learner = BatchCollector(val_learner,batches=batches)
851+
val_learner = EpochCollector(val_learner).catch_records(drop=True)
852+
return LearnerHead((learner,val_learner))
853+
else:
854+
return LearnerHead(learner)
856855

857856
DDPGLearner.__doc__="""DDPG is a continuous action, actor-critic model, first created in
858857
(Lillicrap et al., 2016). The critic estimates a Q value estimate, and the actor

fastrl/agents/dqn/basic.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,19 @@
1010
from collections import deque
1111
from typing import Callable,Optional,List
1212
# Third party libs
13-
from fastcore.all import ifnone
1413
import torchdata.datapipes as dp
15-
from torchdata.dataloader2 import DataLoader2
1614
from torchdata.dataloader2.graph import traverse_dps,DataPipe
1715
import torch
18-
import torch.nn.functional as F
1916
from torch import optim
2017
from torch import nn
21-
import numpy as np
2218
# Local modules
2319
from ..core import AgentHead,AgentBase
2420
from ...pipes.core import find_dp
2521
from ...memory.experience_replay import ExperienceReplay
2622
from ..core import StepFieldSelector,SimpleModelRunner,NumpyConverter
2723
from ..discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector
2824
from fastrl.loggers.core import (
29-
LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record
25+
Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,ProgressBarLogger
3026
)
3127
from ...learner.core import LearnerBase,LearnerHead,StepBatcher
3228
from ...torch_core import Module
@@ -156,7 +152,7 @@ def __iter__(self):
156152
def DQNLearner(
157153
model,
158154
dls,
159-
logger_bases:Optional[Callable]=None,
155+
do_logging:bool=True,
160156
loss_func=nn.MSELoss(),
161157
opt=optim.AdamW,
162158
lr=0.005,
@@ -169,25 +165,27 @@ def DQNLearner(
169165
learner = LearnerBase(model,dls[0])
170166
learner = BatchCollector(learner,batches=batches)
171167
learner = EpochCollector(learner)
172-
if logger_bases:
173-
learner = logger_bases(learner)
168+
if do_logging:
169+
learner = learner.dump_records()
170+
learner = ProgressBarLogger(learner)
174171
learner = RollingTerminatedRewardCollector(learner)
175172
learner = EpisodeCollector(learner)
176-
learner = learner.catch_records()
173+
learner = learner.catch_records(drop=not do_logging)
174+
177175
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz,freeze_memory=True)
178176
learner = StepBatcher(learner,device=device)
179177
learner = QCalc(learner)
180178
learner = TargetCalc(learner,nsteps=nsteps)
181179
learner = LossCalc(learner,loss_func=loss_func)
182180
learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
183-
if logger_bases:
181+
if do_logging:
184182
learner = LossCollector(learner).catch_records()
185183

186184
if len(dls)==2:
187-
val_learner = LearnerBase(model,dls[1])
185+
val_learner = LearnerBase(model,dls[1]).visualize_vscode()
188186
val_learner = BatchCollector(val_learner,batches=batches)
189187
val_learner = EpochCollector(val_learner).dump_records()
190-
learner = LearnerHead((learner,val_learner),model)
188+
learner = LearnerHead((learner,val_learner))
191189
else:
192-
learner = LearnerHead(learner,model)
190+
learner = LearnerHead(learner)
193191
return learner

fastrl/agents/dqn/categorical.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def CategoricalDQNAgent(
261261
def DQNCategoricalLearner(
262262
model,
263263
dls,
264-
logger_bases:Optional[Callable]=None,
264+
do_logging:bool=True,
265265
loss_func=PartialCrossEntropy,
266266
opt=optim.AdamW,
267267
lr=0.005,
@@ -276,29 +276,28 @@ def DQNCategoricalLearner(
276276
learner = LearnerBase(model,dls=dls[0])
277277
learner = BatchCollector(learner,batches=batches)
278278
learner = EpochCollector(learner)
279-
if logger_bases:
280-
learner = logger_bases(learner)
279+
if do_logging:
280+
learner = learner.dump_records()
281+
learner = ProgressBarLogger(learner)
281282
learner = RollingTerminatedRewardCollector(learner)
282-
learner = EpisodeCollector(learner)
283-
learner = learner.catch_records()
283+
learner = EpisodeCollector(learner).catch_records()
284284
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
285285
learner = StepBatcher(learner,device=device)
286286
learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device)
287287
# learner = TargetCalc(learner,nsteps=nsteps)
288288
learner = LossCalc(learner,loss_func=loss_func)
289289
learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
290290
learner = TargetModelUpdater(learner,target_sync=target_sync)
291-
if logger_bases:
291+
if do_logging:
292292
learner = LossCollector(learner).catch_records()
293293

294294
if len(dls)==2:
295-
val_learner = LearnerBase(model,dls[1])
295+
val_learner = LearnerBase(model,dls[1]).visualize_vscode()
296296
val_learner = BatchCollector(val_learner,batches=batches)
297297
val_learner = EpochCollector(val_learner).catch_records(drop=True)
298-
val_learner = VSCodeDataPipe(val_learner)
299-
return LearnerHead((learner,val_learner),model)
298+
return LearnerHead((learner,val_learner))
300299
else:
301-
return LearnerHead(learner,model)
300+
return LearnerHead(learner)
302301

303302
# %% ../../../nbs/07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb 49
304303
def show_q(cat_dist,title='Update Distributions'):

fastrl/agents/dqn/double.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ...loggers.core import BatchCollector,EpochCollector
1919
from ...learner.core import LearnerBase,LearnerHead
2020
from ...loggers.vscode_visualizers import VSCodeDataPipe
21+
from ...loggers.core import ProgressBarLogger
2122
from fastrl.agents.dqn.basic import (
2223
LossCollector,
2324
RollingTerminatedRewardCollector,
@@ -36,7 +37,7 @@
3637

3738
# %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 5
3839
class DoubleQCalc(dp.iter.IterDataPipe):
39-
def __init__(self,source_datapipe=None):
40+
def __init__(self,source_datapipe):
4041
self.source_datapipe = source_datapipe
4142

4243
def __iter__(self):
@@ -53,7 +54,7 @@ def __iter__(self):
5354
def DoubleDQNLearner(
5455
model,
5556
dls,
56-
logger_bases:Optional[Callable]=None,
57+
do_logging:bool=True,
5758
loss_func=nn.MSELoss(),
5859
opt=optim.AdamW,
5960
lr=0.005,
@@ -67,27 +68,25 @@ def DoubleDQNLearner(
6768
learner = LearnerBase(model,dls=dls[0])
6869
learner = BatchCollector(learner,batches=batches)
6970
learner = EpochCollector(learner)
70-
if logger_bases:
71-
learner = logger_bases(learner)
71+
if do_logging:
72+
learner = learner.dump_records()
73+
learner = ProgressBarLogger(learner)
7274
learner = RollingTerminatedRewardCollector(learner)
73-
learner = EpisodeCollector(learner)
74-
learner = learner.catch_records()
75+
learner = EpisodeCollector(learner).catch_records()
7576
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
7677
learner = StepBatcher(learner,device=device)
77-
# learner = TargetModelQCalc(learner)
7878
learner = DoubleQCalc(learner)
7979
learner = TargetCalc(learner,nsteps=nsteps)
8080
learner = LossCalc(learner,loss_func=loss_func)
8181
learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
8282
learner = TargetModelUpdater(learner,target_sync=target_sync)
83-
if logger_bases:
83+
if do_logging:
8484
learner = LossCollector(learner).catch_records()
8585

8686
if len(dls)==2:
87-
val_learner = LearnerBase(model,dls[1])
87+
val_learner = LearnerBase(model,dls[1]).visualize_vscode()
8888
val_learner = BatchCollector(val_learner,batches=batches)
8989
val_learner = EpochCollector(val_learner).catch_records(drop=True)
90-
val_learner = VSCodeDataPipe(val_learner)
91-
return LearnerHead((learner,val_learner),model)
90+
return LearnerHead((learner,val_learner))
9291
else:
93-
return LearnerHead(learner,model)
92+
return LearnerHead(learner)

fastrl/agents/dqn/dueling.py

+10-29
Original file line numberDiff line numberDiff line change
@@ -5,48 +5,29 @@
55

66
# %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 2
77
# Python native modules
8-
from copy import deepcopy
9-
from typing import Optional,Callable,Tuple
108
# Third party libs
11-
import torchdata.datapipes as dp
12-
from torchdata.dataloader2.graph import traverse_dps,DataPipe
139
import torch
14-
from torch import nn,optim
15-
# Local modulesf
16-
from ...pipes.core import find_dp
17-
from ...memory.experience_replay import ExperienceReplay
18-
from ...loggers.core import BatchCollector,EpochCollector
19-
from ...learner.core import LearnerBase,LearnerHead
20-
from ...loggers.vscode_visualizers import VSCodeDataPipe
10+
from torch import nn
11+
# Local modules
2112
from fastrl.agents.dqn.basic import (
22-
LossCollector,
23-
RollingTerminatedRewardCollector,
24-
EpisodeCollector,
25-
StepBatcher,
26-
TargetCalc,
27-
LossCalc,
28-
ModelLearnCalc,
2913
DQN,
3014
DQNAgent
3115
)
32-
from fastrl.agents.dqn.target import (
33-
TargetModelUpdater,
34-
TargetModelQCalc,
35-
DQNTargetLearner
36-
)
16+
from .target import DQNTargetLearner
3717

3818
# %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 5
3919
class DuelingHead(nn.Module):
40-
def __init__(self,
41-
hidden:int, # Input into the DuelingHead, likely a hidden layer input
42-
n_actions:int, # Number/dim of actions to output
43-
lin_cls=nn.Linear
20+
def __init__(
21+
self,
22+
hidden: int, # Input into the DuelingHead, likely a hidden layer input
23+
n_actions: int, # Number/dim of actions to output
24+
lin_cls = nn.Linear
4425
):
4526
super().__init__()
4627
self.val = lin_cls(hidden,1)
4728
self.adv = lin_cls(hidden,n_actions)
4829

4930
def forward(self,xi):
50-
val,adv=self.val(xi),self.adv(xi)
51-
xi=val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
31+
val,adv = self.val(xi),self.adv(xi)
32+
xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
5233
return xi

0 commit comments

Comments
 (0)