Skip to content

Commit 7093946

Browse files
committed
Update VF Picture
1 parent 0631425 commit 7093946

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

pictures/GridWorld4x5Value.png

2.92 KB
Loading

rlpy/domains/grid_world.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,13 @@ def _init_value_vis(self):
203203
self._init_arrow(name, *np.meshgrid(x, y))
204204
self.vf_fig.show()
205205

206+
def _vf_text(self, c, r, v):
207+
self.vf_texts.append(
208+
self.vf_ax.text(
209+
c - 0.2, r + 0.1, format(v, ".1f"), color="xkcd:bright blue"
210+
)
211+
)
212+
206213
def show_learning(self, representation):
207214
if self.vf_ax is None:
208215
self._init_value_vis()
@@ -233,14 +240,13 @@ def show_learning(self, representation):
233240
vmin, vmax = v.min(), v.max()
234241
for r, c in itertools.product(range(self.rows), range(self.cols)):
235242
if v[r, c] == vmin:
236-
self.vf_texts.append(
237-
self.vf_ax.text(c - 0.2, r + 0.1, format(vmin, ".1f"), color="w")
238-
)
243+
self._vf_text(c, r, vmin)
239244
elif v[r, c] == vmax:
240-
self.vf_texts.append(
241-
self.vf_ax.text(c - 0.2, r + 0.1, format(vmax, ".1f"), color="w")
242-
)
243-
v[r, c] = linear_map(v[r, c], vmin, vmax, -1, 1)
245+
self._vf_text(c, r, vmax)
246+
if v[r, c] < 0:
247+
v[r, c] = linear_map(v[r, c], min(vmin, self.MIN_RETURN), 0, -1, 0)
248+
else:
249+
v[r, c] = linear_map(v[r, c], 0, max(vmax, self.MAX_RETURN), 0, 1)
244250
# Show Value Function
245251
self.vf_img.set_data(v)
246252
# Show Policy for arrows

0 commit comments

Comments
 (0)