Skip to content

Commit 1dd0f46

Browse files
authored
Use gymnasium and reflect new API (#1152)
1 parent 508743b commit 1dd0f46

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

distributed/rpc/rl/main.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
import gym
2+
import gymnasium as gym
33
import numpy as np
44
import os
55
from itertools import count
@@ -85,7 +85,7 @@ class Observer:
8585
def __init__(self):
8686
self.id = rpc.get_worker_info().id
8787
self.env = gym.make('CartPole-v1')
88-
self.env.seed(args.seed)
88+
self.env.reset(seed=args.seed)
8989

9090
def run_episode(self, agent_rref, n_steps):
9191
r"""
@@ -95,18 +95,18 @@ def run_episode(self, agent_rref, n_steps):
9595
agent_rref (RRef): an RRef referencing the agent object.
9696
n_steps (int): number of steps in this episode
9797
"""
98-
state, ep_reward = self.env.reset(), 0
98+
state, ep_reward = self.env.reset()[0], 0
9999
for step in range(n_steps):
100100
# send the state to the agent to get an action
101101
action = _remote_method(Agent.select_action, agent_rref, self.id, state)
102102

103103
# apply the action to the environment, and get the reward
104-
state, reward, done, _ = self.env.step(action)
104+
state, reward, terminated, truncated, _ = self.env.step(action)
105105

106106
# report the reward to the agent for training purpose
107107
_remote_method(Agent.report_reward, agent_rref, self.id, reward)
108108

109-
if done:
109+
if terminated or truncated:
110110
break
111111

112112
class Agent:

distributed/rpc/rl/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
torch
22
numpy
3-
gym
3+
gymnasium

0 commit comments

Comments
 (0)