1
1
import logging
2
2
import numpy as np
3
3
from collections import namedtuple
4
+ import torch
5
+ import torch .nn .functional as F
4
6
5
7
TRACK_NAME = 'icy_soccer_field'
6
8
MAX_FRAMES = 1000
@@ -155,6 +157,25 @@ def _g(f):
155
157
return ray .get (f )
156
158
return f
157
159
160
+ def extract_peak (heatmap , max_pool_ks = 7 , min_score = - 5 , max_det = 1 ):
161
+ """
162
+ Your code here.
163
+ Extract local maxima (peaks) in a 2d heatmap.
164
+ @heatmap: H x W heatmap containing peaks (similar to your training heatmap)
165
+ @max_pool_ks: Only return points that are larger than a max_pool_ks x max_pool_ks window around the point
166
+ @min_score: Only return peaks greater than min_score
167
+ @return: List of peak locations only [(score, cx, cy), ...], where cx, cy are the position of a peak and score is the
168
+ heatmap value at the peak. Return no more than max_det peaks per image
169
+ """
170
+ max_cls = F .max_pool2d (heatmap [None , None ], kernel_size = max_pool_ks , padding = max_pool_ks // 2 , stride = 1 )[0 , 0 ]
171
+ possible_det = heatmap - (max_cls > heatmap ).float () * 1e5
172
+ if max_det > possible_det .numel ():
173
+ max_det = possible_det .numel ()
174
+ score , loc = torch .topk (possible_det .view (- 1 ), max_det )
175
+ return [(int (l ) % heatmap .size (1 ), int (l ) // heatmap .size (1 ))
176
+ for s , l in zip (score .cpu (), loc .cpu ()) if s > min_score ]
177
+
178
+
158
179
def _check (self , team1 , team2 , where , n_iter , timeout ):
159
180
_ , error , t1 = self ._g (self ._r (team1 .info )())
160
181
if error :
@@ -167,6 +188,11 @@ def _check(self, team1, team2, where, n_iter, timeout):
167
188
logging .debug ('timeout {} <? {} {}' .format (timeout , t1 , t2 ))
168
189
return t1 < timeout , t2 < timeout
169
190
191
+ @staticmethod
192
+ def _to_image (x , proj , view ):
193
+ p = proj @ view @ np .array (list (x ) + [1 ])
194
+ return np .clip (np .array ([p [0 ] / p [- 1 ], - p [1 ] / p [- 1 ]]), - 1 , 1 )
195
+
170
196
def run (self , team1 , team2 , num_player = 1 , max_frames = MAX_FRAMES , max_score = 3 , record_fn = None , timeout = 1e10 ,
171
197
initial_ball_location = [0 , 0 ], initial_ball_velocity = [0 , 0 ], verbose = False ):
172
198
RaceConfig = self ._pystk .RaceConfig
@@ -202,6 +228,7 @@ def run(self, team1, team2, num_player=1, max_frames=MAX_FRAMES, max_score=3, re
202
228
race .step ()
203
229
204
230
state = self ._pystk .WorldState ()
231
+
205
232
state .update ()
206
233
state .set_ball_location ((initial_ball_location [0 ], 1 , initial_ball_location [1 ]),
207
234
(initial_ball_velocity [0 ], 0 , initial_ball_velocity [1 ]))
@@ -214,10 +241,21 @@ def run(self, team1, team2, num_player=1, max_frames=MAX_FRAMES, max_score=3, re
214
241
team1_state = [to_native (p ) for p in state .players [0 ::2 ]]
215
242
team2_state = [to_native (p ) for p in state .players [1 ::2 ]]
216
243
soccer_state = to_native (state .soccer )
244
+ #print(soccer_state)
217
245
team1_images = team2_images = None
218
246
if self ._use_graphics :
219
247
team1_images = [np .array (race .render_data [i ].image ) for i in range (0 , len (race .render_data ), 2 )]
220
248
team2_images = [np .array (race .render_data [i ].image ) for i in range (1 , len (race .render_data ), 2 )]
249
+
250
+ for render in race .render_data :
251
+ r = render .instance .copy ()
252
+ r = r >> self ._pystk .object_type_shift
253
+ if 8 in r :
254
+ r [r != 8 ] = 0
255
+ print ('center of puck =' ,self .extract_peak (r ))
256
+ else :
257
+ print ('puck not in image' )
258
+ #print(8 in (render.instance >> self._pystk.object_type_shift))
221
259
222
260
# Have each team produce actions (in parallel)
223
261
if t1_can_act :
@@ -304,7 +342,9 @@ def wait(self, x):
304
342
recorder = recorder & utils .StateRecorder (args .record_state )
305
343
306
344
# Start the match
307
- match = Match (use_graphics = team1 .agent_type == 'image' or team2 .agent_type == 'image' )
345
+ #match = Match(use_graphics=team1.agent_type == 'image' or team2.agent_type == 'image')
346
+ #REMOVE THIS LINE BEFORE SUBMISSION
347
+ match = Match (use_graphics = True )
308
348
try :
309
349
result = match .run (team1 , team2 , args .num_players , args .num_frames , max_score = args .max_score ,
310
350
initial_ball_location = args .ball_location , initial_ball_velocity = args .ball_velocity ,
0 commit comments