Skip to content
31 changes: 21 additions & 10 deletions notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,17 @@
"unet_kwargs[\"encoder_hidden_states\"] = torch.ones((2, 154, 4096))\n",
"unet_kwargs[\"pooled_projections\"] = torch.ones((2, 2048))\n",
"\n",
"# Feature map height and width are dynamic\n",
"fm_height = torch.export.Dim(\"fm_height\", min=16, max=256)\n",
"fm_width = torch.export.Dim(\"fm_width\", min=16, max=256)\n",
"dim = torch.export.Dim(\"dim\", min=1, max=16)\n",
"fm_height = 16 * dim\n",
"fm_width = 16 * dim\n",
"\n",
"dynamic_shapes = {\"sample\": {2: fm_height, 3: fm_width}}\n",
"# iterate through the unet kwargs and set only hidden state kwarg to dynamic\n",
"dynamic_shapes_transformer = {key: (None if key != \"hidden_states\" else {2: fm_height, 3: fm_width}) for key in unet_kwargs.keys()}\n",
"\n",
"with torch.no_grad():\n",
" with disable_patching():\n",
" text_encoder = torch.export.export_for_training(\n",
Expand All @@ -308,10 +319,12 @@
" args=(text_encoder_input,),\n",
" kwargs=(text_encoder_kwargs),\n",
" ).module()\n",
" pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,)).module()\n",
" pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,)).module()\n",
" pipe.vae.decoder = torch.export.export_for_training(pipe.vae.decoder.eval(), args=(vae_decoder_input,), dynamic_shapes=dynamic_shapes).module()\n",
" pipe.vae.encoder = torch.export.export_for_training(pipe.vae.encoder.eval(), args=(vae_encoder_input,), dynamic_shapes=dynamic_shapes).module()\n",
" vae = pipe.vae\n",
" transformer = torch.export.export_for_training(pipe.transformer.eval(), args=(), kwargs=(unet_kwargs)).module()\n",
" transformer = torch.export.export_for_training(\n",
" pipe.transformer.eval(), args=(), kwargs=(unet_kwargs), dynamic_shapes=dynamic_shapes_transformer\n",
" ).module()\n",
"models_dict = {}\n",
"models_dict[\"transformer\"] = transformer\n",
"models_dict[\"vae\"] = vae\n",
Expand Down Expand Up @@ -450,8 +463,6 @@
" ).shuffle(seed=42)\n",
"\n",
" transformer_config = dict(pipe.transformer.config)\n",
" if \"model\" in transformer_config:\n",
" del transformer_config[\"model\"]\n",
" wrapped_unet = UNetWrapper(pipe.transformer.model, transformer_config)\n",
" pipe.transformer = wrapped_unet\n",
" # Run inference for data collection\n",
Expand Down Expand Up @@ -517,10 +528,10 @@
"if to_quantize:\n",
" with disable_patching():\n",
" with torch.no_grad():\n",
" nncf.compress_weights(text_encoder)\n",
" nncf.compress_weights(text_encoder_2)\n",
" nncf.compress_weights(vae_encoder)\n",
" nncf.compress_weights(vae_decoder)\n",
" text_encoder = nncf.compress_weights(text_encoder)\n",
" text_encoder_2 = nncf.compress_weights(text_encoder_2)\n",
" vae_encoder = nncf.compress_weights(vae_encoder)\n",
" vae_decoder = nncf.compress_weights(vae_decoder)\n",
" quantized_transformer = nncf.quantize(\n",
" model=original_transformer,\n",
" calibration_dataset=nncf.Dataset(unet_calibration_data),\n",
Expand Down Expand Up @@ -766,7 +777,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand Down
Loading