Skip to content

Commit 1398bdb

Browse files
dmonitoringmodeld: follow same pattern as modeld (#36636)
* dmonitoringmodeld: follow same pattern as modeld * lint * oops * rename
1 parent b89c717 commit 1398bdb

File tree

1 file changed

+39
-26
lines changed

1 file changed

+39
-26
lines changed

selfdrive/modeld/dmonitoringmodeld.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
os.environ['DEV'] = 'QCOM' if TICI else 'CPU'
55
from tinygrad.tensor import Tensor
66
from tinygrad.dtype import dtypes
7-
import math
87
import time
98
import pickle
109
import numpy as np
@@ -18,7 +17,7 @@
1817
from openpilot.common.transformations.model import dmonitoringmodel_intrinsics
1918
from openpilot.common.transformations.camera import _ar_ox_fisheye, _os_fisheye
2019
from openpilot.selfdrive.modeld.models.commonmodel_pyx import CLContext, MonitoringModelFrame
21-
from openpilot.selfdrive.modeld.parse_model_outputs import sigmoid
20+
from openpilot.selfdrive.modeld.parse_model_outputs import sigmoid, safe_exp
2221
from openpilot.selfdrive.modeld.runners.tinygrad_helpers import qcom_tensor_from_opencl_address
2322

2423
PROCESS_NAME = "selfdrive.modeld.dmonitoringmodeld"
@@ -34,7 +33,7 @@ class ModelState:
3433
def __init__(self, cl_ctx):
3534
with open(METADATA_PATH, 'rb') as f:
3635
model_metadata = pickle.load(f)
37-
self.input_shapes = model_metadata['input_shapes']
36+
self.input_shapes = model_metadata['input_shapes']
3837
self.output_slices = model_metadata['output_slices']
3938

4039
self.frame = MonitoringModelFrame(cl_ctx)
@@ -65,32 +64,43 @@ def run(self, buf: VisionBuf, calib: np.ndarray, transform: np.ndarray) -> tuple
6564
t2 = time.perf_counter()
6665
return output, t2 - t1
6766

68-
69-
def fill_driver_state(msg, model_output, output_slices, ds_suffix):
70-
face_descs = model_output[output_slices[f'face_descs_{ds_suffix}']]
71-
face_descs_std = face_descs[-6:]
72-
msg.faceOrientation = [float(x) for x in face_descs[:3]]
73-
msg.faceOrientationStd = [math.exp(x) for x in face_descs_std[:3]]
74-
msg.facePosition = [float(x) for x in face_descs[3:5]]
75-
msg.facePositionStd = [math.exp(x) for x in face_descs_std[3:5]]
76-
msg.faceProb = float(sigmoid(model_output[output_slices[f'face_prob_{ds_suffix}']][0]))
77-
msg.leftEyeProb = float(sigmoid(model_output[output_slices[f'left_eye_prob_{ds_suffix}']][0]))
78-
msg.rightEyeProb = float(sigmoid(model_output[output_slices[f'right_eye_prob_{ds_suffix}']][0]))
79-
msg.leftBlinkProb = float(sigmoid(model_output[output_slices[f'left_blink_prob_{ds_suffix}']][0]))
80-
msg.rightBlinkProb = float(sigmoid(model_output[output_slices[f'right_blink_prob_{ds_suffix}']][0]))
81-
msg.sunglassesProb = float(sigmoid(model_output[output_slices[f'sunglasses_prob_{ds_suffix}']][0]))
82-
msg.phoneProb = float(sigmoid(model_output[output_slices[f'using_phone_prob_{ds_suffix}']][0]))
83-
84-
def get_driverstate_packet(model_output: np.ndarray, output_slices: dict[str, slice], frame_id: int, location_ts: int, exec_time: float, gpu_exec_time: float):
67+
def slice_outputs(model_outputs, output_slices):
68+
return {k: model_outputs[np.newaxis, v] for k,v in output_slices.items()}
69+
70+
def parse_model_output(model_output):
71+
parsed = {}
72+
parsed['wheel_on_right'] = sigmoid(model_output['wheel_on_right'])
73+
for ds_suffix in ['lhd', 'rhd']:
74+
face_descs = model_output[f'face_descs_{ds_suffix}']
75+
parsed[f'face_descs_{ds_suffix}'] = face_descs[:, :-6]
76+
parsed[f'face_descs_{ds_suffix}_std'] = safe_exp(face_descs[:, -6:])
77+
for key in ['face_prob', 'left_eye_prob', 'right_eye_prob','left_blink_prob', 'right_blink_prob', 'sunglasses_prob', 'using_phone_prob']:
78+
parsed[f'{key}_{ds_suffix}'] = sigmoid(model_output[f'{key}_{ds_suffix}'])
79+
return parsed
80+
81+
def fill_driver_data(msg, model_output, ds_suffix):
82+
msg.faceOrientation = model_output[f'face_descs_{ds_suffix}'][0, :3].tolist()
83+
msg.faceOrientationStd = model_output[f'face_descs_{ds_suffix}_std'][0, :3].tolist()
84+
msg.facePosition = model_output[f'face_descs_{ds_suffix}'][0, 3:5].tolist()
85+
msg.facePositionStd = model_output[f'face_descs_{ds_suffix}_std'][0, 3:5].tolist()
86+
msg.faceProb = model_output[f'face_prob_{ds_suffix}'][0, 0].item()
87+
msg.leftEyeProb = model_output[f'left_eye_prob_{ds_suffix}'][0, 0].item()
88+
msg.rightEyeProb = model_output[f'right_eye_prob_{ds_suffix}'][0, 0].item()
89+
msg.leftBlinkProb = model_output[f'left_blink_prob_{ds_suffix}'][0, 0].item()
90+
msg.rightBlinkProb = model_output[f'right_blink_prob_{ds_suffix}'][0, 0].item()
91+
msg.sunglassesProb = model_output[f'sunglasses_prob_{ds_suffix}'][0, 0].item()
92+
msg.phoneProb = model_output[f'using_phone_prob_{ds_suffix}'][0, 0].item()
93+
94+
def get_driverstate_packet(model_output, frame_id: int, location_ts: int, exec_time: float, gpu_exec_time: float):
8595
msg = messaging.new_message('driverStateV2', valid=True)
8696
ds = msg.driverStateV2
8797
ds.frameId = frame_id
8898
ds.modelExecutionTime = exec_time
8999
ds.gpuExecutionTime = gpu_exec_time
90-
ds.wheelOnRightProb = float(sigmoid(model_output[output_slices['wheel_on_right']][0]))
91-
ds.rawPredictions = model_output.tobytes() if SEND_RAW_PRED else b''
92-
fill_driver_state(ds.leftDriverData, model_output, output_slices, 'lhd')
93-
fill_driver_state(ds.rightDriverData, model_output, output_slices, 'rhd')
100+
ds.rawPredictions = model_output['raw_pred']
101+
ds.wheelOnRightProb = model_output['wheel_on_right'][0, 0].item()
102+
fill_driver_data(ds.leftDriverData, model_output, 'lhd')
103+
fill_driver_data(ds.rightDriverData, model_output, 'rhd')
94104
return msg
95105

96106

@@ -130,8 +140,11 @@ def main():
130140
t1 = time.perf_counter()
131141
model_output, gpu_execution_time = model.run(buf, calib, model_transform)
132142
t2 = time.perf_counter()
133-
134-
msg = get_driverstate_packet(model_output, model.output_slices, vipc_client.frame_id, vipc_client.timestamp_sof, t2 - t1, gpu_execution_time)
143+
raw_pred = model_output.tobytes() if SEND_RAW_PRED else b''
144+
model_output = slice_outputs(model_output, model.output_slices)
145+
model_output = parse_model_output(model_output)
146+
model_output['raw_pred'] = raw_pred
147+
msg = get_driverstate_packet(model_output, vipc_client.frame_id, vipc_client.timestamp_sof, t2 - t1, gpu_execution_time)
135148
pm.send("driverStateV2", msg)
136149

137150

0 commit comments

Comments
 (0)