File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
src/imitation/algorithms/pebble Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -35,13 +35,14 @@ def __call__(
3535
3636 all_observations = self .replay_buffer_view .observations
3737 # ReplayBuffer sampling flattens the venv dimension, let's adapt to that
38- all_observations = all_observations .reshape ((- 1 , * self .obs_shape ))
38+ all_observations = all_observations .reshape ((- 1 , * state .shape [1 :])) # TODO #625: fix self.obs_shape
39+ # TODO #625: deal with the conversion back and forth between np and torch
3940 entropies = util .compute_state_entropy (
40- state ,
41- all_observations ,
41+ th . tensor ( state ) ,
42+ th . tensor ( all_observations ) ,
4243 self .nearest_neighbor_k ,
4344 )
44- normalized_entropies = self .entropy_stats .forward (th . as_tensor ( entropies ) )
45+ normalized_entropies = self .entropy_stats .forward (entropies )
4546 return normalized_entropies .numpy ()
4647
4748 def __getstate__ (self ):
You can’t perform that action at this time.
0 commit comments