Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLLIB] Offline Training #51747

Open
dkupsh opened this issue Mar 27, 2025 · 1 comment
Open

[RLLIB] Offline Training #51747

dkupsh opened this issue Mar 27, 2025 · 1 comment
Labels
bug Something that is supposed to be working; but isn't P3 Issue moderate in impact or severity rllib RLlib related issues rllib-offline-rl Offline RL problems

Comments

@dkupsh
Copy link

dkupsh commented Mar 27, 2025

What happened + What you expected to happen

Have a few things with offline learning, using both MARWIL and PPO.

  1. Firstly, there's no documentation on saving external experiences using SingleAgentEpisode (can't directly save either). I think it would be worthwhile to have an example like this within documentation:
def save_experiences(episodes: list[SingleAgentEpisode], filename: str):
    import msgpack_numpy as mnp
    episodes = [mnp.packb(ep.get_state(), default=mnp.encode) for ep in episodes]
    episodes_ds = data.from_items(episodes)
    episodes_ds.write_parquet(
        filename,
        compression='gzip'
    )
    del episodes_ds
    episodes.clear()

Not sure if its a bug, but you need to do this code in order to properly get it to load (like can't directly load SingleAgentEpisode from OfflinePreLearner class. Right now, you can only load compressed files and there's no documentation or example on this.

  1. OfflinePreLearner's learning connector calls GAE on CPU during the initialization of Marwil Algorithm. The call-stack goes from offline_prelearner:249 to general_advantage_estimation:94. Here is the relevant offline config I'm using (alongside custom Model/Env):
.learners(
     num_learners=1,
     num_gpus_per_learner=1
).training(
      train_batch_size_per_learner=32,
).offline_data(
      input_=experiences,
      input_read_episodes=True,
      input_read_batch_size=32,
      map_batches_kwargs={"concurrency": 1},
      iter_batches_kwargs={"prefetch_batches": 1},
      dataset_num_iters_per_learner=1,
)
  1. When using BC Algorithm, the first "running" call of the model uses Numpy arrays (gathered from offline experiences). It seems like the connectors didn't run on the batch (?). Uses same config as above.

Versions / Dependencies

Ray, master branch

Reproduction script

env = gym.make("CartPole-v1")

base_config: MARWILConfig = (
    MARWILConfig()
    .environment(
        "CartPole-v1",
        action_space=env.action_space,
        observation_space=env.observation_space,
    ).learners(
        num_learners=1,
        num_gpus_per_learner=1
    ).training(
        train_batch_size_per_learner=32,
    ).offline_data(
        input_="tests/data/cartpole/cartpole-v1_large",
        map_batches_kwargs={"concurrency": 1},
        iter_batches_kwargs={"prefetch_batches": 1},
        dataset_num_iters_per_learner=1,
    )
)

callbacks = []

ray.init()
os.environ["RAY_AIR_NEW_OUTPUT"] = "0"
os.environ["RAY_verbose_spill_logs"] = "0"

Tuner(
    base_config.algo_class,
    param_space=base_config,
    run_config=RunConfig(
        stop={"training_iteration": 1000},
        callbacks=callbacks,
        progress_reporter=CLIReporter(),
    ),
).fit()

ray.shutdown()

Issue Severity

High: It blocks me from completing my task.

@dkupsh dkupsh added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Mar 27, 2025
@masoudcharkhabi masoudcharkhabi added the rllib RLlib related issues label Mar 27, 2025
@dkupsh
Copy link
Author

dkupsh commented Mar 28, 2025

Okay found the bug for #3 here.

Learner from ray.rllib.core.learner.learner.py and the make_batch_if_necessary function, this code assumes that the observation space is either a box or discrete, and doesn't work with more complex spaces (like dict or tuple):

# If we already have an `MultiAgentBatch` but with `numpy` array, convert to tensors.
elif (
      isinstance(training_data.batch, MultiAgentBatch)
      and training_data.batch.policy_batches
      and is_numpy(next(iter(training_data.batch.policy_batches.values()))["obs"])
 ): 

changing this last condition in the If statement to use a function like:

def is_numpy(x):
        if isinstance(x, dict):
            return is_numpy(list(x.values())[0])
        elif isinstance(x, tuple):
            return is_numpy(x[0])
        return isinstance(x, numpy.ndarray)

should fix this problem and allow complex input spaces be converted correctly.

@simonsays1980 simonsays1980 added P3 Issue moderate in impact or severity rllib-offline-rl Offline RL problems and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Apr 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P3 Issue moderate in impact or severity rllib RLlib related issues rllib-offline-rl Offline RL problems
Projects
None yet
Development

No branches or pull requests

3 participants