33Can be used as a CLI script, or the `train_preference_comparisons` function
44can be called directly.
55"""
6-
76import functools
87import pathlib
98from typing import Any , Mapping , Optional , Type , Union
109
10+ import numpy as np
1111import torch as th
1212from sacred .observers import FileStorageObserver
13- from stable_baselines3 .common import type_aliases
13+ from stable_baselines3 .common import type_aliases , base_class , vec_env
1414
1515from imitation .algorithms import preference_comparisons
16+ from imitation .algorithms .pebble .entropy_reward import PebbleStateEntropyReward
1617from imitation .data import types
1718from imitation .policies import serialize
19+ from imitation .rewards import reward_nets , reward_function
1820from imitation .scripts .common import common , reward
1921from imitation .scripts .common import rl as rl_common
2022from imitation .scripts .common import train
2123from imitation .scripts .config .train_preference_comparisons import (
2224 train_preference_comparisons_ex ,
2325)
26+ from imitation .util import logger as imit_logger
2427
2528
2629def save_model (
@@ -57,6 +60,59 @@ def save_checkpoint(
5760 )
5861
5962
63+ @train_preference_comparisons_ex .capture
64+ def make_reward_function (
65+ reward_net : reward_nets .RewardNet ,
66+ * ,
67+ pebble_enabled : bool = False ,
68+ pebble_nearest_neighbor_k : Optional [int ] = None ,
69+ ):
70+ relabel_reward_fn = functools .partial (
71+ reward_net .predict_processed ,
72+ update_stats = False ,
73+ )
74+ if pebble_enabled :
75+ relabel_reward_fn = PebbleStateEntropyReward (
76+ relabel_reward_fn , pebble_nearest_neighbor_k
77+ )
78+ return relabel_reward_fn
79+
80+
81+ @train_preference_comparisons_ex .capture
82+ def make_agent_trajectory_generator (
83+ venv : vec_env .VecEnv ,
84+ agent : base_class .BaseAlgorithm ,
85+ reward_net : reward_nets .RewardNet ,
86+ relabel_reward_fn : reward_function .RewardFn ,
87+ rng : np .random .Generator ,
88+ custom_logger : Optional [imit_logger .HierarchicalLogger ],
89+ * ,
90+ exploration_frac : float ,
91+ pebble_enabled : bool ,
92+ trajectory_generator_kwargs : Mapping [str , Any ],
93+ ) -> preference_comparisons .AgentTrainer :
94+ if pebble_enabled :
95+ return preference_comparisons .PebbleAgentTrainer (
96+ algorithm = agent ,
97+ reward_fn = relabel_reward_fn ,
98+ venv = venv ,
99+ exploration_frac = exploration_frac ,
100+ rng = rng ,
101+ custom_logger = custom_logger ,
102+ ** trajectory_generator_kwargs ,
103+ )
104+ else :
105+ return preference_comparisons .AgentTrainer (
106+ algorithm = agent ,
107+ reward_fn = reward_net ,
108+ venv = venv ,
109+ exploration_frac = exploration_frac ,
110+ rng = rng ,
111+ custom_logger = custom_logger ,
112+ ** trajectory_generator_kwargs ,
113+ )
114+
115+
60116@train_preference_comparisons_ex .main
61117def train_preference_comparisons (
62118 total_timesteps : int ,
@@ -83,7 +139,6 @@ def train_preference_comparisons(
83139 checkpoint_interval : int ,
84140 query_schedule : Union [str , type_aliases .Schedule ],
85141 unsupervised_agent_pretrain_frac : Optional [float ],
86- pebble_nearest_neighbor_k : Optional [int ],
87142) -> Mapping [str , Any ]:
88143 """Train a reward model using preference comparisons.
89144
@@ -146,8 +201,6 @@ def train_preference_comparisons(
146201 unsupervised_agent_pretrain_frac: fraction of total_timesteps for which the
147202 agent will be trained without preference gathering (and reward model
148203 training)
149- pebble_nearest_neighbor_k: Parameter for state entropy computation (for PEBBLE
150- training only)
151204
152205 Returns:
153206 Rollout statistics from trained policy.
@@ -160,10 +213,8 @@ def train_preference_comparisons(
160213
161214 with common .make_venv () as venv :
162215 reward_net = reward .make_reward_net (venv )
163- relabel_reward_fn = functools .partial (
164- reward_net .predict_processed ,
165- update_stats = False ,
166- )
216+ relabel_reward_fn = make_reward_function (reward_net )
217+
167218 if agent_path is None :
168219 agent = rl_common .make_rl_algo (venv , relabel_reward_fn = relabel_reward_fn )
169220 else :
@@ -176,21 +227,17 @@ def train_preference_comparisons(
176227 if trajectory_path is None :
177228 # Setting the logger here is not necessary (PreferenceComparisons takes care
178229 # of it automatically) but it avoids creating unnecessary loggers.
179- agent_trainer = preference_comparisons .AgentTrainer (
180- algorithm = agent ,
181- reward_fn = reward_net ,
230+ trajectory_generator = make_agent_trajectory_generator (
182231 venv = venv ,
183- exploration_frac = exploration_frac ,
232+ agent = agent ,
233+ reward_net = reward_net ,
234+ relabel_reward_fn = relabel_reward_fn ,
184235 rng = rng ,
185236 custom_logger = custom_logger ,
186- ** trajectory_generator_kwargs ,
187237 )
188238 # Stable Baselines will automatically occupy GPU 0 if it is available.
189239 # Let's use the same device as the SB3 agent for the reward model.
190- reward_net = reward_net .to (agent_trainer .algorithm .device )
191- trajectory_generator : preference_comparisons .TrajectoryGenerator = (
192- agent_trainer
193- )
240+ reward_net = reward_net .to (trajectory_generator .algorithm .device )
194241 else :
195242 if exploration_frac > 0 :
196243 raise ValueError (
0 commit comments