Skip to content

Commit 088e9c5

Browse files
committed
add code to runner for dataset generation
1 parent 2d7cc30 commit 088e9c5

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

final/tournament/runner.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
22
import numpy as np
33
from collections import namedtuple
4+
import torch
5+
import torch.nn.functional as F
46

57
TRACK_NAME = 'icy_soccer_field'
68
MAX_FRAMES = 1000
@@ -155,6 +157,25 @@ def _g(f):
155157
return ray.get(f)
156158
return f
157159

160+
def extract_peak(heatmap, max_pool_ks=7, min_score=-5, max_det=1):
161+
"""
162+
Your code here.
163+
Extract local maxima (peaks) in a 2d heatmap.
164+
@heatmap: H x W heatmap containing peaks (similar to your training heatmap)
165+
@max_pool_ks: Only return points that are larger than a max_pool_ks x max_pool_ks window around the point
166+
@min_score: Only return peaks greater than min_score
167+
@return: List of peak locations only [(score, cx, cy), ...], where cx, cy are the position of a peak and score is the
168+
heatmap value at the peak. Return no more than max_det peaks per image
169+
"""
170+
max_cls = F.max_pool2d(heatmap[None, None], kernel_size=max_pool_ks, padding=max_pool_ks // 2, stride=1)[0, 0]
171+
possible_det = heatmap - (max_cls > heatmap).float() * 1e5
172+
if max_det > possible_det.numel():
173+
max_det = possible_det.numel()
174+
score, loc = torch.topk(possible_det.view(-1), max_det)
175+
return [(int(l) % heatmap.size(1), int(l) // heatmap.size(1))
176+
for s, l in zip(score.cpu(), loc.cpu()) if s > min_score]
177+
178+
158179
def _check(self, team1, team2, where, n_iter, timeout):
159180
_, error, t1 = self._g(self._r(team1.info)())
160181
if error:
@@ -167,6 +188,11 @@ def _check(self, team1, team2, where, n_iter, timeout):
167188
logging.debug('timeout {} <? {} {}'.format(timeout, t1, t2))
168189
return t1 < timeout, t2 < timeout
169190

191+
@staticmethod
192+
def _to_image(x, proj, view):
193+
p = proj @ view @ np.array(list(x) + [1])
194+
return np.clip(np.array([p[0] / p[-1], -p[1] / p[-1]]), -1, 1)
195+
170196
def run(self, team1, team2, num_player=1, max_frames=MAX_FRAMES, max_score=3, record_fn=None, timeout=1e10,
171197
initial_ball_location=[0, 0], initial_ball_velocity=[0, 0], verbose=False):
172198
RaceConfig = self._pystk.RaceConfig
@@ -202,6 +228,7 @@ def run(self, team1, team2, num_player=1, max_frames=MAX_FRAMES, max_score=3, re
202228
race.step()
203229

204230
state = self._pystk.WorldState()
231+
205232
state.update()
206233
state.set_ball_location((initial_ball_location[0], 1, initial_ball_location[1]),
207234
(initial_ball_velocity[0], 0, initial_ball_velocity[1]))
@@ -214,10 +241,21 @@ def run(self, team1, team2, num_player=1, max_frames=MAX_FRAMES, max_score=3, re
214241
team1_state = [to_native(p) for p in state.players[0::2]]
215242
team2_state = [to_native(p) for p in state.players[1::2]]
216243
soccer_state = to_native(state.soccer)
244+
#print(soccer_state)
217245
team1_images = team2_images = None
218246
if self._use_graphics:
219247
team1_images = [np.array(race.render_data[i].image) for i in range(0, len(race.render_data), 2)]
220248
team2_images = [np.array(race.render_data[i].image) for i in range(1, len(race.render_data), 2)]
249+
250+
for render in race.render_data:
251+
r = render.instance.copy()
252+
r = r >> self._pystk.object_type_shift
253+
if 8 in r:
254+
r[r != 8] = 0
255+
print('center of puck =',self.extract_peak(r))
256+
else:
257+
print('puck not in image')
258+
#print(8 in (render.instance >> self._pystk.object_type_shift))
221259

222260
# Have each team produce actions (in parallel)
223261
if t1_can_act:
@@ -304,7 +342,9 @@ def wait(self, x):
304342
recorder = recorder & utils.StateRecorder(args.record_state)
305343

306344
# Start the match
307-
match = Match(use_graphics=team1.agent_type == 'image' or team2.agent_type == 'image')
345+
#match = Match(use_graphics=team1.agent_type == 'image' or team2.agent_type == 'image')
346+
#REMOVE THIS LINE BEFORE SUBMISSION
347+
match = Match(use_graphics=True)
308348
try:
309349
result = match.run(team1, team2, args.num_players, args.num_frames, max_score=args.max_score,
310350
initial_ball_location=args.ball_location, initial_ball_velocity=args.ball_velocity,

homework1/solution/mlp.th

-4.69 MB
Binary file not shown.

homework3/homework/models.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -131,33 +131,33 @@ def forward(self, x):
131131
u = []
132132
x = normalize(x)
133133
og_size=list(x.size())
134-
print('og size',og_size)
134+
#print('og size',og_size)
135135
og_height=x.size(dim=2)
136136
og_width=x.size(dim=3)
137137
block = self.Levels[0]
138138
x1 = block(x)
139139
x1_height=x1.size(dim=2)
140140
x1_width=x1.size(dim=3)
141-
print('size after first conv',x1.size(),'for og size',og_size)
141+
#print('size after first conv',x1.size(),'for og size',og_size)
142142
block = self.Levels[1]
143143
x2 = block(x1)
144-
print('size after second conv',x2.size(),'for og size',og_size)
144+
#print('size after second conv',x2.size(),'for og size',og_size)
145145
block = self.Levels[2]
146146
x3 = block(x2)
147-
print('size after third conv',x3.size(),'for og size',og_size)
147+
#print('size after third conv',x3.size(),'for og size',og_size)
148148

149149
x3_u = self.Upconvs[0](x3)
150-
print('size after first deconv',x3_u.size(),'for og size',og_size)
150+
#print('size after first deconv',x3_u.size(),'for og size',og_size)
151151

152152
x2_u = self.Upconvs[1](x3_u)[:,:,:x1_height,:x1_width]
153-
print('size after second deconv',x2_u.size(),'for og size',og_size)
153+
#print('size after second deconv',x2_u.size(),'for og size',og_size)
154154
x1_u = torch.cat([x1,x2_u],dim=1)
155-
print('size after torch.cat',x1_u.size())
155+
#print('size after torch.cat',x1_u.size())
156156
x1_u = self.Upconvs[1](x1_u)
157-
print('size after third deconv',x1_u.size(),'for og size',og_size)
157+
#print('size after third deconv',x1_u.size(),'for og size',og_size)
158158
x1_u = self.Upconvs[2](x1_u)[:,:,:og_height,:og_width]
159159
x = torch.cat([x,x1_u],dim=1)
160-
print('size after torch.cat',x.size())
160+
#print('size after torch.cat',x.size())
161161
return self.conv1k(x)
162162

163163

0 commit comments

Comments
 (0)