Skip to content

Commit 6e73e79

Browse files
committed
Support CPU training for DMC and format examples
1 parent e1a3eac commit 6e73e79

20 files changed

+1071
-381
lines changed

Diff for: docs/toy-examples.md

+350-113
Large diffs are not rendered by default.

Diff for: examples/evaluate.py

+47-8
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,15 @@
44
import argparse
55

66
import rlcard
7-
from rlcard.agents import DQNAgent, RandomAgent
8-
from rlcard.utils import get_device, set_seed, tournament
7+
from rlcard.agents import (
8+
DQNAgent,
9+
RandomAgent,
10+
)
11+
from rlcard.utils import (
12+
get_device,
13+
set_seed,
14+
tournament,
15+
)
916

1017
def load_model(model_path, env=None, position=None, device=None):
1118
if os.path.isfile(model_path): # Torch model
@@ -49,12 +56,44 @@ def evaluate(args):
4956

5057
if __name__ == '__main__':
5158
parser = argparse.ArgumentParser("Evaluation example in RLCard")
52-
parser.add_argument('--env', type=str, default='leduc-holdem',
53-
choices=['blackjack', 'leduc-holdem', 'limit-holdem', 'doudizhu', 'mahjong', 'no-limit-holdem', 'uno', 'gin-rummy'])
54-
parser.add_argument('--models', nargs='*', default=['experiments/leduc_holdem_dqn_result/model.pth', 'random'])
55-
parser.add_argument('--cuda', type=str, default='')
56-
parser.add_argument('--seed', type=int, default=42)
57-
parser.add_argument('--num_games', type=int, default=10000)
59+
parser.add_argument(
60+
'--env',
61+
type=str,
62+
default='leduc-holdem',
63+
choices=[
64+
'blackjack',
65+
'leduc-holdem',
66+
'limit-holdem',
67+
'doudizhu',
68+
'mahjong',
69+
'no-limit-holdem',
70+
'uno',
71+
'gin-rummy',
72+
],
73+
)
74+
parser.add_argument(
75+
'--models',
76+
nargs='*',
77+
default=[
78+
'experiments/leduc_holdem_dqn_result/model.pth',
79+
'random',
80+
],
81+
)
82+
parser.add_argument(
83+
'--cuda',
84+
type=str,
85+
default='',
86+
)
87+
parser.add_argument(
88+
'--seed',
89+
type=int,
90+
default=42,
91+
)
92+
parser.add_argument(
93+
'--num_games',
94+
type=int,
95+
default=10000,
96+
)
5897

5998
args = parser.parse_args()
6099

Diff for: examples/human/blackjack_human.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@
88

99
# Make environment
1010
num_players = 2
11-
env = rlcard.make('blackjack', config={'game_num_players': num_players})
11+
env = rlcard.make(
12+
'blackjack',
13+
config={
14+
'game_num_players': num_players,
15+
},
16+
)
1217
human_agent = HumanAgent(env.num_actions)
1318
random_agent = RandomAgent(env.num_actions)
14-
env.set_agents([human_agent, random_agent])
19+
env.set_agents([
20+
human_agent,
21+
random_agent,
22+
])
1523

1624
print(">> Blackjack human agent")
1725

Diff for: examples/human/gin_rummy_human.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ def make_gin_rummy_env() -> 'GinRummyEnv':
4242
# north_agent = RandomAgent(num_actions=gin_rummy_env.num_actions)
4343
north_agent = GinRummyNoviceRuleAgent()
4444
south_agent = HumanAgent(gin_rummy_env.num_actions)
45-
gin_rummy_env.set_agents([north_agent, south_agent])
45+
gin_rummy_env.set_agents([
46+
north_agent,
47+
south_agent
48+
])
4649
gin_rummy_env.game.judge.scorer = scorers.GinRummyScorer(get_payoff=scorers.get_payoff_gin_rummy_v0)
4750
return gin_rummy_env
4851

Diff for: examples/human/leduc_holdem_human.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
env = rlcard.make('leduc-holdem')
1111
human_agent = HumanAgent(env.num_actions)
1212
cfr_agent = models.load('leduc-holdem-cfr').agents[0]
13-
env.set_agents([human_agent, cfr_agent])
13+
env.set_agents([
14+
human_agent,
15+
cfr_agent,
16+
])
1417

1518
print(">> Leduc Hold'em pre-trained model")
1619

Diff for: examples/human/limit_holdem_human.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
env = rlcard.make('limit-holdem')
1111
human_agent = HumanAgent(env.num_actions)
1212
agent_0 = RandomAgent(num_actions=env.num_actions)
13-
env.set_agents([human_agent, agent_0])
13+
env.set_agents([
14+
human_agent,
15+
agent_0,
16+
])
1417

1518
print(">> Limit Hold'em random agent")
1619

Diff for: examples/human/uno_human.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
env = rlcard.make('uno')
1010
human_agent = HumanAgent(env.num_actions)
1111
cfr_agent = models.load('uno-rule-v1').agents[0]
12-
env.set_agents([human_agent, cfr_agent])
12+
env.set_agents([
13+
human_agent,
14+
cfr_agent,
15+
])
1316

1417
print(">> UNO rule model V1")
1518

Diff for: examples/pettingzoo/run_dmc.py

+77-30
Original file line numberDiff line numberDiff line change
@@ -35,42 +35,89 @@ def train(args):
3535
env.reset()
3636

3737
# Initialize the DMC trainer
38-
trainer = DMCTrainer(env,
39-
is_pettingzoo_env=True,
40-
load_model=args.load_model,
41-
xpid=args.xpid,
42-
savedir=args.savedir,
43-
save_interval=args.save_interval,
44-
num_actor_devices=args.num_actor_devices,
45-
num_actors=args.num_actors,
46-
training_device=args.training_device,
47-
total_frames=args.total_frames,
48-
)
38+
trainer = DMCTrainer(
39+
env,
40+
is_pettingzoo_env=True,
41+
load_model=args.load_model,
42+
xpid=args.xpid,
43+
savedir=args.savedir,
44+
save_interval=args.save_interval,
45+
num_actor_devices=args.num_actor_devices,
46+
num_actors=args.num_actors,
47+
training_device=args.training_device,
48+
total_frames=args.total_frames,
49+
)
4950

5051
# Train DMC Agents
5152
trainer.start()
5253

5354
if __name__ == '__main__':
5455
parser = argparse.ArgumentParser("DMC example in RLCard")
55-
parser.add_argument('--env', type=str, default='leduc-holdem',
56-
choices=['blackjack', 'leduc-holdem', 'limit-holdem', 'doudizhu', 'mahjong', 'no-limit-holdem', 'uno', 'gin-rummy'])
57-
parser.add_argument('--cuda', type=str, default='1')
58-
parser.add_argument('--load_model', action='store_true',
59-
help='Load an existing model')
60-
parser.add_argument('--xpid', default='doudizhu',
61-
help='Experiment id (default: doudizhu)')
62-
parser.add_argument('--savedir', default='experiments/dmc_result',
63-
help='Root dir where experiment data will be saved')
64-
parser.add_argument('--save_interval', default=30, type=int,
65-
help='Time interval (in minutes) at which to save the model')
66-
parser.add_argument('--num_actor_devices', default=1, type=int,
67-
help='The number of devices used for simulation')
68-
parser.add_argument('--num_actors', default=5, type=int,
69-
help='The number of actors for each simulation device')
70-
parser.add_argument('--total_frames', default=1e11, type=int,
71-
help='The total number of frames to train for')
72-
parser.add_argument('--training_device', default=0, type=int,
73-
help='The index of the GPU used for training models')
56+
parser.add_argument(
57+
'--env',
58+
type=str,
59+
default='leduc-holdem',
60+
choices=[
61+
'blackjack',
62+
'leduc-holdem',
63+
'limit-holdem',
64+
'doudizhu',
65+
'mahjong',
66+
'no-limit-holdem',
67+
'uno',
68+
'gin-rummy',
69+
]
70+
)
71+
parser.add_argument(
72+
'--cuda',
73+
type=str,
74+
default='',
75+
)
76+
parser.add_argument(
77+
'--load_model',
78+
action='store_true',
79+
help='Load an existing model',
80+
)
81+
parser.add_argument(
82+
'--xpid',
83+
default='leduc_holdem',
84+
help='Experiment id (default: leduc_holdem)',
85+
)
86+
parser.add_argument(
87+
'--savedir',
88+
default='experiments/dmc_result',
89+
help='Root dir where experiment data will be saved',
90+
)
91+
parser.add_argument(
92+
'--save_interval',
93+
default=30,
94+
type=int,
95+
help='Time interval (in minutes) at which to save the model',
96+
)
97+
parser.add_argument(
98+
'--num_actor_devices',
99+
default=1,
100+
type=int,
101+
help='The number of devices used for simulation',
102+
)
103+
parser.add_argument(
104+
'--num_actors',
105+
default=5,
106+
type=int,
107+
help='The number of actors for each simulation device',
108+
)
109+
parser.add_argument(
110+
'--total_frames',
111+
default=1e11,
112+
type=int,
113+
help='The total number of frames to train for',
114+
)
115+
parser.add_argument(
116+
'--training_device',
117+
default=0,
118+
type=int,
119+
help='The index of the GPU used for training models',
120+
)
74121

75122
args = parser.parse_args()
76123

Diff for: examples/pettingzoo/run_rl.py

+60-11
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,13 @@
1717
)
1818
from rlcard.agents.pettingzoo_agents import RandomAgentPettingZoo
1919
from rlcard.utils import (
20-
get_device, set_seed, Logger, plot_curve,
21-
run_game_pettingzoo, reorganize_pettingzoo, tournament_pettingzoo
20+
get_device,
21+
set_seed,
22+
Logger,
23+
plot_curve,
24+
run_game_pettingzoo,
25+
reorganize_pettingzoo,
26+
tournament_pettingzoo,
2227
)
2328

2429
env_name_to_env_func = {
@@ -104,15 +109,59 @@ def train(args):
104109

105110
if __name__ == '__main__':
106111
parser = argparse.ArgumentParser("DQN/NFSP example in RLCard")
107-
parser.add_argument('--env', type=str, default='leduc-holdem',
108-
choices=['leduc-holdem', 'limit-holdem', 'doudizhu', 'mahjong', 'no-limit-holdem', 'uno', 'gin-rummy'])
109-
parser.add_argument('--algorithm', type=str, default='dqn', choices=['dqn', 'nfsp'])
110-
parser.add_argument('--cuda', type=str, default='')
111-
parser.add_argument('--seed', type=int, default=42)
112-
parser.add_argument('--num_episodes', type=int, default=5000)
113-
parser.add_argument('--num_eval_games', type=int, default=2000)
114-
parser.add_argument('--evaluate_every', type=int, default=100)
115-
parser.add_argument('--log_dir', type=str, default='experiments/leduc_holdem_dqn_result/')
112+
parser.add_argument(
113+
'--env',
114+
type=str,
115+
default='leduc-holdem',
116+
choices=[
117+
'leduc-holdem',
118+
'limit-holdem',
119+
'doudizhu',
120+
'mahjong',
121+
'no-limit-holdem',
122+
'uno',
123+
'gin-rummy',
124+
],
125+
)
126+
parser.add_argument(
127+
'--algorithm',
128+
type=str,
129+
default='dqn',
130+
choices=[
131+
'dqn',
132+
'nfsp',
133+
],
134+
)
135+
parser.add_argument(
136+
'--cuda',
137+
type=str,
138+
default='',
139+
)
140+
parser.add_argument(
141+
'--seed',
142+
type=int,
143+
default=42,
144+
)
145+
parser.add_argument(
146+
'--num_episodes',
147+
type=int,
148+
default=5000,
149+
)
150+
parser.add_argument(
151+
'--num_eval_games',
152+
type=int,
153+
default=2000,
154+
)
155+
parser.add_argument(
156+
'--evaluate_every',
157+
type=int,
158+
default=100,
159+
)
160+
parser.add_argument(
161+
'--log_dir',
162+
type=str,
163+
default='experiments/leduc_holdem_dqn_result/',
164+
)
116165

117166
args = parser.parse_args()
118167

0 commit comments

Comments
 (0)