7
7
'CriticLossProcessor' , 'ActorLossProcessor' , 'DDPGLearner' ]
8
8
9
9
# %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 2
10
- # # Python native modules
11
- # import os
10
+ # Python native modules
12
11
from typing import Tuple ,Optional ,Callable ,Union ,Dict ,Literal ,List
13
12
from functools import partial
14
- # from typing_extensions import Literal
15
13
from copy import deepcopy
16
- # # Third party libs
14
+ # Third party libs
17
15
from fastcore .all import add_docs
18
16
import torchdata .datapipes as dp
19
17
from torchdata .dataloader2 .graph import traverse_dps ,find_dps ,DataPipe
20
- # from torchdata.dataloader2.graph import DataPipe,traverse
21
18
from torch import nn
22
- # from torch.optim import AdamW,Adam
23
19
import torch
24
- # import pandas as pd
25
- # import numpy as np
26
- # # Local modules
20
+ # Local modules
27
21
from ..core import SimpleStep
28
22
from ..pipes .core import find_dp
29
23
from ..torch_core import Module
30
24
from ..memory .experience_replay import ExperienceReplay
31
- from ..loggers .core import Record ,is_record ,not_record ,_RECORD_CATCH_LIST
32
25
from ..learner .core import LearnerBase ,LearnerHead ,StepBatcher
33
- # from fastrl.pipes.core import *
34
- # from fastrl.data.block import *
35
- # from fastrl.data.dataloader2 import *
26
+ from ..loggers .vscode_visualizers import VSCodeDataPipe
36
27
from fastrl .loggers .core import (
37
- LogCollector ,Record ,BatchCollector ,EpochCollector ,RollingTerminatedRewardCollector ,EpisodeCollector ,is_record
28
+ ProgressBarLogger ,
29
+ Record ,
30
+ BatchCollector ,
31
+ EpochCollector ,
32
+ RollingTerminatedRewardCollector ,
33
+ EpisodeCollector ,
34
+ not_record ,
35
+ _RECORD_CATCH_LIST
38
36
)
39
37
from fastrl .agents .core import (
40
38
AgentHead ,
43
41
SimpleModelRunner ,
44
42
NumpyConverter
45
43
)
46
- # from fastrl.memory.experience_replay import ExperienceReplay
47
- # from fastrl.learner.core import *
48
- # from fastrl.loggers.core import *
49
44
50
45
# %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 6
51
46
def init_xavier_uniform_weights (m :Module ,bias = 0.01 ):
@@ -798,7 +793,7 @@ def DDPGLearner(
798
793
critic :Critic ,
799
794
# A list of dls, where index=0 is the training dl.
800
795
dls ,
801
- logger_bases : Optional [ Callable ] = None ,
796
+ do_logging : bool = True ,
802
797
# The learning rate for the actor. Expected to learn slower than the critic
803
798
actor_lr :float = 1e-3 ,
804
799
# The optimizer for the actor
@@ -830,11 +825,12 @@ def DDPGLearner(
830
825
# Debug mode will output device moves
831
826
debug :bool = False
832
827
) -> LearnerHead :
833
- learner = LearnerBase (actor ,dls [0 ])
828
+ learner = LearnerBase ({ ' actor' : actor , 'critic' : critic } ,dls [0 ])
834
829
learner = BatchCollector (learner ,batches = batches )
835
830
learner = EpochCollector (learner )
836
- if logger_bases :
837
- learner = logger_bases (learner )
831
+ if do_logging :
832
+ learner = learner .dump_records ()
833
+ learner = ProgressBarLogger (learner )
838
834
learner = RollingTerminatedRewardCollector (learner )
839
835
learner = EpisodeCollector (learner ).catch_records ()
840
836
learner = ExperienceReplay (learner ,bs = bs ,max_sz = max_sz )
@@ -847,12 +843,15 @@ def DDPGLearner(
847
843
learner = ActorLossProcessor (learner ,critic ,actor ,clip_critic_grad = 5 )
848
844
learner = LossCollector (learner ,title = 'actor-loss' ).catch_records ()
849
845
learner = BasicOptStepper (learner ,actor ,actor_lr ,opt = actor_opt ,filter = True ,do_zero_grad = False )
850
- learner = LearnerHead (learner ,(actor ,critic ))
851
-
852
- # for dl in dls:
853
- # pipe_to_device(dl.datapipe,device,debug=debug)
854
-
855
- return learner
846
+ learner = LearnerHead (learner )
847
+
848
+ if len (dls )== 2 :
849
+ val_learner = LearnerBase ({'actor' :actor ,'critic' :critic },dls [1 ]).visualize_vscode ()
850
+ val_learner = BatchCollector (val_learner ,batches = batches )
851
+ val_learner = EpochCollector (val_learner ).catch_records (drop = True )
852
+ return LearnerHead ((learner ,val_learner ))
853
+ else :
854
+ return LearnerHead (learner )
856
855
857
856
DDPGLearner .__doc__ = """DDPG is a continuous action, actor-critic model, first created in
858
857
(Lillicrap et al., 2016). The critic estimates a Q value estimate, and the actor
0 commit comments