-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdisplay.py
89 lines (74 loc) · 3.03 KB
/
display.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Nov 21, 2020
@author: Thomas Bonald <[email protected]>
"""
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation as anim
plt.rcParams["animation.html"] = "jshtml"
import seaborn as sns
colors = sns.color_palette('colorblind')
def display_position(image, position=None, positions=None, marker='o', marker_size=200, marker_color='b', interval=200):
fig, ax = plt.subplots()
ax.axis('off')
ax.imshow(image)
if positions is not None:
y, x = positions[0]
image = ax.scatter(x, y, marker=marker, s=marker_size, c=marker_color)
def update(i):
y_, x_ = positions[i]
image.set_offsets(np.vstack((x_, y_)).T)
return anim.FuncAnimation(fig, update, frames=len(positions), interval=interval, repeat=False)
elif position is not None:
y, x = position
ax.scatter(x, y, marker=marker, s=marker_size, c=marker_color)
def display_board(image, board=None, boards=None, marker1='x', marker2='o', marker_size=200,
color1='b', color2='r', interval=200):
fig, ax = plt.subplots()
ax.axis('off')
ax.imshow(image)
if boards is not None:
board = boards[0]
y, x = np.where(board > 0)
player1 = ax.scatter(x, y, marker=marker1, s=marker_size, c=color1)
y, x = np.where(board < 0)
player2 = ax.scatter(x, y, marker=marker2, s=marker_size, c=color2)
def update(i):
board_ = boards[i]
y_, x_ = np.where(board_ > 0)
player1.set_offsets(np.vstack((x_, y_)).T)
y_, x_ = np.where(board_ < 0)
player2.set_offsets(np.vstack((x_, y_)).T)
return anim.FuncAnimation(fig, update, frames=len(boards), interval=interval, repeat=False)
elif board is not None:
y, x = np.where(board > 0)
ax.scatter(x, y, marker=marker1, s=marker_size, c=color1)
y, x = np.where(board < 0)
ax.scatter(x, y, marker=marker2, s=marker_size, c=color2)
def plot_regret(regrets, logscale=False, lb=None):
"""
regrets must be a dict {'agent_id':regret_table}
"""
reg_plot = plt.figure()
#compute useful stats
# regret_stats = {}
for i, agent_id in enumerate(regrets.keys()):
data = regrets[agent_id]
N, T = data.shape
cumdata = np.cumsum(data, axis=1) # cumulative regret
mean_reg = np.mean(cumdata, axis=0)
q_reg = np.percentile(cumdata, 10, axis=0)
Q_reg = np.percentile(cumdata, 90, axis=0)
# regret_stats[agent_id] = np.array(mean_reg, q_reg, Q_reg)
plt.plot(np.arange(T), mean_reg, color=colors[i], label=agent_id)
plt.fill_between(np.arange(T), q_reg, Q_reg, color=colors[i], alpha=0.2)
if logscale:
plt.xscale('log')
if lb is not None:
plt.plot(np.arange(T), lb, color='black', marker='*', markevery=10)
plt.xlabel('time steps')
plt.ylabel('Cumulative Regret')
plt.legend()
reg_plot.show()