1
1
# %%
2
- import random
3
- import torch
4
- from math import pi , tan
2
+ import scsfm
3
+ import argparse
4
+ import time
5
5
import cv2
6
6
import numpy as np
7
- from numpy .linalg import inv
8
- import matplotlib .pyplot as plt
9
- import pandas as pd
7
+ import torch
8
+ from skimage .transform import resize as imresize
9
+ import random
10
+ from math import pi , tan
11
+ # from numpy.linalg import inv
12
+ # import matplotlib.pyplot as plt
13
+ # import pandas as pd
10
14
11
15
###########################
12
16
# VISUALIZATION FUNCTIONS #
@@ -31,10 +35,12 @@ def draw(frame, imgpts):
31
35
# corner = corners[0].ravel()
32
36
imgpts = np .int32 (imgpts ).reshape (- 1 , 2 )
33
37
# draw ground floor in green
34
- frame = cv2 .drawContours (frame , [imgpts [:4 ]], - 1 , (0 , 255 , 0 ), - 3 )
38
+ frame = cv2 .drawContours (frame , [imgpts [:4 ]],
39
+ - 1 , (0 , 255 , 0 ), - 3 )
35
40
# draw pillars in blue color
36
41
for i , j in zip (range (4 ), range (4 , 8 )):
37
- frame = cv2 .line (frame , tuple (imgpts [i ]), tuple (imgpts [j ]), (255 ), 1 )
42
+ frame = cv2 .line (frame , tuple (imgpts [i ]),
43
+ tuple (imgpts [j ]), (255 ), 1 )
38
44
# draw top layer in red color
39
45
frame = cv2 .drawContours (frame , [imgpts [4 :]], - 1 , (0 , 0 , 255 ), 1 )
40
46
return frame
@@ -43,7 +49,7 @@ def draw(frame, imgpts):
43
49
####################
44
50
# FRAME PARAMETERS #
45
51
####################
46
- frame = cv2 .imread ('data/blob/blob21.jpg ' )
52
+ frame = cv2 .imread ('data/data.png ' )
47
53
height , width , channels = frame .shape
48
54
print (f'width: { width } height: { height } channels: { channels } ' )
49
55
gray = cv2 .cvtColor (frame , cv2 .COLOR_BGR2GRAY )
@@ -112,8 +118,8 @@ def draw(frame, imgpts):
112
118
113
119
# Filter by Area.
114
120
blobParams .filterByArea = True
115
- blobParams .minArea = 10 # minArea may be adjusted to suit for your experiment
116
- blobParams .maxArea = 100 # maxArea may be adjusted to suit for your experiment
121
+ blobParams .minArea = 1
122
+ blobParams .maxArea = 100
117
123
118
124
# Filter by Circularity
119
125
blobParams .filterByCircularity = True
@@ -303,3 +309,70 @@ def find_board(keypoints):
303
309
[225 , 255 , 255 ], thickness = tf , lineType = cv2 .LINE_AA )
304
310
305
311
cv2 .imwrite ('res/blobs.png' , blob_frame )
312
+
313
+
314
+ device = torch .device (
315
+ "cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
316
+
317
+
318
+ def load_tensor_image (img , resize = (256 , 320 )):
319
+
320
+ img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB )
321
+ img = img .astype (np .float32 )
322
+
323
+ if resize :
324
+ img = imresize (img , resize )
325
+
326
+ img = np .transpose (img , (2 , 0 , 1 ))
327
+ tensor_img = ((torch .from_numpy (img ).unsqueeze (
328
+ 0 ) / 255 - 0.45 ) / 0.225 ).to (device )
329
+ print (f'tensor shape { tensor_img .shape } ' )
330
+ return tensor_img
331
+
332
+
333
+ def prediction_to_visual (output , shape = (360 , 640 )):
334
+ pred_disp = output .cpu ().numpy ()[0 , 0 ]
335
+ img = 1 / pred_disp
336
+ img = imresize (img , shape ).astype (np .float32 )
337
+ return img
338
+
339
+
340
+ @torch .no_grad ()
341
+ def predict_depth ():
342
+ parser = argparse .ArgumentParser ()
343
+ parser .add_argument ('--source' , type = str ,
344
+ default = 'res/dataset1.mp4' , help = 'source' )
345
+ parser .add_argument ('--output' , type = str ,
346
+ default = 'res/scfm.csv' , help = 'source' )
347
+ opt = parser .parse_args ()
348
+ print (opt )
349
+
350
+ ################
351
+ # Load DispNet #
352
+ ################
353
+ disp_net = scsfm .DispResNet (18 , False ).to (device )
354
+ weights = torch .load ('data/weights/scfm-nyu2-test.pth.tar' )
355
+ disp_net .load_state_dict (weights ['state_dict' ])
356
+ disp_net .eval ()
357
+ frame_rgb = cv2 .cvtColor (frame , cv2 .COLOR_BGR2RGB )
358
+
359
+ # cap = cv2.VideoCapture('data/data2.mp4')
360
+ # while True:
361
+ then = time .time ()
362
+ frame_rgb = frame_rgb [:, 30 :510 ]
363
+ tgt_img = load_tensor_image (frame_rgb .copy ())
364
+ print_duration (then , 'capture and convert' )
365
+ then = time .time ()
366
+ output = disp_net (tgt_img )
367
+ print_duration (then , 'inference' )
368
+
369
+ cv2 .imshow ('frame' , frame_rgb )
370
+ cv2 .imshow ('depth' , prediction_to_visual (output ))
371
+ cv2 .waitKey (- 1 )
372
+
373
+
374
+ def print_duration (then , prefix = '' ):
375
+ print (prefix , 'took %.2f ms' % ((time .time () - then ) * 1000 ))
376
+
377
+
378
+ predict_depth ()
0 commit comments