forked from xdit-project/xDiT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflux_example.py
97 lines (85 loc) · 3.61 KB
/
flux_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import logging
import time
import torch
import torch.distributed
from xfuser import xFuserFluxPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
)
def benchmark_torch_function(iters, f, *args, **kwargs):
f(*args, **kwargs)
f(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
f(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
# elapsed_time has a resolution of 0.5 microseconds:
# but returns milliseconds, so we need to multiply it to increase resolution
return start_event.elapsed_time(end_event) * 1000 / iters, f(*args, **kwargs)
def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank
pipe = xFuserFluxPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
else:
pipe = pipe.to(f"cuda:{local_rank}")
pipe.prepare_run(input_config, max_sequence_length=256)
torch.cuda.reset_peak_memory_stats()
# start_time = time.time()
avg_et, output = benchmark_torch_function(
10,
pipe,
height=input_config.height,
width=input_config.height,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
max_sequence_length=256,
guidance_scale=0.0,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
# end_time = time.time()
# elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
dp_group_index = global_rank // dp_group_world_size
num_dp_groups = engine_config.parallel_config.dp_degree
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image.save(f"./results/flux_result_{parallel_info}_{image_rank}.png")
print(
f"image {i} saved to ./results/flux_result_{parallel_info}_{image_rank}.png"
)
if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {avg_et / 1e6:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()
if __name__ == "__main__":
main()