I am trying to understand the heuristic algorithm used in the memory policy. However I could not fully understand the whole logic, especially the following if statement as shown below.
|
if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): |
Some explanations or guidance will be highly appreciated.
Thanks.
I am trying to understand the heuristic algorithm used in the
memorypolicy. However I could not fully understand the whole logic, especially the followingif statementas shown below.gradient-checkpointing/memory_saving_gradients.py
Line 143 in 43444e0
Some explanations or guidance will be highly appreciated.
Thanks.