1+ # from debug import debug, mark
2+ import torch
3+ try :
4+ from torchvision .utils import save_image
5+ except :
6+ pass
7+ import numpy as np
8+ import sys
9+ import os
10+ try :
11+ from PIL import Image
12+ except :
13+ pass
14+
15+ # 开关 #################################
16+ ON_DEBUG = True # debug总开关
17+ PLAIN = False # 开启则仅普通的打印(至终端或debug.log)
18+ MAX_LOG = - 1 # 0: 不debug, -1: 无限输出 NOTE: 无输出?可能这里设成了0,或者数量不够高、没到需要输出的变量!
19+ FULL = False # 是否输出完整的tensor内容,而不用...进行省略
20+ TO_FILE = True # 是否写入debug.log
21+ PRINT = True # 是否打印至终端
22+ BUGGY = True # 便捷地debug(出现bug则进入自动进入调试模式)
23+ PEEK_LAYER = 0 # 详细打印至第几层,标准为0,建议用3
24+ SAVE_IMAGE_NORM = False # 把tensor保存成图片时是否normalize
25+ # 控制是否打印细节:debug(True/False, xxx, xxx),False则只打印形状
26+
27+ # 教程 #################################
28+ # 功能1: debug(xxx) : 用黄色字体打印出xxx的形状及具体值,debug(False, xxx)则只打印形状,不打印具体值。更多控制开关见上方。
29+ # 功能2: mark(xxx) : 标记运行到了某个位置,若有输入,则用黄色字体打印出xxx值,若仅用mark()无输入,则打印mark()所在的位置
30+ # 功能3: 在出错时跳至ipdb界面,便捷debug
31+
32+ # 实现 #################################
33+
34+ debug_count = 1
35+ debug_file = None
36+ debug_path = "super_debug"
37+ if os .path .exists (debug_path ):
38+ os .system ("rm -r " + debug_path )
39+ os .mkdir (debug_path )
40+ log_path = os .path .join (debug_path , "debug.log" )
41+ os .system ("touch " + log_path )
42+ image_count = {}
43+
44+
45+ class ExceptionHook :
46+ instance = None
47+
48+ def __call__ (self , * args , ** kwargs ):
49+ if not BUGGY :
50+ return
51+ if self .instance is None :
52+ from IPython .core import ultratb
53+ self .instance = ultratb .FormattedTB (mode = 'Plain' ,
54+ color_scheme = 'Linux' , call_pdb = 1 )
55+ return self .instance (* args , ** kwargs )
56+
57+
58+ sys .excepthook = ExceptionHook ()
59+
60+
61+ def get_pos (level = 1 , end = "\n " ):
62+ position = """"{}", line {}, in {}""" .format (
63+ sys ._getframe (level ).f_code .co_filename , # 当前文件名
64+ sys ._getframe (level ).f_lineno , # 当前行号
65+ sys ._getframe (level ).f_code .co_name , # 当前函数/module名
66+ )
67+ return position
68+
69+
70+ def print_yellow (text , end = "\n " ):
71+ print (f"\033 [1;33m{ text } \033 [0m" , end = end )
72+
73+
74+ def normalize (tensor ):
75+ max_value = torch .max (tensor )
76+ min_value = torch .min (tensor )
77+ tensor = (tensor - min_value ) / (max_value - min_value )
78+ return tensor
79+
80+ def print_image (tensor , name , is_np = False ):
81+ if name not in image_count :
82+ image_count [name ] = 0
83+ file_path = os .path .join (debug_path , f"tensor_{ debug_count } _{ name } _{ image_count [name ]} .jpg" )
84+ normallized_file_path = os .path .join (debug_path , f"tensor_{ debug_count } _{ name } _{ image_count [name ]} _norm.jpg" )
85+ image_count [name ] += 1
86+ if type (tensor ) == Image .Image :
87+ tensor .save (file_path )
88+ else :
89+ if is_np :
90+ tensor = torch .Tensor (tensor )
91+ normalized_tensor = normalize (tensor )
92+ try :
93+ if SAVE_IMAGE_NORM :
94+ save_image (normalized_tensor , normallized_file_path )
95+ else :
96+ save_image (tensor , file_path )
97+ except Exception :
98+ pass
99+ def mark (marker = None ):
100+ if marker is not None :
101+ print_yellow (marker )
102+ else :
103+ print_yellow (get_pos (level = 2 ))
104+
105+
106+ def logging (* message , end = "\n " ):
107+ """同时输出到终端和debug.log"""
108+ message = " " .join ([str (_ ) for _ in message ])
109+ if debug_file :
110+ debug_file .write (message + end )
111+ if PRINT :
112+ print_yellow (message , end = end )
113+
114+
115+ def info (i , name = "" , detail = True , layer = 0 ):
116+ """递归打印变量"""
117+ global PEEK_LAYER
118+ sep = " "
119+ if type (i ) == int or type (i ) == float :
120+ logging (sep * layer , name , "num val:" , i )
121+ else :
122+ if type (i ) == str :
123+ logging (sep * layer , name , "str:" , i )
124+ elif type (i ) == bool :
125+ logging (sep * layer , name , "bool:" , i )
126+
127+ elif type (i ) == list :
128+ logging (sep * layer , name , "list size:" , len (i ), end = "" )
129+ if layer < PEEK_LAYER and len (i ) > 0 :
130+ logging ("" )
131+ info (i [0 ], "0th item:" , detail , layer + 1 )
132+ else :
133+ logging (" val:" , i if detail else "*" )
134+ elif type (i ) == dict :
135+ logging (sep * layer , name , "dict with keys" , list (i .keys ()))
136+ for key in i :
137+ info (i [key ], key , detail , layer + 1 )
138+ elif type (i ) == tuple :
139+ logging (sep * layer , name , "tuple size:" , len (i ), "" )
140+ if layer < PEEK_LAYER and len (i ) > 0 :
141+ for no , item in enumerate (i ):
142+ info (item , str (no ) + ". " , detail , layer + 1 )
143+ else :
144+ logging (" val:" , i if detail else "*" )
145+ elif type (i ) == torch .Tensor :
146+ logging (sep * layer , name , "Tensor size:" , i .shape ,
147+ "val:" , i if detail else "*" )
148+ print_image (i , name )
149+
150+ elif type (i ) == np .ndarray :
151+ logging (sep * layer , name , "ndarray size:" , i .shape ,
152+ "val:" , i if detail else "*" )
153+ print_image (i , name , True )
154+ elif type (i ) == Image .Image :
155+ print_image (i , name )
156+ else :
157+ try :
158+ j = float (i )
159+ except Exception :
160+ logging (sep * layer , name , str (type (i )) + " with val: " , i )
161+ else :
162+ logging (sep * layer , name , "num val:" , j , type (i ))
163+
164+
165+ def debug (* args , ** kwargs ):
166+ """debug打印主入口"""
167+ global ON_DEBUG
168+ global debug_count
169+ global debug_file
170+ global TO_FILE
171+ global PLAIN
172+ if not ON_DEBUG :
173+ return
174+ if TO_FILE :
175+ debug_file = open (log_path , "a" )
176+ if PLAIN :
177+ logging (* args , ** kwargs , end = "\n " )
178+ if TO_FILE :
179+ debug_file .close ()
180+ return
181+ global FULL
182+ if FULL :
183+ torch .set_printoptions (profile = "full" )
184+ np .set_printoptions (threshold = sys .maxsize )
185+ count = 0
186+ if MAX_LOG != - 1 and debug_count >= MAX_LOG :
187+ if debug_file :
188+ debug_file .close ()
189+ return
190+ detail = True
191+ if args and type (args [0 ]) is bool :
192+ detail = args [0 ]
193+ args = args [1 :]
194+ keys = list (kwargs .keys ())
195+ logging (
196+ f"DEBUG: { len (args ) + len (kwargs )} vars: { ['?' for _ in args ] + keys } , at { get_pos (level = 2 )} " )
197+ for i in args :
198+ logging (f"{ count } / { debug_count } ." , end = " " )
199+ info (i , detail = detail )
200+ debug_count += 1
201+ count += 1
202+ for i in keys :
203+ logging (f"{ count } / { debug_count } ." , end = " " )
204+ info (kwargs [i ], i , detail = detail )
205+ count += 1
206+ debug_count += 1
207+ logging ("-------------------------------------" )
208+ if TO_FILE :
209+ debug_file .close ()
0 commit comments