Skip to content

Commit e830a9a

Browse files
committed
Merge branch 'free_vram'
2 parents f208fae + 7c12078 commit e830a9a

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

iw3/gui.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import sys
44
import os
55
from os import path
6-
import gc
76
import traceback
87
import functools
98
from time import time
@@ -16,6 +15,7 @@
1615
create_parser, set_state_args, iw3_main,
1716
is_text, is_video, is_output_dir, is_yaml, make_output_filename,
1817
has_rembg_model)
18+
from nunif.initializer import gc_collect
1919
from nunif.device import mps_is_available, xpu_is_available
2020
from nunif.utils.image_loader import IMG_EXTENSIONS as LOADER_SUPPORTED_EXTENSIONS
2121
from nunif.utils.video import VIDEO_EXTENSIONS as KNOWN_VIDEO_EXTENSIONS, has_nvenc
@@ -1096,9 +1096,7 @@ def parse_args(self):
10961096
self.depth_model = None
10971097
self.depth_model_type = None
10981098
self.depth_model_device_id = None
1099-
gc.collect()
1100-
if torch.cuda.is_available():
1101-
torch.cuda.empty_cache()
1099+
gc_collect()
11021100

11031101
remove_bg = self.chk_rembg.GetValue()
11041102
bg_model_type = self.cbo_bg_model.GetValue()
@@ -1241,9 +1239,7 @@ def on_exit_worker(self, result):
12411239
self.update_start_button_state()
12421240

12431241
# free vram
1244-
gc.collect()
1245-
if torch.cuda.is_available:
1246-
torch.cuda.empty_cache()
1242+
gc_collect()
12471243

12481244
def on_click_btn_cancel(self, event):
12491245
self.suspend_event.set()

iw3/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import math
1111
from tqdm import tqdm
1212
from PIL import ImageDraw, Image
13+
from nunif.initializer import gc_collect
1314
from nunif.utils.image_loader import ImageLoader
1415
from nunif.utils.pil_io import load_image_simple
1516
from nunif.models import load_model # , compile_model
@@ -1673,10 +1674,12 @@ def iw3_main(args):
16731674
if not args.recursive:
16741675
image_files = ImageLoader.listdir(args.input)
16751676
process_images(image_files, args.output, args, depth_model, side_model, title="Images")
1677+
gc_collect()
16761678
for video_file in VU.list_videos(args.input):
16771679
if args.state["stop_event"] is not None and args.state["stop_event"].is_set():
16781680
return args
16791681
process_video(video_file, args.output, args, depth_model, side_model)
1682+
gc_collect()
16801683
else:
16811684
subdirs = list_subdir(args.input, include_root=True, excludes=args.output)
16821685
for input_dir in subdirs:
@@ -1685,10 +1688,12 @@ def iw3_main(args):
16851688
if image_files:
16861689
process_images(image_files, output_dir, args, depth_model, side_model,
16871690
title=path.relpath(input_dir, args.input))
1691+
gc_collect()
16881692
for video_file in VU.list_videos(input_dir):
16891693
if args.state["stop_event"] is not None and args.state["stop_event"].is_set():
16901694
return args
16911695
process_video(video_file, output_dir, args, depth_model, side_model)
1696+
gc_collect()
16921697

16931698
elif is_yaml(args.input):
16941699
config = export_config.ExportConfig.load(args.input)
@@ -1712,6 +1717,7 @@ def iw3_main(args):
17121717
if args.state["stop_event"] is not None and args.state["stop_event"].is_set():
17131718
return args
17141719
process_video(video_file, args.output, args, depth_model, side_model)
1720+
gc_collect()
17151721
elif is_video(args.input):
17161722
process_video(args.input, args.output, args, depth_model, side_model)
17171723
elif is_image(args.input):

nunif/initializer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
import numpy as np
55
import secrets
6+
import gc
67

78

89
def disable_image_lib_threads():
@@ -35,3 +36,16 @@ def set_seed(seed):
3536
torch.cuda.manual_seed_all(seed)
3637
random.seed(seed)
3738
np.random.seed(seed)
39+
40+
41+
def gc_collect():
42+
gc.collect()
43+
44+
if hasattr(torch, "_dynamo") and hasattr(torch._dynamo, "reset"):
45+
torch._dynamo.reset()
46+
if torch.cuda.is_available():
47+
torch.cuda.empty_cache()
48+
if torch.backends.mps.is_available():
49+
torch.mps.empty_cache()
50+
if hasattr(torch, "xpu") and torch.xpu.is_available():
51+
torch.xpu.empty_cache()

0 commit comments

Comments
 (0)