@@ -203,6 +203,13 @@ def _init_value_vis(self):
203
203
self ._init_arrow (name , * np .meshgrid (x , y ))
204
204
self .vf_fig .show ()
205
205
206
+ def _vf_text (self , c , r , v ):
207
+ self .vf_texts .append (
208
+ self .vf_ax .text (
209
+ c - 0.2 , r + 0.1 , format (v , ".1f" ), color = "xkcd:bright blue"
210
+ )
211
+ )
212
+
206
213
def show_learning (self , representation ):
207
214
if self .vf_ax is None :
208
215
self ._init_value_vis ()
@@ -233,14 +240,13 @@ def show_learning(self, representation):
233
240
vmin , vmax = v .min (), v .max ()
234
241
for r , c in itertools .product (range (self .rows ), range (self .cols )):
235
242
if v [r , c ] == vmin :
236
- self .vf_texts .append (
237
- self .vf_ax .text (c - 0.2 , r + 0.1 , format (vmin , ".1f" ), color = "w" )
238
- )
243
+ self ._vf_text (c , r , vmin )
239
244
elif v [r , c ] == vmax :
240
- self .vf_texts .append (
241
- self .vf_ax .text (c - 0.2 , r + 0.1 , format (vmax , ".1f" ), color = "w" )
242
- )
243
- v [r , c ] = linear_map (v [r , c ], vmin , vmax , - 1 , 1 )
245
+ self ._vf_text (c , r , vmax )
246
+ if v [r , c ] < 0 :
247
+ v [r , c ] = linear_map (v [r , c ], min (vmin , self .MIN_RETURN ), 0 , - 1 , 0 )
248
+ else :
249
+ v [r , c ] = linear_map (v [r , c ], 0 , max (vmax , self .MAX_RETURN ), 0 , 1 )
244
250
# Show Value Function
245
251
self .vf_img .set_data (v )
246
252
# Show Policy for arrows
0 commit comments