diff --git a/.github/workflows/lightning.yml b/.github/workflows/lightning.yml index 73716b26..32a1a8ab 100644 --- a/.github/workflows/lightning.yml +++ b/.github/workflows/lightning.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - lightning: ["2.1.4", "2.2.5", "2.3.3", "2.4.0"] + lightning: ["2.1.4", "2.2.5", "2.3.3", "2.4.0", "2.5.0"] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/pytorch.yml b/.github/workflows/pytorch.yml index f79f54b6..52f33714 100644 --- a/.github/workflows/pytorch.yml +++ b/.github/workflows/pytorch.yml @@ -17,10 +17,10 @@ jobs: fail-fast: false matrix: pytorch: [ + 'torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cpu', 'torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu', 'torch==2.4.1 torchvision==0.19.1 --index-url https://download.pytorch.org/whl/cpu', 'torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cpu', - 'torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cpu', ] steps: diff --git a/model_benchmark.py b/model_benchmark.py index 9fe844b2..32a82e9a 100644 --- a/model_benchmark.py +++ b/model_benchmark.py @@ -36,10 +36,7 @@ from ptlflow.utils.lightning.ptlflow_cli import PTLFlowCLI from ptlflow.utils.registry import RegisteredModel from ptlflow.utils.timer import Timer -from ptlflow.utils.utils import ( - count_parameters, - make_divisible, -) +from ptlflow.utils.utils import count_parameters NUM_COMMON_COLUMNS = 6 TABLE_KEYS_LEGENDS = { @@ -302,8 +299,8 @@ def benchmark(args: Namespace, device_handle) -> pd.DataFrame: 1, 2, 3, - make_divisible(input_size[0], model.output_stride), - make_divisible(input_size[1], model.output_stride), + input_size[0], + input_size[1], ) } @@ -372,7 +369,7 @@ def benchmark(args: Namespace, device_handle) -> pd.DataFrame: ) except Exception as e: # noqa: B902 logger.warning( - "Skipping model %s with datatype %s due to exception %s", + "Skipping model {} with datatype {} due to exception {}", mname, dtype_str, e, @@ -440,8 +437,8 @@ def estimate_inference_time( args.batch_size, 2, 3, - make_divisible(input_size[0], model.output_stride), - make_divisible(input_size[1], model.output_stride), + input_size[0], + input_size[1], ) } if torch.cuda.is_available(): diff --git a/plot_results.py b/plot_results.py index 90454e0d..4b23eb1b 100644 --- a/plot_results.py +++ b/plot_results.py @@ -19,9 +19,8 @@ import argparse import logging from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Union -import numpy as np import pandas as pd import plotly.express as px diff --git a/ptlflow/__init__.py b/ptlflow/__init__.py index 4b3ddc79..6140c9a4 100644 --- a/ptlflow/__init__.py +++ b/ptlflow/__init__.py @@ -236,7 +236,9 @@ def load_checkpoint(ckpt_path: str, model_ref: BaseModel) -> Dict[str, Any]: device = "cuda" if torch.cuda.is_available() else "cpu" if Path(ckpt_path).exists(): - ckpt = torch.load(ckpt_path, map_location=torch.device(device)) + ckpt = torch.load( + ckpt_path, map_location=torch.device(device), weights_only=True + ) else: model_dir = Path(hub.get_dir()) / "checkpoints" ckpt = hub.load_state_dict_from_url( @@ -244,6 +246,7 @@ def load_checkpoint(ckpt_path: str, model_ref: BaseModel) -> Dict[str, Any]: model_dir=model_dir, map_location=torch.device(device), check_hash=True, + weights_only=True, ) return ckpt diff --git a/ptlflow/data/datasets.py b/ptlflow/data/datasets.py index 30682d9d..0ae4bb8f 100644 --- a/ptlflow/data/datasets.py +++ b/ptlflow/data/datasets.py @@ -25,6 +25,7 @@ from loguru import logger import numpy as np import torch +import torch.nn.functional as F from torch.utils.data import Dataset from ptlflow.utils import flow_utils @@ -1662,7 +1663,6 @@ def __init__( # noqa: C901 reverse_only: bool = False, subsample: bool = False, is_image_4k: bool = False, - image_4k_split_dir_suffix: str = "_4k", ) -> None: """Initialize SintelDataset. @@ -1705,11 +1705,6 @@ def __init__( # noqa: C901 If False, and is_image_4k is True, then the groundtruth is returned in its original 4D-shaped 4K resolution, but the flow values are doubled. is_image_4k : bool, default False If True, assumes the input images will be provided in 4K resolution, instead of the original 2K. - image_4k_split_dir_suffix : str, default "_4k" - Only used when is_image_4k == True. It indicates the suffix to add to the split folder name where the 4k images are located. - For example, by default, the 4K images need to be located inside folders called "train_4k" and/or "test/4k". - The structure of these folders should be the same as the original "train" and "test". - The "*_4k" folders only need to contain the image directories, the groundtruth will still be loaded from the original locations. """ if isinstance(side_names, str): side_names = [side_names] @@ -1731,7 +1726,6 @@ def __init__( # noqa: C901 self.sequence_position = sequence_position self.subsample = subsample self.is_image_4k = is_image_4k - self.image_4k_split_dir_suffix = image_4k_split_dir_suffix if self.is_image_4k: assert not self.subsample @@ -1758,17 +1752,9 @@ def __init__( # noqa: C901 for side in side_names: for direcs in directions: rev = direcs[0] == "BW" - img_split_dir_name = ( - f"{split_dir}{self.image_4k_split_dir_suffix}" - if self.is_image_4k - else split_dir - ) image_paths = sorted( ( - Path(self.root_dir) - / img_split_dir_name - / seq_name - / f"frame_{side}" + Path(self.root_dir) / split_dir / seq_name / f"frame_{side}" ).glob("*.png"), reverse=rev, ) @@ -1883,12 +1869,38 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 if self.transform is not None: inputs = self.transform(inputs) elif self.is_image_4k: + inputs["images"] = [ + cv.resize(img, None, fx=2, fy=2, interpolation=cv.INTER_CUBIC) + for img in inputs["images"] + ] if self.transform is not None: inputs = self.transform(inputs) if "flows" in inputs: inputs["flows"] = 2 * inputs["flows"] if self.get_backward: inputs["flows_b"] = 2 * inputs["flows_b"] + + process_keys = [("flows", "valids")] + if self.get_backward: + process_keys.append(("flows_b", "valids_b")) + + for flow_key, valid_key in process_keys: + flow = inputs[flow_key] + flow_stack = rearrange( + flow, "b c (h nh) (w nw) -> b (nh nw) c h w", nh=2, nw=2 + ) + flow_stack4 = flow_stack.repeat(1, 4, 1, 1, 1) + flow_stack4 = rearrange( + flow_stack4, "b (m n) c h w -> b m n c h w", m=4 + ) + diff = flow_stack[:, :, None] - flow_stack4 + diff = rearrange(diff, "b m n c h w -> b (m n) c h w") + diff = torch.sqrt(torch.pow(diff, 2).sum(2)) + max_diff, _ = diff.max(1) + max_diff = F.interpolate( + max_diff[:, None], scale_factor=2, mode="nearest" + ) + inputs[valid_key] = (max_diff < 1.0).float() else: if self.transform is not None: inputs = self.transform(inputs) diff --git a/ptlflow/data/flow_datamodule.py b/ptlflow/data/flow_datamodule.py index bab24366..a92c9baa 100644 --- a/ptlflow/data/flow_datamodule.py +++ b/ptlflow/data/flow_datamodule.py @@ -64,6 +64,8 @@ def __init__( tartanair_root_dir: Optional[str] = None, spring_root_dir: Optional[str] = None, kubric_root_dir: Optional[str] = None, + middlebury_st_root_dir: Optional[str] = None, + viper_root_dir: Optional[str] = None, dataset_config_path: str = "./datasets.yaml", ): super().__init__() @@ -89,6 +91,8 @@ def __init__( self.tartanair_root_dir = tartanair_root_dir self.spring_root_dir = spring_root_dir self.kubric_root_dir = kubric_root_dir + self.middlebury_st_root_dir = middlebury_st_root_dir + self.viper_root_dir = viper_root_dir self.dataset_config_path = dataset_config_path self.predict_dataset_parsed = None @@ -935,6 +939,7 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset: sequence_position = "first" reverse_only = False subsample = False + is_image_4k = False side_names = [] fbocc_transform = False for v in args: @@ -952,6 +957,8 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset: sequence_position = v.split("_")[1] elif v == "sub": subsample = True + elif v == "4k": + is_image_4k = True elif v == "left": side_names.append("left") elif v == "right": @@ -1012,6 +1019,7 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset: sequence_position=sequence_position, reverse_only=reverse_only, subsample=subsample, + is_image_4k=is_image_4k, ) return dataset diff --git a/ptlflow/models/__init__.py b/ptlflow/models/__init__.py index cb475c35..fabc5d5a 100644 --- a/ptlflow/models/__init__.py +++ b/ptlflow/models/__init__.py @@ -3,6 +3,7 @@ from .csflow import * from .dicl import * from .dip import * +from .dpflow import * from .fastflownet import * from .flow1d import * from .flowformer import * diff --git a/ptlflow/models/base_model/base_model.py b/ptlflow/models/base_model/base_model.py index 6c2b1c28..c931b67d 100644 --- a/ptlflow/models/base_model/base_model.py +++ b/ptlflow/models/base_model/base_model.py @@ -69,18 +69,25 @@ def __init__( lr: Optional[float] = None, wdecay: Optional[float] = None, warm_start: bool = False, + metric_interpolate_pred_to_target_size: bool = False, ) -> None: """Initialize BaseModel. Parameters ---------- - args : Namespace - A namespace with the required arguments. Typically, this can be gotten from add_model_specific_args(). - loss_fn : Callable - A function to be used to compute the loss for the training. The input of this function must match the output of the - forward() method. The output of this function must be a tensor with a single value. output_stride : int How many times the output of the network is smaller than the input. + loss_fn : Optional[Callable] + A function to be used to compute the loss for the training. The input of this function must match the output of the + forward() method. The output of this function must be a tensor with a single value. + lr : Optional[float] + The learning rate to be used for training the model. If not provided, it will be set as 1e-4. + wdecay : Optional[float] + The weight decay to be used for training the model. If not provided, it will be set as 1e-4. + warm_start : bool, default False + If True, use warm start to initialize the flow prediction. The warm_start strategy was presented by the RAFT method and forward interpolates the prediction from the last frame. + metric_interpolate_pred_to_target_size : bool, default False + If True, the prediction is bilinearly interpolated to match the target size during metric calculation, if their sizes are different. """ super(BaseModel, self).__init__() @@ -89,13 +96,19 @@ def __init__( self.lr = lr self.wdecay = wdecay self.warm_start = warm_start + self.metric_interpolate_pred_to_target_size = ( + metric_interpolate_pred_to_target_size + ) self.train_size = None self.train_avg_length = None self.extra_params = None - self.train_metrics = FlowMetrics(prefix="train/") + self.train_metrics = FlowMetrics( + prefix="train/", + interpolate_pred_to_target_size=self.metric_interpolate_pred_to_target_size, + ) self.val_metrics = nn.ModuleList() self.val_dataset_names = [] @@ -132,6 +145,7 @@ def add_extra_param(self, name, value): def preprocess_images( self, images: torch.Tensor, + stride: Optional[int] = None, bgr_add: Union[float, Tuple[float, float, float], np.ndarray, torch.Tensor] = 0, bgr_mult: Union[ float, Tuple[float, float, float], np.ndarray, torch.Tensor @@ -201,7 +215,7 @@ def preprocess_images( if bgr_to_rgb: images = torch.flip(images, [-3]) - stride = self.output_stride + stride = self.output_stride if stride is None else stride if target_size is not None: stride = None @@ -371,7 +385,10 @@ def validation_step( """ if len(self.val_metrics) <= dataloader_idx: self.val_metrics.append( - FlowMetrics(prefix="val/").to(device=batch["flows"].device) + FlowMetrics( + prefix="val/", + interpolate_pred_to_target_size=self.metric_interpolate_pred_to_target_size, + ).to(device=batch["flows"].device) ) self.val_dataset_names.append(None) diff --git a/ptlflow/models/dpflow/LICENSE b/ptlflow/models/dpflow/LICENSE new file mode 100644 index 00000000..c640737d --- /dev/null +++ b/ptlflow/models/dpflow/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2025 Henrique Morimitsu + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/ptlflow/models/dpflow/README.md b/ptlflow/models/dpflow/README.md new file mode 100644 index 00000000..293432ab --- /dev/null +++ b/ptlflow/models/dpflow/README.md @@ -0,0 +1,182 @@ +# DPFlow + +Source code for the CVPR 2025 paper: + +> DPFlow: Adaptive Optical Flow Estimation with a Dual-Pyramid Framework.
+> Henrique Morimitsu, Xiaobin Zhu, Roberto M. Cesar Jr, Xiangyang Ji, and Xu-Cheng Yin. + +The code and download links for the Kubric-NK dataset are available at [https://github.com/hmorimitsu/kubric-nk](https://github.com/hmorimitsu/kubric-nk). + +![Teaser for the paper. Showing qualitativa and quantitative optical flow results at resolution varying from 1K to 8K.](assets/teaser.jpg) + +## Abstract + +Optical flow estimation is essential for video processing tasks, such as restoration and action recognition. +The quality of videos is constantly increasing, with current standards reaching 8K resolution. +However, optical flow methods are usually designed for low resolution and do not generalize to large inputs due to their rigid architectures. +They adopt downscaling or input tiling to reduce the input size, causing a loss of details and global information. +There is also a lack of optical flow benchmarks to judge the actual performance of existing methods on high-resolution samples. +Previous works only conducted qualitative high-resolution evaluations on hand-picked samples. +This paper fills this gap in optical flow estimation in two ways. +We propose DPFlow, an adaptive optical flow architecture capable of generalizing up to 8K resolution inputs while trained with only low-resolution samples. +We also introduce Kubric-NK, a new benchmark for evaluating optical flow methods with input resolutions ranging from 1K to 8K. +Our high-resolution evaluation pushes the boundaries of existing methods and reveals new insights about their generalization capabilities. +Extensive experimental results show that DPFlow achieves state-of-the-art results on the MPI-Sintel, KITTI 2015, Spring, and other high-resolution benchmarks. + +## Installation + +Follow the [PTLFlow installation instructions](https://ptlflow.readthedocs.io/en/latest/starting/installation.html). + +This model can be called using the following names: `dpflow`. + +The exact versions of the packages we used for our tests are listed in [requirements.txt](requirements.txt). + +## Data + +Our model uses the following datasets. Download and unpack them according to their respective instructions and then configure the paths in `datasets.yml` (see [PTLFlow installation instructions](https://ptlflow.readthedocs.io/en/latest/starting/installation.html)). + +### Training datasets + +- [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html) +- [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) +- [MPI-Sintel](http://sintel.is.tue.mpg.de) +- [KITTI 2015](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) +- [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) +- [Spring](https://spring-benchmark.org/) + +### Validation/test datasets + +- [MPI-Sintel](http://sintel.is.tue.mpg.de) +- [KITTI 2015](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) +- [Spring](https://spring-benchmark.org/) +- [Middlebury-ST](https://vision.middlebury.edu/stereo/data/scenes2014/) +- [VIPER](https://playing-for-benchmarks.org/) +- [Kubric-NK](https://github.com/hmorimitsu/kubric-nk) + +## Training + +Follow the [PTLFlow training instructions](https://ptlflow.readthedocs.io/en/latest/starting/training.html). + +We train our model in four stages as follows. + +### Stage 1: FlyingChairs + +```bash +python train.py --config ptlflow/models/dpflow/configs/dpflow-train1-chairs.yaml +``` + +### Stage 2: FlyingThings3D + +```bash +python train.py --config ptlflow/models/dpflow/configs/dpflow-train2-things.yaml +``` + +### Stage 3: FlyingThings3D+Sintel+KITTI+HD1K +```bash +python train.py --config ptlflow/models/dpflow/configs/dpflow-train3-sintel.yaml +``` + +### Stage 4a: KITTI 2015 +```bash +python train.py --config ptlflow/models/dpflow/configs/dpflow-train4a-kitti.yaml +``` + +### Stage 4b: Spring +```bash +python train.py --config ptlflow/models/dpflow/configs/dpflow-train4b-spring.yaml +``` + +## Validation + +To validate our model on the training sets of Sintel and KITTI, use the following command at the root folder of PTLFlow: + +```bash +python validate.py --config ptlflow/models/dpflow/configs/dpflow-validate.yaml --ckpt things --data.val_dataset sintel-clean+sintel-final+kitti-2015 +``` + +It should generate the following results: + +| Dataset | EPE | Outlier | px1 | WAUC | +|--------------|------|---------|------|------| +| Sintel clean | 1.02 | 3.16 | 7.85 | 90.0 | +| Sintel final | 2.27 | 6.46 | 13.0 | 85.2 | +| KITTI 2015 | 3.39 | 11.1 | 29.1 | 70.2 | + +### Middlebury-ST + +```bash +python validate.py --config ptlflow/models/dpflow/configs/dpflow-validate.yaml --ckpt things --data.val_dataset middlebury_st +``` + +### VIPER + +```bash +python validate.py --config ptlflow/models/dpflow/configs/dpflow-validate.yaml --ckpt sintel --data.val_dataset viper +``` + +### Spring + +```bash +python validate.py --config ptlflow/models/dpflow/configs/dpflow-validate.yaml --ckpt sintel --data.val_dataset spring-left +``` + +### Spring (4k) + +```bash +python validate.py --config ptlflow/models/dpflow/configs/dpflow-validate.yaml --ckpt sintel --data.val_dataset spring-left-4k +``` + +### Kubic-NK + +```bash +python validate.py --config ptlflow/models/dpflow/configs/dpflow-validate.yaml --ckpt sintel --data.val_dataset kubric --kubric_root_dir /path/to/kubric-nk/1k +``` + +To validate on other resolutions, just replace the path in `--kubric_root_dir` to the respective folder containing the data at another resolution. + +## Test + +The results submitted to the public benchmarks are generated with the respective commands below. + +### MPI-Sintel + +```bash +python test.py --config ptlflow/models/dpflow/configs/dpflow-test.yaml --ckpt sintel --data.test_dataset sintel +``` + +### KITTI 2015 + +```bash +python test.py --config ptlflow/models/dpflow/configs/dpflow-test.yaml --ckpt kitti --data.test_dataset kitti-2015 +``` + +### Spring + +```bash +python test.py --config ptlflow/models/dpflow/configs/dpflow-test-spring-zeroshot.yaml +``` + +```bash +python test.py --config ptlflow/models/dpflow/configs/dpflow-test-spring-finetune.yaml +``` + +## Code license + +The source code is released under the [Apache 2.0 LICENSE](LICENSE). + +## Pretrained weights license + +Based on the licenses of the datasets used for training the models, our weights are released strictly for academic and research purposes only. + +## Citation + +If you use this model, please consider citing the paper: + +``` +@InProceedings{Morimitsu2025DPFlow, + author = {Morimitsu, Henrique and Zhu, Xiaobin and Cesar-Jr., Roberto M. and Ji, Xiangyang and Yin, Xu-Cheng}, + booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + title = {{DPFlow}: Adaptive Optical Flow Estimation with a Dual-Pyramid Framework}, + year = {2025}, +} +``` \ No newline at end of file diff --git a/ptlflow/models/dpflow/__init__.py b/ptlflow/models/dpflow/__init__.py new file mode 100644 index 00000000..2e9853c9 --- /dev/null +++ b/ptlflow/models/dpflow/__init__.py @@ -0,0 +1 @@ +from .dpflow import dpflow diff --git a/ptlflow/models/dpflow/assets/teaser.jpg b/ptlflow/models/dpflow/assets/teaser.jpg new file mode 100644 index 00000000..a6bca31a Binary files /dev/null and b/ptlflow/models/dpflow/assets/teaser.jpg differ diff --git a/ptlflow/models/dpflow/cgu.py b/ptlflow/models/dpflow/cgu.py new file mode 100644 index 00000000..93777772 --- /dev/null +++ b/ptlflow/models/dpflow/cgu.py @@ -0,0 +1,412 @@ +# ============================================================================= +# Copyright 2025 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Code based on VAN: https://github.com/Visual-Attention-Network/VAN-Classification/blob/main/models/van.py +# ============================================================================= + +import math +from functools import partial + +import torch +import torch.nn as nn + +from .local_timm.drop import DropPath +from .local_timm.layer_helpers import to_2tuple +from .norm import GroupNorm2d +from .local_timm.weight_init import trunc_normal_ +from .conv import Conv2dBlock +from .utils import get_activation + + +class DWConv(nn.Module): + def __init__(self, dim=768, kernel_size=3): + super(DWConv, self).__init__() + self.dwconv = Conv2dBlock( + dim, dim, kernel_size, 1, kernel_size // 2, bias=True, groups=dim + ) + + def forward(self, x): + x = self.dwconv(x) + return x + + +class ActGLU(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation_function=None, + drop=0.0, + mlp_use_dw_conv=True, + mlp_dw_kernel_size=3, + mlp_in_kernel_size=1, + mlp_out_kernel_size=1, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1_g = Conv2dBlock( + in_features, + hidden_features, + mlp_in_kernel_size, + padding=mlp_in_kernel_size // 2, + ) + self.fc1_x = Conv2dBlock( + in_features, + hidden_features, + mlp_in_kernel_size, + padding=mlp_in_kernel_size // 2, + ) + self.dwconv_g = None + self.dwconv_x = None + if mlp_use_dw_conv: + self.dwconv_g = DWConv(hidden_features, mlp_dw_kernel_size) + self.dwconv_x = DWConv(hidden_features, mlp_dw_kernel_size) + act = ( + get_activation("gelu") + if activation_function is None + else activation_function + ) + self.act = act(inplace=True) + self.fc2 = Conv2dBlock( + hidden_features, + out_features, + mlp_out_kernel_size, + padding=mlp_out_kernel_size // 2, + ) + self.drop = nn.Dropout(drop) + + self.in_hid_factor = float(hidden_features) / in_features + self.hid_out_factor = float(out_features) / hidden_features + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x_gate = self.fc1_g(x) + x = self.fc1_x(x) + if self.dwconv_g is not None: + x_gate = self.dwconv_g(x_gate) + x = self.dwconv_x(x) + x = self.act(x_gate) * x + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CrossActGLU(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation_function=None, + drop=0.0, + mlp_use_dw_conv=True, + mlp_dw_kernel_size=3, + mlp_in_kernel_size=1, + mlp_out_kernel_size=1, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.merge_fc_g = Conv2dBlock(2 * in_features, in_features, 1) + + self.fc1_g = Conv2dBlock( + in_features, + hidden_features, + mlp_in_kernel_size, + padding=mlp_in_kernel_size // 2, + ) + self.fc1_y = Conv2dBlock( + in_features, + hidden_features, + mlp_in_kernel_size, + padding=mlp_in_kernel_size // 2, + ) + self.dwconv_g = None + self.dwconv_y = None + if mlp_use_dw_conv: + self.dwconv_g = DWConv(hidden_features, mlp_dw_kernel_size) + self.dwconv_y = DWConv(hidden_features, mlp_dw_kernel_size) + act = ( + get_activation("gelu") + if activation_function is None + else activation_function + ) + self.act = act(inplace=True) + self.fc2 = Conv2dBlock( + hidden_features, + out_features, + mlp_out_kernel_size, + padding=mlp_out_kernel_size // 2, + ) + self.drop = nn.Dropout(drop) + + self.in_hid_factor = float(hidden_features) / in_features + self.hid_out_factor = float(out_features) / hidden_features + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, y): + xy = self.merge_fc_g(torch.cat([x, y], 1)) + xy_gate = self.fc1_g(xy) + y = self.fc1_y(y) + if self.dwconv_g is not None: + xy_gate = self.dwconv_g(xy_gate) + y = self.dwconv_y(y) + x = self.act(xy_gate) * y + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class LayerTransition(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, patch_size=3, stride=2, in_chans=64, embed_dim=64): + super().__init__() + patch_size = to_2tuple(patch_size) + self.proj = Conv2dBlock( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + return x + + +class CGU(nn.Module): + def __init__( + self, + dim, + drop=0.0, + drop_path=0.0, + activation_function=None, + norm_layer=partial(GroupNorm2d, num_groups=8), + use_cross=False, + mlp_ratio=4, + mlp_use_dw_conv=True, + mlp_dw_kernel_size=7, + mlp_in_kernel_size=1, + mlp_out_kernel_size=1, + layer_scale_init_value=1e-2, + ): + super().__init__() + self.use_cross = use_cross + + self.norm = norm_layer(num_channels=dim) + hidden_dim = int(dim * mlp_ratio) + self.conv_self = ActGLU( + in_features=dim, + hidden_features=hidden_dim, + activation_function=activation_function, + drop=drop, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + ) + if use_cross: + self.conv_cross = CrossActGLU( + in_features=dim, + hidden_features=hidden_dim, + activation_function=activation_function, + drop=drop, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if layer_scale_init_value < 1e-4: + self.layer_scale = None + else: + self.layer_scale = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, y=None): + if self.use_cross: + x_short = x.clone() + y_short = y.clone() + x = self.norm(x) + y = self.norm(y) + + x = self.conv_self(x) + y = self.conv_self(y) + + x = self.conv_cross(x, y) + if self.layer_scale is not None: + x = x * self.layer_scale[: x.shape[1]].unsqueeze(-1).unsqueeze(-1) + x = self.drop_path(x) + x = x + x_short + + y = self.conv_cross(y, x) + if self.layer_scale is not None: + y = y * self.layer_scale[: y.shape[1]].unsqueeze(-1).unsqueeze(-1) + y = self.drop_path(y) + y = y + y_short + else: + x_short = x.clone() + x = self.norm(x) + x = self.conv_self(x) + x = x * self.layer_scale[: x.shape[1]].unsqueeze(-1).unsqueeze(-1) + x = self.drop_path(x) + x = x + x_short + return x, y + + +class CGUStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=2, + drop=0.0, + drop_path=0.0, + activation_function=None, + norm_layer=partial(GroupNorm2d, num_groups=8), + depth=2, + use_cross=False, + mlp_ratio=4, + mlp_use_dw_conv=True, + mlp_dw_kernel_size=7, + mlp_in_kernel_size=1, + mlp_out_kernel_size=1, + layer_scale_init_value=1e-2, + ): + super(CGUStage, self).__init__() + self.conv_transition = None + self.use_cross = use_cross + if stride > 1 or in_chs != out_chs: + patch_size = 1 + if stride > 1: + patch_size = 3 + self.conv_transition = LayerTransition( + patch_size=patch_size, stride=stride, in_chans=in_chs, embed_dim=out_chs + ) + + self.blocks = nn.ModuleList() + for _ in range(depth): + self.blocks.append( + CGU( + dim=out_chs, + drop=drop, + drop_path=drop_path, + activation_function=activation_function, + norm_layer=norm_layer, + use_cross=use_cross, + mlp_ratio=mlp_ratio, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + layer_scale_init_value=layer_scale_init_value, + ) + ) + self.norm = norm_layer(num_channels=out_chs) + + def forward(self, x, y=None, skip_transition=False): + if self.conv_transition is not None and not skip_transition: + x = self.conv_transition(x) + if self.use_cross: + y = self.conv_transition(y) + for blk in self.blocks: + x, y = blk(x, y) + x = self.norm(x) + if self.use_cross: + y = self.norm(y) + return x, y + return x diff --git a/ptlflow/models/dpflow/cgu_bidir_dual_encoder.py b/ptlflow/models/dpflow/cgu_bidir_dual_encoder.py new file mode 100644 index 00000000..ec2ae3e8 --- /dev/null +++ b/ptlflow/models/dpflow/cgu_bidir_dual_encoder.py @@ -0,0 +1,314 @@ +# ============================================================================= +# Copyright 2025 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .cgu import CGUStage +from .local_timm.weight_init import trunc_normal_ +from .res_stem import ResStem +from .conv import Conv2dBlock, ConvTranspose2dBlock +from .norm import GroupNorm2d +from .update import ConvGRU + + +class CGUBidirDualEncoder(nn.Module): + def __init__( + self, + pyramid_levels: Optional[int], + hidden_chs: int, + out_1x1_abs_chs: int, + out_1x1_factor: Optional[float], + num_out_stages: int = 0, + gru_mode: str = "gru", + activation_function: Optional[nn.Module] = None, + norm_layer=GroupNorm2d, + depth: int = 2, + mlp_ratio: float = 4, + mlp_use_dw_conv: bool = False, + mlp_dw_kernel_size: int = 7, + mlp_in_kernel_size: int = 1, + mlp_out_kernel_size: int = 1, + cgu_layer_scale_init_value: float = 1e-2, + ): + super().__init__() + + self.pyramid_levels = pyramid_levels + self.hidden_chs = hidden_chs + self.out_1x1_abs_chs = out_1x1_abs_chs + self.out_1x1_factor = out_1x1_factor + self.num_out_stages = num_out_stages + self.gru_mode = gru_mode + + self.forward_gru = ConvGRU(hidden_chs[-1], hidden_chs[-1]) + self.down_gru = Conv2dBlock( + hidden_chs[-1], hidden_chs[-1], 3, stride=2, padding=1, bias=True + ) + + self.backward_gru = ConvGRU(hidden_chs[-1], hidden_chs[-1]) + self.up_gru = ConvTranspose2dBlock( + hidden_chs[-1], hidden_chs[-1], 4, stride=2, padding=1, bias=True + ) + + self.stem = self._make_stem( + [hidden_chs[0], hidden_chs[1], 2 * hidden_chs[2]], + norm_layer=norm_layer, + ) + + self.lowres_stem = self._make_stem( + hidden_chs, + norm_layer=norm_layer, + ) + + if self.out_1x1_abs_chs > 0: + self.out_1x1 = self._make_out_1x1_layer( + hidden_chs[-1], self.out_1x1_abs_chs + ) + + self.rec_stage = self._make_stage( + hidden_chs[-1], + out_chs=hidden_chs[-1], + activation_function=activation_function, + norm_layer=norm_layer, + depth=depth, + mlp_ratio=mlp_ratio, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + cgu_layer_scale_init_value=cgu_layer_scale_init_value, + ) + + self.back_stage = self._make_stage( + hidden_chs[-1], + out_chs=hidden_chs[-1], + activation_function=activation_function, + norm_layer=norm_layer, + depth=depth, + mlp_ratio=mlp_ratio, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + stride=1, + cgu_layer_scale_init_value=cgu_layer_scale_init_value, + ) + + if self.num_out_stages > 0: + self.out_merge_conv = Conv2dBlock( + 3 * hidden_chs[-1], hidden_chs[-1], kernel_size=1, pre_act_fn=nn.ReLU() + ) + self.out_stages = self._make_out_stages( + self.num_out_stages, + hidden_chs[-1], + out_chs=None, + activation_function=activation_function, + norm_layer=norm_layer, + depth=depth, + mlp_ratio=mlp_ratio, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + cgu_layer_scale_init_value=cgu_layer_scale_init_value, + ) + + self._init_weights() + + def _make_stem(self, hidden_chs: int, norm_layer): + if not isinstance(hidden_chs, (list, tuple)): + return + + return ResStem([hidden_chs[0], hidden_chs[1], hidden_chs[2]], norm_layer) + + def _make_stage( + self, + hidden_chs: int, + out_chs=None, + activation_function=None, + norm_layer=GroupNorm2d, + depth=2, + mlp_ratio=4, + mlp_use_dw_conv=True, + mlp_dw_kernel_size=7, + mlp_in_kernel_size=1, + mlp_out_kernel_size=1, + stride=2, + cgu_layer_scale_init_value=1e-2, + ): + if out_chs is None: + out_chs = hidden_chs + + return CGUStage( + hidden_chs, + out_chs, + stride=stride, + activation_function=activation_function, + norm_layer=norm_layer, + depth=depth, + use_cross=True, + mlp_ratio=mlp_ratio, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + layer_scale_init_value=cgu_layer_scale_init_value, + ) + + def _make_out_stages( + self, + num_out_stages: int, + hidden_chs: int, + out_chs=None, + activation_function=None, + norm_layer=GroupNorm2d, + depth=2, + mlp_dw_kernel_size=7, + mlp_ratio=4, + mlp_use_dw_conv=True, + mlp_in_kernel_size=1, + mlp_out_kernel_size=1, + cgu_layer_scale_init_value=1e-2, + ): + if out_chs is None: + out_chs = hidden_chs + return CGUStage( + hidden_chs, + out_chs, + stride=1, + activation_function=activation_function, + norm_layer=norm_layer, + depth=num_out_stages * depth, + use_cross=True, + mlp_ratio=mlp_ratio, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + layer_scale_init_value=cgu_layer_scale_init_value, + ) + + def _make_out_1x1_layer(self, hidden_chs: int, out_chs: int): + return Conv2dBlock(hidden_chs, out_chs, kernel_size=1) + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x: torch.Tensor, y: torch.Tensor, pyr_levels: int): + input_x = x + input_y = y + + x_pyramid = [] + y_pyramid = [] + + pyr_iters = pyr_levels + 1 + for i in range(pyr_iters): + if i == 0: + x = self.stem(x) + y = self.stem(y) + + x, hx = torch.split(x, [x.shape[1] // 2, x.shape[1] // 2], 1) + y, hy = torch.split(y, [y.shape[1] // 2, y.shape[1] // 2], 1) + hx = torch.tanh(hx) + hy = torch.tanh(hy) + else: + hx = self.forward_gru(hx, x) + hy = self.forward_gru(hy, y) + + x, y = self.rec_stage(hx, hy) + x = x.contiguous() + y = y.contiguous() + if i < (pyr_iters - 1): + hx = self.down_gru(hx) + hx = torch.tanh(hx) + hy = self.down_gru(hy) + hy = torch.tanh(hy) + + if i >= 1: + x_pyramid.append(x) + y_pyramid.append(y) + + hx = torch.zeros_like(x_pyramid[-1]) + hy = torch.zeros_like(y_pyramid[-1]) + for i in range(len(x_pyramid) - 1, -1, -1): + x = x_pyramid[i] + y = y_pyramid[i] + + hx = self.backward_gru(hx, x) + hy = self.backward_gru(hy, y) + + x2, y2 = self.back_stage(hx, hy) + + input_x_lowres = F.interpolate( + input_x, + scale_factor=(1.0 / 2.0 ** (i + 1)), + mode="bilinear", + align_corners=True, + ) + x_lowres = self.lowres_stem(input_x_lowres) + + input_y_lowres = F.interpolate( + input_y, + scale_factor=(1.0 / 2.0 ** (i + 1)), + mode="bilinear", + align_corners=True, + ) + y_lowres = self.lowres_stem(input_y_lowres) + + x_pyramid[i] = torch.cat([x, x2, x_lowres], 1) + y_pyramid[i] = torch.cat([y, y2, y_lowres], 1) + + if i > 0: + hx = self.up_gru(hx) + hx = torch.tanh(hx) + hy = self.up_gru(hy) + hy = torch.tanh(hy) + + for i, (x, y) in enumerate(zip(x_pyramid, y_pyramid)): + if self.num_out_stages > 0: + x = self.out_merge_conv(x) + y = self.out_merge_conv(y) + x, y = self.out_stages(x, y) + if self.out_1x1_abs_chs > 0: + if self.out_1x1_factor is None: + x = self.out_1x1(x) + y = self.out_1x1(y) + else: + x = self.out_1x1(x, int(self.out_1x1_factor * x.shape[1])) + y = self.out_1x1(y, int(self.out_1x1_factor * y.shape[1])) + x_pyramid[i] = x + y_pyramid[i] = y + + return x_pyramid[::-1], y_pyramid[::-1] diff --git a/ptlflow/models/dpflow/configs/dpflow-test.yaml b/ptlflow/models/dpflow/configs/dpflow-test.yaml new file mode 100644 index 00000000..ec14eb79 --- /dev/null +++ b/ptlflow/models/dpflow/configs/dpflow-test.yaml @@ -0,0 +1,59 @@ +# lightning.pytorch==2.4.0 +output_path: outputs/test +show: false +max_forward_side: null +scale_factor: null +max_show_side: 1000 +save_viz: true +model: + class_path: ptlflow.models.dpflow + init_args: + iters_per_level: 4 + detach_flow: true + use_norm_affine: false + group_norm_num_groups: 8 + corr_mode: allpairs + corr_levels: 1 + corr_range: 4 + activation_function: orig + enc_network: cgu_bidir_dual + enc_norm_type: group + enc_depth: 4 + enc_mlp_ratio: 2.0 + enc_mlp_in_kernel_size: 1 + enc_mlp_out_kernel_size: 1 + enc_hidden_chs: + - 64 + - 96 + - 128 + enc_num_out_stages: 1 + enc_out_1x1_chs: '384' + dec_gru_norm_type: layer + dec_gru_iters: 1 + dec_gru_depth: 4 + dec_gru_mlp_ratio: 2.0 + dec_gru_mlp_in_kernel_size: 1 + dec_gru_mlp_out_kernel_size: 1 + dec_net_chs: 128 + dec_inp_chs: 128 + dec_motion_chs: 128 + dec_flow_kernel_size: 7 + dec_flow_head_chs: 256 + dec_motenc_corr_hidden_chs: 256 + dec_motenc_corr_out_chs: 192 + dec_motenc_flow_hidden_chs: 128 + dec_motenc_flow_out_chs: 64 + use_upsample_mask: true + upmask_gradient_scale: 1.0 + cgu_mlp_dw_kernel_size: 7 + cgu_fusion_gate_activation: gelu + cgu_mlp_use_dw_conv: true + cgu_mlp_activation_function: gelu + cgu_layer_scale_init_value: 0.01 + loss: laplace + gamma: 0.8 + max_flow: 400.0 + use_var: true + var_min: 0.0 + var_max: 10.0 + warm_start: true diff --git a/ptlflow/models/dpflow/configs/dpflow-train1-chairs.yaml b/ptlflow/models/dpflow/configs/dpflow-train1-chairs.yaml new file mode 100644 index 00000000..1ab0fc53 --- /dev/null +++ b/ptlflow/models/dpflow/configs/dpflow-train1-chairs.yaml @@ -0,0 +1,70 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +lr: 2.5e-4 +wdecay: 1.0e-4 +trainer: + max_epochs: 45 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + devices: 0,1 +model: + class_path: ptlflow.models.dpflow + init_args: + pyramid_levels: 3 + iters_per_level: 4 + detach_flow: true + use_norm_affine: false + group_norm_num_groups: 8 + corr_mode: allpairs + corr_levels: 1 + corr_range: 4 + activation_function: orig + enc_network: cgu_bidir_dual + enc_norm_type: group + enc_depth: 4 + enc_mlp_ratio: 2.0 + enc_mlp_in_kernel_size: 1 + enc_mlp_out_kernel_size: 1 + enc_hidden_chs: + - 64 + - 96 + - 128 + enc_num_out_stages: 1 + enc_out_1x1_chs: '384' + dec_gru_norm_type: layer + dec_gru_iters: 1 + dec_gru_depth: 4 + dec_gru_mlp_ratio: 2.0 + dec_gru_mlp_in_kernel_size: 1 + dec_gru_mlp_out_kernel_size: 1 + dec_net_chs: 128 + dec_inp_chs: 128 + dec_motion_chs: 128 + dec_flow_kernel_size: 7 + dec_flow_head_chs: 256 + dec_motenc_corr_hidden_chs: 256 + dec_motenc_corr_out_chs: 192 + dec_motenc_flow_hidden_chs: 128 + dec_motenc_flow_out_chs: 64 + use_upsample_mask: true + upmask_gradient_scale: 1.0 + cgu_mlp_dw_kernel_size: 7 + cgu_fusion_gate_activation: gelu + cgu_mlp_use_dw_conv: true + cgu_mlp_activation_function: gelu + cgu_layer_scale_init_value: 0.01 + loss: laplace + gamma: 0.8 + max_flow: 400.0 + use_var: true + var_min: 0.0 + var_max: 10.0 + warm_start: false +data: + train_dataset: chairs + val_dataset: sintel-final-val+kitti-2015-val + train_batch_size: 5 + train_num_workers: 5 + train_crop_size: [352, 480] + train_transform_cuda: false + train_transform_fp16: false diff --git a/ptlflow/models/dpflow/configs/dpflow-train2-things.yaml b/ptlflow/models/dpflow/configs/dpflow-train2-things.yaml new file mode 100644 index 00000000..f4e3edbd --- /dev/null +++ b/ptlflow/models/dpflow/configs/dpflow-train2-things.yaml @@ -0,0 +1,71 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +ckpt_path: chairs # Change to the ckpt resulting from dpflow-train1-chairs +lr: 1.25e-4 +wdecay: 1.0e-4 +trainer: + max_epochs: 80 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + devices: 0,1 +model: + class_path: ptlflow.models.dpflow + init_args: + pyramid_levels: 3 + iters_per_level: 4 + detach_flow: true + use_norm_affine: false + group_norm_num_groups: 8 + corr_mode: allpairs + corr_levels: 1 + corr_range: 4 + activation_function: orig + enc_network: cgu_bidir_dual + enc_norm_type: group + enc_depth: 4 + enc_mlp_ratio: 2.0 + enc_mlp_in_kernel_size: 1 + enc_mlp_out_kernel_size: 1 + enc_hidden_chs: + - 64 + - 96 + - 128 + enc_num_out_stages: 1 + enc_out_1x1_chs: '384' + dec_gru_norm_type: layer + dec_gru_iters: 1 + dec_gru_depth: 4 + dec_gru_mlp_ratio: 2.0 + dec_gru_mlp_in_kernel_size: 1 + dec_gru_mlp_out_kernel_size: 1 + dec_net_chs: 128 + dec_inp_chs: 128 + dec_motion_chs: 128 + dec_flow_kernel_size: 7 + dec_flow_head_chs: 256 + dec_motenc_corr_hidden_chs: 256 + dec_motenc_corr_out_chs: 192 + dec_motenc_flow_hidden_chs: 128 + dec_motenc_flow_out_chs: 64 + use_upsample_mask: true + upmask_gradient_scale: 1.0 + cgu_mlp_dw_kernel_size: 7 + cgu_fusion_gate_activation: gelu + cgu_mlp_use_dw_conv: true + cgu_mlp_activation_function: gelu + cgu_layer_scale_init_value: 0.01 + loss: laplace + gamma: 0.8 + max_flow: 400.0 + use_var: true + var_min: 0.0 + var_max: 10.0 + warm_start: false +data: + train_dataset: things-train + val_dataset: sintel-clean+sintel-final+kitti-2015 + train_batch_size: 3 + train_num_workers: 3 + train_crop_size: [384, 704] + train_transform_cuda: false + train_transform_fp16: false diff --git a/ptlflow/models/dpflow/configs/dpflow-train3-sintel.yaml b/ptlflow/models/dpflow/configs/dpflow-train3-sintel.yaml new file mode 100644 index 00000000..4a3bb7a4 --- /dev/null +++ b/ptlflow/models/dpflow/configs/dpflow-train3-sintel.yaml @@ -0,0 +1,71 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +ckpt_path: things # Change to the ckpt resulting from dpflow-train2-things +lr: 1.25e-4 +wdecay: 1.0e-5 +trainer: + max_epochs: 25 + accumulate_grad_batches: 4 + gradient_clip_val: 1.0 + devices: 0,1 +model: + class_path: ptlflow.models.dpflow + init_args: + pyramid_levels: 3 + iters_per_level: 4 + detach_flow: true + use_norm_affine: false + group_norm_num_groups: 8 + corr_mode: allpairs + corr_levels: 1 + corr_range: 4 + activation_function: orig + enc_network: cgu_bidir_dual + enc_norm_type: group + enc_depth: 4 + enc_mlp_ratio: 2.0 + enc_mlp_in_kernel_size: 1 + enc_mlp_out_kernel_size: 1 + enc_hidden_chs: + - 64 + - 96 + - 128 + enc_num_out_stages: 1 + enc_out_1x1_chs: '384' + dec_gru_norm_type: layer + dec_gru_iters: 1 + dec_gru_depth: 4 + dec_gru_mlp_ratio: 2.0 + dec_gru_mlp_in_kernel_size: 1 + dec_gru_mlp_out_kernel_size: 1 + dec_net_chs: 128 + dec_inp_chs: 128 + dec_motion_chs: 128 + dec_flow_kernel_size: 7 + dec_flow_head_chs: 256 + dec_motenc_corr_hidden_chs: 256 + dec_motenc_corr_out_chs: 192 + dec_motenc_flow_hidden_chs: 128 + dec_motenc_flow_out_chs: 64 + use_upsample_mask: true + upmask_gradient_scale: 1.0 + cgu_mlp_dw_kernel_size: 7 + cgu_fusion_gate_activation: gelu + cgu_mlp_use_dw_conv: true + cgu_mlp_activation_function: gelu + cgu_layer_scale_init_value: 0.01 + loss: laplace + gamma: 0.85 + max_flow: 400.0 + use_var: true + var_min: 0.0 + var_max: 10.0 + warm_start: false +data: + train_dataset: sintel-searaft_split + val_dataset: sintel-final-val+kitti-2015-val + train_batch_size: 3 + train_num_workers: 3 + train_crop_size: [368, 768] + train_transform_cuda: false + train_transform_fp16: false diff --git a/ptlflow/models/dpflow/configs/dpflow-train4a-kitti.yaml b/ptlflow/models/dpflow/configs/dpflow-train4a-kitti.yaml new file mode 100644 index 00000000..122722dd --- /dev/null +++ b/ptlflow/models/dpflow/configs/dpflow-train4a-kitti.yaml @@ -0,0 +1,72 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +ckpt_path: sintel # Change to the ckpt resulting from dpflow-train3-sintel +lr: 1.0e-4 +wdecay: 1.0e-5 +trainer: + max_epochs: 250 + check_val_every_n_epoch: 50 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + devices: 0,1 +model: + class_path: ptlflow.models.dpflow + init_args: + pyramid_levels: 3 + iters_per_level: 4 + detach_flow: true + use_norm_affine: false + group_norm_num_groups: 8 + corr_mode: allpairs + corr_levels: 1 + corr_range: 4 + activation_function: orig + enc_network: cgu_bidir_dual + enc_norm_type: group + enc_depth: 4 + enc_mlp_ratio: 2.0 + enc_mlp_in_kernel_size: 1 + enc_mlp_out_kernel_size: 1 + enc_hidden_chs: + - 64 + - 96 + - 128 + enc_num_out_stages: 1 + enc_out_1x1_chs: '384' + dec_gru_norm_type: layer + dec_gru_iters: 1 + dec_gru_depth: 4 + dec_gru_mlp_ratio: 2.0 + dec_gru_mlp_in_kernel_size: 1 + dec_gru_mlp_out_kernel_size: 1 + dec_net_chs: 128 + dec_inp_chs: 128 + dec_motion_chs: 128 + dec_flow_kernel_size: 7 + dec_flow_head_chs: 256 + dec_motenc_corr_hidden_chs: 256 + dec_motenc_corr_out_chs: 192 + dec_motenc_flow_hidden_chs: 128 + dec_motenc_flow_out_chs: 64 + use_upsample_mask: true + upmask_gradient_scale: 1.0 + cgu_mlp_dw_kernel_size: 7 + cgu_fusion_gate_activation: gelu + cgu_mlp_use_dw_conv: true + cgu_mlp_activation_function: gelu + cgu_layer_scale_init_value: 0.01 + loss: laplace + gamma: 0.85 + max_flow: 400.0 + use_var: true + var_min: 0.0 + var_max: 10.0 + warm_start: false +data: + train_dataset: kitti-2015 + val_dataset: sintel-final-val+kitti-2015-val + train_batch_size: 3 + train_num_workers: 3 + train_crop_size: [288, 960] + train_transform_cuda: false + train_transform_fp16: false diff --git a/ptlflow/models/dpflow/configs/dpflow-train4b-spring.yaml b/ptlflow/models/dpflow/configs/dpflow-train4b-spring.yaml new file mode 100644 index 00000000..14f6c2d1 --- /dev/null +++ b/ptlflow/models/dpflow/configs/dpflow-train4b-spring.yaml @@ -0,0 +1,71 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +ckpt_path: sintel # Change to the ckpt resulting from dpflow-train3-sintel +lr: 1.0e-4 +wdecay: 1.0e-5 +trainer: + max_epochs: 100 + accumulate_grad_batches: 4 + gradient_clip_val: 1.0 + devices: 0,1 +model: + class_path: ptlflow.models.dpflow + init_args: + pyramid_levels: 4 + iters_per_level: 4 + detach_flow: true + use_norm_affine: false + group_norm_num_groups: 8 + corr_mode: allpairs + corr_levels: 1 + corr_range: 4 + activation_function: orig + enc_network: cgu_bidir_dual + enc_norm_type: group + enc_depth: 4 + enc_mlp_ratio: 2.0 + enc_mlp_in_kernel_size: 1 + enc_mlp_out_kernel_size: 1 + enc_hidden_chs: + - 64 + - 96 + - 128 + enc_num_out_stages: 1 + enc_out_1x1_chs: '384' + dec_gru_norm_type: layer + dec_gru_iters: 1 + dec_gru_depth: 4 + dec_gru_mlp_ratio: 2.0 + dec_gru_mlp_in_kernel_size: 1 + dec_gru_mlp_out_kernel_size: 1 + dec_net_chs: 128 + dec_inp_chs: 128 + dec_motion_chs: 128 + dec_flow_kernel_size: 7 + dec_flow_head_chs: 256 + dec_motenc_corr_hidden_chs: 256 + dec_motenc_corr_out_chs: 192 + dec_motenc_flow_hidden_chs: 128 + dec_motenc_flow_out_chs: 64 + use_upsample_mask: true + upmask_gradient_scale: 1.0 + cgu_mlp_dw_kernel_size: 7 + cgu_fusion_gate_activation: gelu + cgu_mlp_use_dw_conv: true + cgu_mlp_activation_function: gelu + cgu_layer_scale_init_value: 0.01 + loss: laplace + gamma: 0.85 + max_flow: 400.0 + use_var: true + var_min: 0.0 + var_max: 10.0 + warm_start: false +data: + train_dataset: spring-sub-rev + val_dataset: sintel-final-val+kitti-2015-val + train_batch_size: 2 + train_num_workers: 2 + train_crop_size: [540, 960] + train_transform_cuda: false + train_transform_fp16: false diff --git a/ptlflow/models/dpflow/configs/dpflow-validate.yaml b/ptlflow/models/dpflow/configs/dpflow-validate.yaml new file mode 100644 index 00000000..075bd37b --- /dev/null +++ b/ptlflow/models/dpflow/configs/dpflow-validate.yaml @@ -0,0 +1,70 @@ +# lightning.pytorch==2.4.0 +all: false +select: null +exclude: null +output_path: outputs/validate +write_outputs: false +show: false +flow_format: original +max_forward_side: null +scale_factor: null +max_show_side: 1000 +max_samples: null +reversed: false +fp16: false +seq_val_mode: all +write_individual_metrics: false +epe_clip: 5.0 +seed_everything: true +model: + class_path: ptlflow.models.dpflow + init_args: + iters_per_level: 4 + detach_flow: true + use_norm_affine: false + group_norm_num_groups: 8 + corr_mode: allpairs + corr_levels: 1 + corr_range: 4 + activation_function: orig + enc_network: cgu_bidir_dual + enc_norm_type: group + enc_depth: 4 + enc_mlp_ratio: 2.0 + enc_mlp_in_kernel_size: 1 + enc_mlp_out_kernel_size: 1 + enc_hidden_chs: + - 64 + - 96 + - 128 + enc_num_out_stages: 1 + enc_out_1x1_chs: '384' + dec_gru_norm_type: layer + dec_gru_iters: 1 + dec_gru_depth: 4 + dec_gru_mlp_ratio: 2.0 + dec_gru_mlp_in_kernel_size: 1 + dec_gru_mlp_out_kernel_size: 1 + dec_net_chs: 128 + dec_inp_chs: 128 + dec_motion_chs: 128 + dec_flow_kernel_size: 7 + dec_flow_head_chs: 256 + dec_motenc_corr_hidden_chs: 256 + dec_motenc_corr_out_chs: 192 + dec_motenc_flow_hidden_chs: 128 + dec_motenc_flow_out_chs: 64 + use_upsample_mask: true + upmask_gradient_scale: 1.0 + cgu_mlp_dw_kernel_size: 7 + cgu_fusion_gate_activation: gelu + cgu_mlp_use_dw_conv: true + cgu_mlp_activation_function: gelu + cgu_layer_scale_init_value: 0.01 + loss: laplace + gamma: 0.8 + max_flow: 400.0 + use_var: true + var_min: 0.0 + var_max: 10.0 + warm_start: false diff --git a/ptlflow/models/dpflow/conv.py b/ptlflow/models/dpflow/conv.py new file mode 100644 index 00000000..886f1160 --- /dev/null +++ b/ptlflow/models/dpflow/conv.py @@ -0,0 +1,291 @@ +# ============================================================================= +# Copyright 2025 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .norm import GroupNorm2d, BatchNorm2d, LayerNorm2d + + +def conv2d( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + pre_act_fn=None, + post_act_fn=None, + pre_norm=None, + post_norm=None, + is_transpose: bool = False, +) -> torch.Tensor: + if pre_norm is not None: + x = pre_norm(x) + + if pre_act_fn is not None: + x = pre_act_fn(x) + + if is_transpose: + x = F.conv_transpose2d( + x, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + else: + x = F.conv2d( + x, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + if post_norm is not None: + x = post_norm(x) + + if post_act_fn is not None: + x = post_act_fn(x) + + return x + + +class ConvBase(nn.Module): + def __init__( + self, + is_transpose, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + pre_norm=None, + post_norm=None, + group_norm_groups=8, + pre_act_fn=None, + post_act_fn=None, + ) -> None: + super().__init__() + self.is_transpose = is_transpose + self.in_channels = in_channels + self.out_channels = out_channels + if not isinstance(kernel_size, (list, tuple)): + kernel_size = (kernel_size, kernel_size) + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.padding_mode = padding_mode + self.device = device + self.dtype = dtype + self.pre_norm = pre_norm + self.post_norm = post_norm + self.pre_act_fn = pre_act_fn + self.post_act_fn = post_act_fn + + if self.is_transpose: + self.register_parameter( + "weight", + nn.Parameter( + torch.zeros( + in_channels // groups, + out_channels, + kernel_size[0], + kernel_size[1], + device=device, + dtype=dtype, + ) + ), + ) + else: + self.register_parameter( + "weight", + nn.Parameter( + torch.zeros( + out_channels, + in_channels // groups, + kernel_size[0], + kernel_size[1], + device=device, + dtype=dtype, + ) + ), + ) + + if bias: + self.register_parameter( + "bias", + nn.Parameter(torch.zeros(out_channels, device=device, dtype=dtype)), + ) + else: + self.bias = None + + if isinstance(pre_norm, str): + if pre_norm == "instance": + self.pre_norm = nn.InstanceNorm2d(1, track_running_stats=False) + elif pre_norm == "group": + self.pre_norm = GroupNorm2d( + group_norm_groups, out_channels, affine=False + ) + elif pre_norm == "batch": + self.pre_norm = BatchNorm2d(out_channels) + elif pre_norm == "layer": + self.pre_norm = LayerNorm2d(out_channels) + else: + self.pre_norm = pre_norm + + if isinstance(post_norm, str): + if post_norm == "instance": + self.post_norm = nn.InstanceNorm2d(1, track_running_stats=False) + elif post_norm == "group": + self.post_norm = GroupNorm2d( + group_norm_groups, out_channels, affine=False + ) + elif post_norm == "batch": + self.post_norm = BatchNorm2d(out_channels) + elif post_norm == "layer": + self.post_norm = LayerNorm2d(out_channels) + else: + self.post_norm = post_norm + + self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size) + # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return conv2d( + x=x, + weight=self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + pre_act_fn=self.pre_act_fn, + post_act_fn=self.post_act_fn, + pre_norm=self.pre_norm, + post_norm=self.post_norm, + is_transpose=self.is_transpose, + ) + + +class Conv2dBlock(ConvBase): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + pre_norm=None, + post_norm=None, + group_norm_groups=8, + pre_act_fn=None, + post_act_fn=None, + ) -> None: + super().__init__( + is_transpose=False, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + pre_norm=pre_norm, + post_norm=post_norm, + group_norm_groups=group_norm_groups, + pre_act_fn=pre_act_fn, + post_act_fn=post_act_fn, + ) + + +class ConvTranspose2dBlock(ConvBase): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + pre_norm=None, + post_norm=None, + group_norm_groups=8, + pre_act_fn=None, + post_act_fn=None, + ) -> None: + super().__init__( + is_transpose=True, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + pre_norm=pre_norm, + post_norm=post_norm, + group_norm_groups=group_norm_groups, + pre_act_fn=pre_act_fn, + post_act_fn=post_act_fn, + ) diff --git a/ptlflow/models/dpflow/corr.py b/ptlflow/models/dpflow/corr.py new file mode 100644 index 00000000..da835e27 --- /dev/null +++ b/ptlflow/models/dpflow/corr.py @@ -0,0 +1,185 @@ +# ============================================================================= +# Copyright 2025 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import math + +from einops import rearrange +import torch +import torch.nn.functional as F + +try: + from spatial_correlation_sampler import SpatialCorrelationSampler +except ModuleNotFoundError: + from ptlflow.utils.correlation import ( + IterSpatialCorrelationSampler as SpatialCorrelationSampler, + ) +from ptlflow.utils.correlation import ( + IterTranslatedSpatialCorrelationSampler as TranslatedSpatialCorrelationSampler, +) + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + +from .utils import bilinear_sampler + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + if min(corr.shape[2:4]) > 2 * radius + 1: + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, f1 = coords.shape + + corr_full = [] + for j in range(0, f1, 2): + corr_pyramid = [] + for i in range(self.num_levels): + corr_raw = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1, dtype=coords.dtype) + dy = torch.linspace(-r, r, 2 * r + 1, dtype=coords.dtype) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( + coords.device + ) + + centroid_lvl = ( + coords[..., j : j + 2].reshape(batch * h1 * w1, 1, 1, 2) / 2**i + ) + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr_raw, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + corr_pyramid.append(corr) + + corr_pyramid = torch.cat(corr_pyramid, dim=-1) + corr_full.append(corr_pyramid) + + corr_full = torch.cat(corr_full, -1) + return corr_full.permute(0, 3, 1, 2).contiguous() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim)) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + if coords.dtype == torch.float16: + fmap1_i = fmap1_i.float() + fmap2_i = fmap2_i.float() + coords_i = coords_i.float() + (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + if coords.dtype == torch.float16: + corr = corr.half() + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / math.sqrt(dim) + + +class LocalCorrBlock: + def __init__(self, num_levels=1, radius=4, use_translated_correlation=False): + self.num_levels = num_levels + self.radius = radius + self.use_translated_correlation = use_translated_correlation + self.side = 2 * radius + 1 + + if self.use_translated_correlation: + self.corr = TranslatedSpatialCorrelationSampler( + kernel_size=1, patch_size=2 ** (num_levels - 1) * self.side, padding=0 + ) + else: + self.corr = SpatialCorrelationSampler( + kernel_size=1, patch_size=2 ** (num_levels - 1) * self.side, padding=0 + ) + + def __call__(self, fmap1, fmap2, flow=None): + if self.use_translated_correlation: + out_corr = self.corr(fmap1, fmap2, flow) + else: + out_corr = self.corr(fmap1, fmap2) + + out_corr = out_corr / math.sqrt(fmap1.shape[1]) + + b, h2, w2, h1, w1 = out_corr.shape + out_corr = rearrange(out_corr[:, None], "b c h2 w2 h1 w1 -> (b h1 w1) c h2 w2") + corr_pyr = [] + for i in range(self.num_levels): + if i > 0: + out_corr = F.avg_pool2d(out_corr, 2, stride=2) + + hm, wm = out_corr.shape[2] // 2, out_corr.shape[3] // 2 + corr_pyr.append( + out_corr[ + :, + :, + hm - self.radius : hm + self.radius + 1, + wm - self.radius : wm + self.radius + 1, + ] + ) + for i in range(len(corr_pyr)): + corr_pyr[i] = rearrange( + corr_pyr[i], "(b h1 w1) c h2 w2 -> b (c w2 h2) h1 w1", h1=h1, w1=w1 + ) + out_corr = torch.cat(corr_pyr, dim=1) + return out_corr diff --git a/ptlflow/models/dpflow/dpflow.py b/ptlflow/models/dpflow/dpflow.py new file mode 100644 index 00000000..3806fa91 --- /dev/null +++ b/ptlflow/models/dpflow/dpflow.py @@ -0,0 +1,548 @@ +# ============================================================================= +# Copyright 2025 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import math +from typing import Optional + +from loguru import logger +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..base_model.base_model import BaseModel +from .corr import CorrBlock, AlternateCorrBlock +from .pwc_modules import rescale_flow, upsample2d_as +from .cgu_bidir_dual_encoder import CGUBidirDualEncoder +from .update import UpdateBlock +from .utils import ( + compute_pyramid_levels, + get_activation, + get_norm, +) +from ptlflow.utils.utils import forward_interpolate_batch +from ptlflow.utils.registry import register_model, trainable, ptlflow_trained + +try: + import alt_cuda_corr +except: + alt_cuda_corr = None + + +class SequenceLoss(nn.Module): + def __init__(self, loss: str, max_flow: float, gamma: float): + super().__init__() + self.loss = loss + self.max_flow = max_flow + self.gamma = gamma + + def forward(self, outputs, inputs): + """Loss function defined over sequence of flow predictions""" + + flow_preds = outputs["flow_preds"] + flow_gt = inputs["flows"][:, 0] + valid = inputs["valids"][:, 0] + + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # exclude invalid pixels and extremely large diplacements + mag = torch.sum(flow_gt**2, dim=1, keepdim=True).sqrt() + valid = (valid >= 0.5) & (mag < self.max_flow) + + for i in range(n_predictions): + pred = flow_preds[i] + if ( + pred.shape[-2] != flow_gt.shape[-2] + or pred.shape[-1] != flow_gt.shape[-1] + ): + pred = F.interpolate( + pred, size=flow_gt.shape[-2:], mode="bilinear", align_corners=True + ) + i_weight = self.gamma ** (n_predictions - i - 1) + + if self.loss == "l1" or outputs["nf_preds"][i] is None: + diff = pred - flow_gt + i_loss = (diff).abs() + valid_loss = valid * i_loss + flow_loss += i_weight * valid_loss.mean() + elif self.loss == "laplace": + loss_i = outputs["nf_preds"][i] + final_mask = ( + (~torch.isnan(loss_i.detach())) + & (~torch.isinf(loss_i.detach())) + & valid + ) + flow_loss += i_weight * ((final_mask * loss_i).sum() / final_mask.sum()) + + return flow_loss + + +class DPFlow(BaseModel): + pretrained_checkpoints = { + "chairs": "https://github.com/hmorimitsu/ptlflow/releases/download/weights1/dpflow-chairs-f94e717a.ckpt", + "kitti": "https://github.com/hmorimitsu/ptlflow/releases/download/weights1/dpflow-kitti-4e97eac6.ckpt", + "sintel": "https://github.com/hmorimitsu/ptlflow/releases/download/weights1/dpflow-sintel-b44b072c.ckpt", + "spring": "https://github.com/hmorimitsu/ptlflow/releases/download/weights1/dpflow-spring-69bac7fa.ckpt", + "things": "https://github.com/hmorimitsu/ptlflow/releases/download/weights1/dpflow-things-2012b5d6.ckpt", + } + + def __init__( + self, + pyramid_levels: Optional[int] = None, + iters_per_level: int = 4, + detach_flow: bool = True, + use_norm_affine: bool = False, + group_norm_num_groups: int = 8, + corr_mode: str = "allpairs", # "allpairs" or "local" + corr_levels: int = 1, + corr_range: int = 4, + activation_function: str = "orig", # "orig", "relu", "gelu", "silu", or "mish" + enc_network: str = "cgu_bidir_dual", # "cgu", "cgu_bidir", "cgu_bidir_dual", "cgu_dual", "next_bidir_dual", "swin" + enc_norm_type: str = "group", # "none", "group", "layer", or "batch" + enc_depth: int = 4, + enc_mlp_ratio: float = 2.0, + enc_mlp_in_kernel_size: int = 1, + enc_mlp_out_kernel_size: int = 1, + enc_hidden_chs: list[int] = (64, 96, 128), + enc_num_out_stages: int = 1, + enc_out_1x1_chs: str = "384", + dec_gru_norm_type: str = "layer", # "none", "group", "layer", or "batch" + dec_gru_iters: int = 1, + dec_gru_depth: int = 4, + dec_gru_mlp_ratio: float = 2.0, + dec_gru_mlp_in_kernel_size: int = 1, + dec_gru_mlp_out_kernel_size: int = 1, + dec_net_chs: int = 128, + dec_inp_chs: int = 128, + dec_motion_chs: int = 128, + dec_flow_kernel_size: int = 7, + dec_flow_head_chs: int = 256, + dec_motenc_corr_hidden_chs: int = 256, + dec_motenc_corr_out_chs: int = 192, + dec_motenc_flow_hidden_chs: int = 128, + dec_motenc_flow_out_chs: int = 64, + use_upsample_mask: bool = True, + upmask_gradient_scale: float = 1.0, + cgu_mlp_dw_kernel_size: int = 7, + cgu_fusion_gate_activation: str = "gelu", # "linear", "sigmoid", "relu", "gelu", "silu", or "mish" + cgu_mlp_use_dw_conv: bool = True, + cgu_mlp_activation_function: str = "gelu", # "linear", "sigmoid", "relu", "gelu", "silu", or "mish" + cgu_layer_scale_init_value: float = 0.01, + loss: str = "laplace", # "l1" or "laplace" + gamma: float = 0.8, + max_flow: float = 400.0, + use_var: bool = True, + var_min: float = 0.0, + var_max: float = 10.0, + **kwargs, + ): + if pyramid_levels is not None: + assert pyramid_levels > 2, "Only --model.pyramid_levels >= 3 is supported." + output_stride = int(2 ** (pyramid_levels + 2)) + if enc_network == "swin_bidir_dual": + output_stride *= 2 + else: + logger.info( + f"DPFlow: --model.pyramid_levels is not set, the number of pyramid levels will be inferred from the input size." + ) + output_stride = None + self.extra_output_stride = 1 if enc_network == "swin_bidir_dual" else 0 + + super(DPFlow, self).__init__( + loss_fn=SequenceLoss(loss=loss, max_flow=max_flow, gamma=gamma), + output_stride=output_stride, + **kwargs, + ) + + self.pyramid_levels = pyramid_levels + self.iters_per_level = iters_per_level + self.corr_mode = corr_mode + self.corr_range = corr_range + self.corr_levels = corr_levels + self.detach_flow = detach_flow + self.loss = loss + self.use_var = use_var + self.var_min = var_min + self.var_max = var_max + + activation_function = get_activation(activation_function) + + enc_out_1x1_chs = ( + float(enc_out_1x1_chs) + if (isinstance(enc_out_1x1_chs, str) and "." in enc_out_1x1_chs) + else int(enc_out_1x1_chs) + ) + + if isinstance(enc_out_1x1_chs, float): + out_1x1_factor = enc_out_1x1_chs + out_1x1_abs_chs = int(enc_out_1x1_chs * enc_hidden_chs[-1]) + else: + out_1x1_factor = None + out_1x1_abs_chs = enc_out_1x1_chs + + self.max_feat_chs = max( + enc_hidden_chs[-1], + out_1x1_abs_chs, + ) + + net_chs = dec_net_chs + inp_chs = dec_inp_chs + if net_chs is None or inp_chs is None: + base_chs = out_1x1_abs_chs + if base_chs < 1: + base_chs = enc_hidden_chs[-1] + + base_chs = base_chs // 3 * 2 + + if net_chs is None and inp_chs is None: + net_chs = inp_chs = base_chs // 2 + elif net_chs is None and inp_chs is not None: + net_chs = base_chs - inp_chs + elif net_chs is not None and inp_chs is None: + inp_chs = base_chs - net_chs + net_chs_fixed = net_chs + inp_chs_fixed = inp_chs + + enc_norm_layer = get_norm( + enc_norm_type, + affine=use_norm_affine, + num_groups=group_norm_num_groups, + ) + self.fnet = CGUBidirDualEncoder( + pyramid_levels=pyramid_levels, + hidden_chs=enc_hidden_chs, + out_1x1_abs_chs=out_1x1_abs_chs, + out_1x1_factor=out_1x1_factor, + num_out_stages=enc_num_out_stages, + activation_function=activation_function, + norm_layer=enc_norm_layer, + depth=enc_depth, + mlp_ratio=enc_mlp_ratio, + mlp_use_dw_conv=cgu_mlp_use_dw_conv, + mlp_dw_kernel_size=cgu_mlp_dw_kernel_size, + mlp_in_kernel_size=enc_mlp_in_kernel_size, + mlp_out_kernel_size=enc_mlp_out_kernel_size, + cgu_layer_scale_init_value=cgu_layer_scale_init_value, + ) + + self.dim_corr = (corr_range * 2 + 1) ** 2 * corr_levels + + dec_gru_norm_layer = get_norm( + dec_gru_norm_type, + affine=use_norm_affine, + num_groups=group_norm_num_groups, + ) + self.update_block = UpdateBlock( + dec_motenc_corr_hidden_chs=dec_motenc_corr_hidden_chs, + dec_motenc_corr_out_chs=dec_motenc_corr_out_chs, + dec_motenc_flow_hidden_chs=dec_motenc_flow_hidden_chs, + dec_motenc_flow_out_chs=dec_motenc_flow_out_chs, + corr_levels=corr_levels, + corr_range=corr_range, + dec_flow_kernel_size=dec_flow_kernel_size, + dec_motion_chs=dec_motion_chs, + activation_function=activation_function, + net_chs_fixed=net_chs_fixed, + inp_chs_fixed=inp_chs_fixed, + dec_gru_norm_layer=dec_gru_norm_layer, + dec_gru_depth=dec_gru_depth, + dec_gru_iters=dec_gru_iters, + dec_gru_mlp_ratio=dec_gru_mlp_ratio, + cgu_mlp_use_dw_conv=cgu_mlp_use_dw_conv, + cgu_mlp_dw_kernel_size=cgu_mlp_dw_kernel_size, + dec_gru_mlp_in_kernel_size=dec_gru_mlp_in_kernel_size, + dec_gru_mlp_out_kernel_size=dec_gru_mlp_out_kernel_size, + cgu_layer_scale_init_value=cgu_layer_scale_init_value, + dec_flow_head_chs=dec_flow_head_chs, + loss=loss, + use_upsample_mask=use_upsample_mask, + upmask_gradient_scale=upmask_gradient_scale, + ) + + act = nn.ReLU if activation_function is None else activation_function + self.input_act = act(inplace=True) + + self.current_output_stride = output_stride + + self.has_shown_input_message = False + self.has_shown_altcuda_message = False + + def coords_grid(self, batch, ht, wd): + coords = torch.meshgrid( + torch.arange(ht, dtype=self.dtype, device=self.device), + torch.arange(wd, dtype=self.dtype, device=self.device), + indexing="ij", + ) + coords = torch.stack(coords[::-1], dim=0).to(dtype=self.dtype) + return coords[None].repeat(batch, 1, 1, 1) + + def upsample_flow(self, flow, mask, factor, ch=2): + """Upsample flow field [H/f, W/f, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, factor, factor, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(flow, [3, 3], padding=1) + up_flow = up_flow.view(N, ch, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, ch, factor * H, factor * W) + + def _show_input_message(self, images): + pyr_levels = compute_pyramid_levels(images) + recommended_pyr_levels = pyr_levels # 3 for 1K, 4 for 2K, etc. + + logger.info( + f"DPFlow: Using {self.pyramid_levels} pyramid levels and {self.iters_per_level} iterations per level." + ) + logger.info( + f"DPFlow: Processing inputs of resolution {images.shape[-1]} x {images.shape[-2]}" + ) + logger.info(f"DPFlow: Correlation mode: {self.corr_mode}") + + if recommended_pyr_levels != self.pyramid_levels: + logger.info( + "DPFlow: For this input size, you may get better results by setting --pyramid_levels {}", + recommended_pyr_levels, + ) + + def _show_altcuda_message(self): + if self.corr_mode == "local" and alt_cuda_corr is None: + logger.warning( + f"DPFlow: You are running with --corr_mode local, but alt_cuda_corr is not installed. Please install alt_cuda_corr to increase the speed." + ) + + def forward(self, inputs): + try: + return self.forward_flow(inputs) + except torch.OutOfMemoryError: + if self.corr_mode == "allpairs": + logger.warning( + "DPFlow: CUDA out of memory error with input size {}. DPFlow will set --model.corr_mode to 'local' and re-attempt inference. This decreases memory consumption, but it is also slower.", + list(inputs["images"].shape[-2:]), + ) + self.corr_mode = "local" + try: + return self.forward_flow(inputs) + except torch.OutOfMemoryError: + logger.error( + "DPFlow: CUDA out of memory error even after setting --model.corr_mode to 'local'. DPFlow cannot process this input size: {} on this device.", + list(inputs["images"].shape[-2:]), + ) + else: + logger.error( + "DPFlow: CUDA out of memory error even with --model.corr_mode set to 'local'. DPFlow cannot process this input size: {} on this device.", + list(inputs["images"].shape[-2:]), + ) + + def forward_flow(self, inputs): + if self.corr_mode == "local" and not self.has_shown_altcuda_message: + self._show_altcuda_message() + self.has_shown_altcuda_message = True + + if self.pyramid_levels is not None and not self.has_shown_input_message: + self._show_input_message(inputs["images"]) + self.has_shown_input_message = True + + if self.pyramid_levels is None: + pyr_levels = compute_pyramid_levels(inputs["images"]) + output_stride = 2 ** (pyr_levels + 2 + self.extra_output_stride) + + if output_stride != self.current_output_stride: + logger.info( + "DPFlow: Detected change in input size. The number of pyramid levels will change to {}, corresponding to output stride {}.", + pyr_levels, + output_stride, + ) + self.current_output_stride = output_stride + else: + pyr_levels = self.pyramid_levels + output_stride = self.output_stride + + images, image_resizer = self.preprocess_images( + inputs["images"], + stride=output_stride, + bgr_add=-0.5, + bgr_mult=2.0, + bgr_to_rgb=True, + resize_mode="pad", + pad_mode="replicate", + pad_two_side=True, + ) + image1 = images[:, 0] + image2 = images[:, 1] + + flow_init = None + if ( + inputs.get("prev_preds") is not None + and inputs["prev_preds"].get("flow_small") is not None + ): + flow_init = inputs["prev_preds"]["flow_small"] + + flow_predictions, flow_small, flow_up, info_predictions = self.predict( + image1, + image2, + pyr_levels=pyr_levels, + image_resizer=image_resizer, + flow_init=flow_init, + ) + + nf_predictions = [] + if self.training and self.loss == "laplace": + # exlude invalid pixels and extremely large diplacements + for i in range(len(info_predictions)): + if not self.use_var: + var_max = var_min = 0 + else: + var_max = self.var_max + var_min = self.var_min + + if info_predictions[i] is None: + nf_predictions.append(None) + else: + raw_b = info_predictions[i][:, 2:] + log_b = torch.zeros_like(raw_b) + weight = info_predictions[i][:, :2] + # Large b Component + log_b[:, 0] = torch.clamp(raw_b[:, 0], min=0, max=var_max) + # Small b Component + log_b[:, 1] = torch.clamp(raw_b[:, 1], min=var_min, max=0) + # term2: [N, 2, m, H, W] + term2 = ( + (inputs["flows"][:, 0] - flow_predictions[i]).abs().unsqueeze(2) + ) * (torch.exp(-log_b).unsqueeze(1)) + # term1: [N, m, H, W] + term1 = weight - math.log(2) - log_b + nf_loss = torch.logsumexp( + weight, dim=1, keepdim=True + ) - torch.logsumexp(term1.unsqueeze(1) - term2, dim=2) + nf_predictions.append(nf_loss) + + outputs = {"flows": flow_up[:, None], "flow_small": flow_small} + + if self.training: + outputs["flow_preds"] = flow_predictions + outputs["nf_preds"] = nf_predictions + + return outputs + + def predict(self, x1_raw, x2_raw, pyr_levels, image_resizer, flow_init=None): + b, _, height_im, width_im = x1_raw.size() + + x1_pyramid, x2_pyramid = self.fnet(x1_raw, x2_raw, pyr_levels=pyr_levels) + + # outputs + flows = [] + infos = [] + + # init + ( + b_size, + _, + h_x1, + w_x1, + ) = x1_pyramid[0].size() + init_device = x1_pyramid[0].device + + if flow_init is not None: + flow = flow_init + flow = rescale_flow( + flow, + x1_pyramid[0].shape[-1], + x1_pyramid[0].shape[-2], + to_local=False, + ) + flow = upsample2d_as(flow, x1_pyramid[0], mode="bilinear") + flow = forward_interpolate_batch(flow) + else: + flow = torch.zeros( + b_size, 2, h_x1, w_x1, dtype=self.dtype, device=init_device + ) + + net = None + for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): + # Split feature channels into matching (x) and context (c) + xh = x1.shape[1] + ch = xh // 3 + x1, cn1 = torch.split(x1, [xh - ch, ch], dim=1) + x2, cn2 = torch.split(x2, [xh - ch, ch], dim=1) + halfch = ch // 2 + i1, n1 = torch.split(cn1, [ch - halfch, halfch], dim=1) + i2, n2 = torch.split(cn2, [ch - halfch, halfch], dim=1) + inp = torch.cat([i1, i2], 1) + inp = self.input_act(inp) + net_tmp = torch.cat([n1, n2], 1) + + coords0 = self.coords_grid(x1.shape[0], x1.shape[2], x1.shape[3]) + + if self.corr_mode == "allpairs": + corr_fn = CorrBlock(x1, x2, self.corr_levels, self.corr_range) + else: + corr_fn = AlternateCorrBlock(x1, x2, self.corr_levels, self.corr_range) + + if l > 0: + flow = rescale_flow(flow, x1.shape[-1], x1.shape[-2], to_local=False) + flow = upsample2d_as(flow, x1, mode="bilinear") + + net = torch.tanh(net_tmp) + + for it in range(self.iters_per_level): + if self.detach_flow: + flow = flow.detach() + + # correlation + out_corr = corr_fn(coords0 + flow) + + flow_res, net, mask = self.update_block(net, inp, out_corr, flow) + + info = None + if self.loss == "laplace": + info = flow_res[:, 2:] + flow_res = flow_res[:, :2] + + flow = flow + flow_res + + if self.training or ( + l == len(x1_pyramid) - 1 and it == self.iters_per_level - 1 + ): + out_flow = rescale_flow(flow, width_im, height_im, to_local=False) + if mask is not None: + out_flow = self.upsample_flow(out_flow, mask, factor=8) + out_flow = upsample2d_as(out_flow, x1_raw, mode="bilinear") + out_flow = self.postprocess_predictions( + out_flow, image_resizer, is_flow=True + ) + flows.append(out_flow) + + out_info = None + if info is not None: + if mask is not None: + out_info = self.upsample_flow(info, mask, factor=8, ch=4) + out_info = upsample2d_as(out_info, x1_raw, mode="bilinear") + out_info = self.postprocess_predictions( + out_info, image_resizer, is_flow=False + ) + infos.append(out_info) + + return flows, flow, out_flow, infos + + +@register_model +@trainable +@ptlflow_trained +class dpflow(DPFlow): + pass diff --git a/ptlflow/models/dpflow/local_timm/activations.py b/ptlflow/models/dpflow/local_timm/activations.py new file mode 100644 index 00000000..24f4edca --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/activations.py @@ -0,0 +1,174 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941""" + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + NOTE: I don't have a working inplace variant + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681""" + + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + + def forward(self, x): + return mish(x) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.0).div_(6.0) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0) + else: + return F.relu6(x + 3.0) / 6.0 + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + +def hard_mish(x, inplace: bool = False): + """Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + if inplace: + return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) + else: + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_mish(x, self.inplace) + + +class PReLU(nn.PReLU): + """Applies PReLU (w/ dummy inplace arg)""" + + def __init__( + self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False + ) -> None: + super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.prelu(input, self.weight) + + +def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x) + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)""" + + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) + + +def gelu_tanh(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x, approximate="tanh") + + +class GELUTanh(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)""" + + def __init__(self, inplace: bool = False): + super(GELUTanh, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input, approximate="tanh") + + +def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) + + +class QuickGELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)""" + + def __init__(self, inplace: bool = False): + super(QuickGELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return quick_gelu(input) diff --git a/ptlflow/models/dpflow/local_timm/activations_me.py b/ptlflow/models/dpflow/local_timm/activations_me.py new file mode 100644 index 00000000..a139705e --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/activations_me.py @@ -0,0 +1,220 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use +basic versions of the activations. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def swish_fwd(x): + return x.mul(torch.sigmoid(x)) + + +def swish_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishAutoFn(torch.autograd.Function): + """optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + + @staticmethod + def symbolic(g, x): + return g.op("Mul", x, g.op("Sigmoid", x)) + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishAutoFn.apply(x) + + +def mish_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +def mish_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishAutoFn(torch.autograd.Function): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient variant of Mish + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishAutoFn.apply(x) + + +def hard_sigmoid_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.0) + + +def hard_sigmoid_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.0) & (x <= 3.0)) / 6.0 + return grad_output * m + + +class HardSigmoidAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidAutoFn.apply(x) + + +def hard_swish_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.0) + + +def hard_swish_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.0) + m = torch.where((x >= -3.0) & (x <= 3.0), x / 3.0 + 0.5, m) + return grad_output * m + + +class HardSwishAutoFn(torch.autograd.Function): + """A memory efficient HardSwish activation""" + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_bwd(x, grad_output) + + @staticmethod + def symbolic(g, self): + input = g.op( + "Add", self, g.op("Constant", value_t=torch.tensor(3, dtype=torch.float)) + ) + hardtanh_ = g.op( + "Clip", + input, + g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)), + g.op("Constant", value_t=torch.tensor(6, dtype=torch.float)), + ) + hardtanh_ = g.op( + "Div", + hardtanh_, + g.op("Constant", value_t=torch.tensor(6, dtype=torch.float)), + ) + return g.op("Mul", self, hardtanh_) + + +def hard_swish_me(x, inplace=False): + return HardSwishAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishAutoFn.apply(x) + + +def hard_mish_fwd(x): + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +def hard_mish_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= -2.0) + m = torch.where((x >= -2.0) & (x <= 0.0), x + 1.0, m) + return grad_output * m + + +class HardMishAutoFn(torch.autograd.Function): + """A memory efficient variant of Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_mish_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_mish_bwd(x, grad_output) + + +def hard_mish_me(x, inplace: bool = False): + return HardMishAutoFn.apply(x) + + +class HardMishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishMe, self).__init__() + + def forward(self, x): + return HardMishAutoFn.apply(x) diff --git a/ptlflow/models/dpflow/local_timm/cond_conv2d.py b/ptlflow/models/dpflow/local_timm/cond_conv2d.py new file mode 100755 index 00000000..46db84b6 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/cond_conv2d.py @@ -0,0 +1,174 @@ +""" PyTorch Conditionally Parameterized Convolution (CondConv) + +Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference +(https://arxiv.org/abs/1904.04971) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .layer_helpers import to_2tuple +from .conv2d_same import conv2d_same +from .padding import get_padding_value + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if ( + len(weight.shape) != 2 + or weight.shape[0] != num_experts + or weight.shape[1] != num_params + ): + raise ( + ValueError( + "CondConv variables must have shape [num_experts, num_params]" + ) + ) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + + return condconv_initializer + + +class CondConv2d(nn.Module): + """Conditionally Parameterized Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + + __constants__ = ["in_channels", "out_channels", "dynamic_padding"] + + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding="", + dilation=1, + groups=1, + bias=False, + num_experts=4, + ): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation + ) + self.dynamic_padding = ( + is_padding_dynamic # if in forward to work with torchscript + ) + self.padding = to_2tuple(padding_val) + self.dilation = to_2tuple(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = ( + self.out_channels, + self.in_channels // self.groups, + ) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter( + torch.Tensor(self.num_experts, weight_num_param) + ) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter( + torch.Tensor(self.num_experts, self.out_channels) + ) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), + self.num_experts, + self.weight_shape, + ) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), + self.num_experts, + self.bias_shape, + ) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = ( + B * self.out_channels, + self.in_channels // self.groups, + ) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + # reshape instead of view to work with channels_last input + x = x.reshape(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, + weight, + bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups * B, + ) + else: + out = F.conv2d( + x, + weight, + bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups * B, + ) + out = out.permute([1, 0, 2, 3]).view( + B, self.out_channels, out.shape[-2], out.shape[-1] + ) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/ptlflow/models/dpflow/local_timm/config.py b/ptlflow/models/dpflow/local_timm/config.py new file mode 100644 index 00000000..1d6436fb --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/config.py @@ -0,0 +1,163 @@ +""" Model / Layer Config singleton state +""" + +import os +import warnings +from typing import Any, Optional + +import torch + +__all__ = [ + "is_exportable", + "is_scriptable", + "is_no_jit", + "use_fused_attn", + "set_exportable", + "set_scriptable", + "set_no_jit", + "set_layer_config", + "set_fused_attn", +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +# use torch.scaled_dot_product_attention where possible +_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention") +if "TIMM_FUSED_ATTN" in os.environ: + _USE_FUSED_ATTN = int(os.environ["TIMM_FUSED_ATTN"]) +else: + _USE_FUSED_ATTN = ( + 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) + ) + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None, + ): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False + + +def use_fused_attn(experimental: bool = False) -> bool: + # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0 + if not _HAS_FUSED_ATTN or _EXPORTABLE: + return False + if experimental: + return _USE_FUSED_ATTN > 1 + return _USE_FUSED_ATTN > 0 + + +def set_fused_attn(enable: bool = True, experimental: bool = False): + global _USE_FUSED_ATTN + if not _HAS_FUSED_ATTN: + warnings.warn( + "This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored." + ) + return + if experimental and enable: + _USE_FUSED_ATTN = 2 + elif enable: + _USE_FUSED_ATTN = 1 + else: + _USE_FUSED_ATTN = 0 diff --git a/ptlflow/models/dpflow/local_timm/conv2d_same.py b/ptlflow/models/dpflow/local_timm/conv2d_same.py new file mode 100755 index 00000000..a394ca58 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/conv2d_same.py @@ -0,0 +1,64 @@ +""" Conv2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + +from .padding import pad_same, get_padding_value + + +def conv2d_same( + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + dilation: Tuple[int, int] = (1, 1), + groups: int = 1, +): + x = pad_same(x, weight.shape[-2:], stride, dilation) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """Tensorflow like 'SAME' convolution wrapper for 2D convolutions""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias + ) + + def forward(self, x): + return conv2d_same( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop("padding", "") + kwargs.setdefault("bias", False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) diff --git a/ptlflow/models/dpflow/local_timm/create_act.py b/ptlflow/models/dpflow/local_timm/create_act.py new file mode 100644 index 00000000..b3e2f6ec --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/create_act.py @@ -0,0 +1,139 @@ +""" Activation Factory +Hacked together by / Copyright 2020 Ross Wightman +""" + +from typing import Union, Callable, Type + +from .activations import * +from .activations_me import * +from .config import is_exportable, is_scriptable + +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. +# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. +# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. +_has_silu = "silu" in dir(torch.nn.functional) +_has_hardswish = "hardswish" in dir(torch.nn.functional) +_has_hardsigmoid = "hardsigmoid" in dir(torch.nn.functional) +_has_mish = "mish" in dir(torch.nn.functional) + + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=F.mish if _has_mish else mish, + relu=F.relu, + relu6=F.relu6, + leaky_relu=F.leaky_relu, + elu=F.elu, + celu=F.celu, + selu=F.selu, + gelu=gelu, + gelu_tanh=gelu_tanh, + quick_gelu=quick_gelu, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, + hard_swish=F.hardswish if _has_hardswish else hard_swish, + hard_mish=hard_mish, +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=F.mish if _has_mish else mish_me, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, + hard_swish=F.hardswish if _has_hardswish else hard_swish_me, + hard_mish=hard_mish_me, +) + +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_DEFAULT) +for a in _ACT_FNS: + a.setdefault("hardsigmoid", a.get("hard_sigmoid")) + a.setdefault("hardswish", a.get("hard_swish")) + + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=nn.Mish if _has_mish else Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + leaky_relu=nn.LeakyReLU, + elu=nn.ELU, + prelu=PReLU, + celu=nn.CELU, + selu=nn.SELU, + gelu=GELU, + gelu_tanh=GELUTanh, + quick_gelu=QuickGELU, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, + hard_swish=nn.Hardswish if _has_hardswish else HardSwish, + hard_mish=HardMish, + identity=nn.Identity, +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=nn.Mish if _has_mish else MishMe, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, + hard_mish=HardMishMe, +) + +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_DEFAULT) +for a in _ACT_LAYERS: + a.setdefault("hardsigmoid", a.get("hard_sigmoid")) + a.setdefault("hardswish", a.get("hard_swish")) + + +def get_act_fn(name: Union[Callable, str] = "relu"): + """Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if isinstance(name, Callable): + return name + name = name.lower() + if not (is_exportable() or is_scriptable()): + # If not exporting or scripting the model, first look for a memory-efficient version with + # custom autograd, then fallback + if name in _ACT_FN_ME: + return _ACT_FN_ME[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name: Union[Type[nn.Module], str] = "relu"): + """Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name is None: + return None + if not isinstance(name, str): + # callable, module, etc + return name + if not name: + return None + name = name.lower() + if not (is_exportable() or is_scriptable()): + if name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + return _ACT_LAYER_DEFAULT[name] + + +def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs): + act_layer = get_act_layer(name) + if act_layer is None: + return None + if inplace is None: + return act_layer(**kwargs) + try: + return act_layer(inplace=inplace, **kwargs) + except TypeError: + # recover if act layer doesn't have inplace arg + return act_layer(**kwargs) diff --git a/ptlflow/models/dpflow/local_timm/create_conv2d.py b/ptlflow/models/dpflow/local_timm/create_conv2d.py new file mode 100755 index 00000000..63740d65 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/create_conv2d.py @@ -0,0 +1,42 @@ +""" Create Conv2d Factory Method + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): + """Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + if isinstance(kernel_size, list): + assert ( + "num_experts" not in kwargs + ) # MixNet + CondConv combo not supported currently + if "groups" in kwargs: + groups = kwargs.pop("groups") + if groups == in_channels: + kwargs["depthwise"] = True + else: + assert groups == 1 + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) + else: + depthwise = kwargs.pop("depthwise", False) + # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 + groups = in_channels if depthwise else kwargs.pop("groups", 1) + if "num_experts" in kwargs and kwargs["num_experts"] > 0: + m = CondConv2d( + in_channels, out_channels, kernel_size, groups=groups, **kwargs + ) + else: + m = create_conv2d_pad( + in_channels, out_channels, kernel_size, groups=groups, **kwargs + ) + return m diff --git a/ptlflow/models/dpflow/local_timm/drop.py b/ptlflow/models/dpflow/local_timm/drop.py new file mode 100644 index 00000000..3edd7154 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/drop.py @@ -0,0 +1,57 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch.nn as nn + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" diff --git a/ptlflow/models/dpflow/local_timm/gelu.py b/ptlflow/models/dpflow/local_timm/gelu.py new file mode 100644 index 00000000..cf040c1f --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/gelu.py @@ -0,0 +1,13 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)""" + + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) diff --git a/ptlflow/models/dpflow/local_timm/grn.py b/ptlflow/models/dpflow/local_timm/grn.py new file mode 100644 index 00000000..0db1953b --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/grn.py @@ -0,0 +1,41 @@ +""" Global Response Normalization Module + +Based on the GRN layer presented in +`ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808 + +This implementation +* works for both NCHW and NHWC tensor layouts +* uses affine param names matching existing torch norm layers +* slightly improves eager mode performance via fused addcmul + +Hacked together by / Copyright 2023 Ross Wightman +""" + +import torch +from torch import nn as nn + + +class GlobalResponseNorm(nn.Module): + """Global Response Normalization layer""" + + def __init__(self, dim, eps=1e-6, channels_last=True): + super().__init__() + self.eps = eps + if channels_last: + self.spatial_dim = (1, 2) + self.channel_dim = -1 + self.wb_shape = (1, 1, 1, -1) + else: + self.spatial_dim = (2, 3) + self.channel_dim = 1 + self.wb_shape = (1, -1, 1, 1) + + self.weight = nn.Parameter(torch.zeros(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) + + def forward(self, x): + x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True) + x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps) + return x + torch.addcmul( + self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n + ) diff --git a/ptlflow/models/dpflow/local_timm/helpers.py b/ptlflow/models/dpflow/local_timm/helpers.py new file mode 100644 index 00000000..59f4ad38 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/helpers.py @@ -0,0 +1,80 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from itertools import chain + +import torch +from torch.utils.checkpoint import checkpoint + + +def checkpoint_seq( + functions, x, every=1, flatten=False, skip_last=False, preserve_rng_state=True +): + r"""A helper function for checkpointing sequential models. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a sequence into segments + and checkpoint each segment. All segments except run in :func:`torch.no_grad` + manner, i.e., not storing the intermediate activations. The inputs of each + checkpointed segment will be saved for re-running the segment in the backward pass. + + See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. + + .. warning:: + Checkpointing currently only supports :func:`torch.autograd.backward` + and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` + is not supported. + + .. warning: + At least one of the inputs needs to have :code:`requires_grad=True` if + grads are needed for model inputs, otherwise the checkpointed part of the + model won't have gradients. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. + x: A Tensor that is input to :attr:`functions` + every: checkpoint every-n functions (default: 1) + flatten (bool): flatten nn.Sequential of nn.Sequentials + skip_last (bool): skip checkpointing the last function in the sequence if True + preserve_rng_state (bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_seq(model, input_var, every=2) + """ + + def run_function(start, end, functions): + def forward(_x): + for j in range(start, end + 1): + _x = functions[j](_x) + return _x + + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = functions.children() + if flatten: + functions = chain.from_iterable(functions) + if not isinstance(functions, (tuple, list)): + functions = tuple(functions) + + num_checkpointed = len(functions) + if skip_last: + num_checkpointed -= 1 + end = -1 + for start in range(0, num_checkpointed, every): + end = min(start + every - 1, num_checkpointed - 1) + x = checkpoint( + run_function(start, end, functions), + x, + preserve_rng_state=preserve_rng_state, + ) + if skip_last: + return run_function(end + 1, len(functions) - 1, functions)(x) + return x diff --git a/ptlflow/models/dpflow/local_timm/layer_helpers.py b/ptlflow/models/dpflow/local_timm/layer_helpers.py new file mode 100644 index 00000000..c15e8fda --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/layer_helpers.py @@ -0,0 +1,45 @@ +""" Layer/Module Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from itertools import repeat +import collections.abc + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def make_divisible(v, divisor=8, min_value=None, round_limit=0.9): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v + + +def extend_tuple(x, n): + # pdas a tuple to specified n by padding with last value + if not isinstance(x, (tuple, list)): + x = (x,) + else: + x = tuple(x) + pad_n = n - len(x) + if pad_n <= 0: + return x[:n] + return x + (x[-1],) * pad_n diff --git a/ptlflow/models/dpflow/local_timm/mixed_conv2d.py b/ptlflow/models/dpflow/local_timm/mixed_conv2d.py new file mode 100755 index 00000000..144355dd --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/mixed_conv2d.py @@ -0,0 +1,70 @@ +""" PyTorch Mixed Convolution + +Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding="", + dilation=1, + depthwise=False, + **kwargs, + ): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate( + zip(kernel_size, in_splits, out_splits) + ): + conv_groups = in_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, + out_ch, + k, + stride=stride, + padding=padding, + dilation=dilation, + groups=conv_groups, + **kwargs, + ), + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/ptlflow/models/dpflow/local_timm/mlp.py b/ptlflow/models/dpflow/local_timm/mlp.py new file mode 100644 index 00000000..1025d810 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/mlp.py @@ -0,0 +1,275 @@ +""" MLP module w/ dropout and configurable activation layer + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from functools import partial + +from torch import nn as nn + +from .grn import GlobalResponseNorm +from .layer_helpers import to_2tuple + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + ) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GluMlp(nn.Module): + """MLP w/ GLU style gating + See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.Sigmoid, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + gate_last=True, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert hidden_features % 2 == 0 + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + self.chunk_dim = 1 if use_conv else -1 + self.gate_last = gate_last # use second half of width for gate + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_features // 2) + if norm_layer is not None + else nn.Identity() + ) + self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x1, x2 = x.chunk(2, dim=self.chunk_dim) + x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2 + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False) + + +class SwiGLU(nn.Module): + """SwiGLU + NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and + better matches some other common impl which makes mapping checkpoints simpler. + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.SiLU, + norm_layer=None, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + ) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + nn.init.ones_(self.fc1_g.bias) + nn.init.normal_(self.fc1_g.weight, std=1e-6) + + def forward(self, x): + x_gate = self.fc1_g(x) + x = self.fc1_x(x) + x = self.act(x_gate) * x + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GatedMlp(nn.Module): + """MLP as used in gMLP""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + gate_layer=None, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + if gate_layer is not None: + assert hidden_features % 2 == 0 + self.gate = gate_layer(hidden_features) + hidden_features = ( + hidden_features // 2 + ) # FIXME base reduction on gate property? + else: + self.gate = nn.Identity() + self.norm = ( + norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + ) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.gate(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class ConvMlp(nn.Module): + """MLP using 1x1 convs that keeps spatial dims""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.ReLU, + norm_layer=None, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.drop = nn.Dropout(drop) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x + + +class GlobalResponseNormMlp(nn.Module): + """MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.grn(x) + x = self.fc2(x) + x = self.drop2(x) + return x diff --git a/ptlflow/models/dpflow/local_timm/norm.py b/ptlflow/models/dpflow/local_timm/norm.py new file mode 100755 index 00000000..6428fb24 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/norm.py @@ -0,0 +1,34 @@ +""" Normalization layers and wrappers + +Norm layer definitions that support fast norm and consistent channel arg order (always first arg). + +Hacked together by / Copyright 2022 Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm(nn.LayerNorm): + """LayerNorm w/ fast norm option""" + + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x + + +class LayerNorm2d(nn.LayerNorm): + """LayerNorm for channels of '2D' spatial NCHW tensors""" + + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x diff --git a/ptlflow/models/dpflow/local_timm/padding.py b/ptlflow/models/dpflow/local_timm/padding.py new file mode 100644 index 00000000..d0f0bf55 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/padding.py @@ -0,0 +1,66 @@ +""" Padding Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +from typing import List, Tuple + +import torch.nn.functional as F + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + if isinstance(kernel_size, tuple): + padding = tuple(((stride - 1) + dilation * (ks - 1)) // 2 for ks in kernel_size) + else: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding( + iw, k[1], s[1], d[1] + ) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, + [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], + value=value, + ) + return x + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == "same": + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == "valid": + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic diff --git a/ptlflow/models/dpflow/local_timm/pool2d_same.py b/ptlflow/models/dpflow/local_timm/pool2d_same.py new file mode 100644 index 00000000..87ca29f7 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/pool2d_same.py @@ -0,0 +1,107 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch.nn as nn +import torch.nn.functional as F +from typing import List + +from .layer_helpers import to_2tuple +from .padding import pad_same, get_padding_value + + +def avg_pool2d_same( + x, + kernel_size: List[int], + stride: List[int], + padding: List[int] = (0, 0), + ceil_mode: bool = False, + count_include_pad: bool = True, +): + # FIXME how to deal with count_include_pad vs not for external padding? + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """Tensorflow like 'SAME' wrapper for 2D average pooling""" + + def __init__( + self, + kernel_size: int, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + ): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + super(AvgPool2dSame, self).__init__( + kernel_size, stride, (0, 0), ceil_mode, count_include_pad + ) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride) + return F.avg_pool2d( + x, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + ) + + +def max_pool2d_same( + x, + kernel_size: List[int], + stride: List[int], + padding: List[int] = (0, 0), + dilation: List[int] = (1, 1), + ceil_mode: bool = False, +): + x = pad_same(x, kernel_size, stride, value=-float("inf")) + return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) + + +class MaxPool2dSame(nn.MaxPool2d): + """Tensorflow like 'SAME' wrapper for 2D max pooling""" + + def __init__( + self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False + ): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + super(MaxPool2dSame, self).__init__( + kernel_size, stride, (0, 0), dilation, ceil_mode + ) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride, value=-float("inf")) + return F.max_pool2d( + x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode + ) + + +def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): + stride = stride or kernel_size + padding = kwargs.pop("padding", "") + padding, is_dynamic = get_padding_value( + padding, kernel_size, stride=stride, **kwargs + ) + if is_dynamic: + if pool_type == "avg": + return AvgPool2dSame(kernel_size, stride=stride, **kwargs) + elif pool_type == "max": + return MaxPool2dSame(kernel_size, stride=stride, **kwargs) + else: + assert False, f"Unsupported pool type {pool_type}" + else: + if pool_type == "avg": + return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + elif pool_type == "max": + return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + else: + assert False, f"Unsupported pool type {pool_type}" diff --git a/ptlflow/models/dpflow/local_timm/weight_init.py b/ptlflow/models/dpflow/local_timm/weight_init.py new file mode 100644 index 00000000..c06b0150 --- /dev/null +++ b/ptlflow/models/dpflow/local_timm/weight_init.py @@ -0,0 +1,67 @@ +import torch +import math +import warnings + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) diff --git a/ptlflow/models/dpflow/norm.py b/ptlflow/models/dpflow/norm.py new file mode 100644 index 00000000..7f23689e --- /dev/null +++ b/ptlflow/models/dpflow/norm.py @@ -0,0 +1,122 @@ +""" Normalization layers and wrappers + +Norm layer definitions that support fast norm and consistent channel arg order (always first arg). + +Hacked together by / Copyright 2022 Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm2d(nn.LayerNorm): + """LayerNorm for channels of '2D' spatial NCHW tensors""" + + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + if self.weight is not None: + x = F.layer_norm( + x, + (x.shape[-1],), + self.weight, + self.bias, + self.eps, + ) + else: + x = F.layer_norm(x, (x.shape[-1],), eps=self.eps) + x = x.permute(0, 3, 1, 2) + return x + + +class GroupNorm2d(nn.GroupNorm): + """GroupNorm for channels of '2D' spatial NCHW tensors""" + + def __init__(self, num_groups, num_channels, eps=1e-6, affine=True): + super().__init__(num_groups, num_channels, eps=eps, affine=affine) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.weight is not None: + x = F.group_norm( + x, + self.num_groups, + self.weight, + self.bias, + self.eps, + ) + else: + x = F.group_norm(x, self.num_groups, eps=self.eps) + return x + + +class BatchNorm2d(nn.BatchNorm2d): + def __init__( + self, + num_channels: int, + eps: float = 0.00001, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__( + num_channels, eps, momentum, affine, track_running_stats, device, dtype + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + self._check_input_dim(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked.add_(1) # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + return F.batch_norm( + input, + # If buffers are not to be tracked, ensure that they won't be updated + ( + self.running_mean[: input.shape[1]] + if not self.training or self.track_running_stats + else None + ), + ( + self.running_var[: input.shape[1]] + if not self.training or self.track_running_stats + else None + ), + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) diff --git a/ptlflow/models/dpflow/pwc_modules.py b/ptlflow/models/dpflow/pwc_modules.py new file mode 100644 index 00000000..01169182 --- /dev/null +++ b/ptlflow/models/dpflow/pwc_modules.py @@ -0,0 +1,27 @@ +from __future__ import absolute_import, division, print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def upsample2d_as(inputs, target_as, mode="bilinear"): + _, _, h, w = target_as.size() + if inputs.shape[-2] != h or inputs.shape[-1] != w: + inputs = F.interpolate(inputs, [h, w], mode=mode, align_corners=True) + return inputs + + +def rescale_flow(flow, width_im, height_im, to_local=True): + if to_local: + u_scale = float(flow.size(3) / width_im) + v_scale = float(flow.size(2) / height_im) + else: + u_scale = float(width_im / flow.size(3)) + v_scale = float(height_im / flow.size(2)) + + u, v = flow.chunk(2, dim=1) + u = u * u_scale + v = v * v_scale + + return torch.cat([u, v], dim=1) diff --git a/ptlflow/models/dpflow/requirements.txt b/ptlflow/models/dpflow/requirements.txt new file mode 100644 index 00000000..f226024f --- /dev/null +++ b/ptlflow/models/dpflow/requirements.txt @@ -0,0 +1,136 @@ +absl-py==2.1.0 +aiohappyeyeballs==2.4.3 +aiohttp==3.10.10 +aiosignal==1.3.1 +alabaster==1.0.0 +antlr4-python3-runtime==4.9.3 +attrs==24.2.0 +babel==2.16.0 +bitsandbytes==0.44.1 +black==24.10.0 +certifi==2024.8.30 +charset-normalizer==3.4.0 +click==8.1.7 +coloredlogs==15.0.1 +contourpy==1.3.0 +correlation==0.0.0 +cycler==0.12.1 +docker-pycreds==0.4.0 +docstring_parser==0.16 +docutils==0.21.2 +einops==0.8.0 +filelock==3.16.1 +flatbuffers==24.3.25 +fonttools==4.54.1 +frozenlist==1.5.0 +fsspec==2024.10.0 +GANet==0.0.0 +gitdb==4.0.11 +GitPython==3.1.43 +grpcio==1.67.1 +h5py==3.12.1 +huggingface-hub==0.26.2 +humanfriendly==10.0 +hydra-core==1.3.2 +idna==3.10 +imagesize==1.4.1 +importlib_resources==6.4.5 +iniconfig==2.0.0 +Jinja2==3.1.4 +jsonargparse==4.34.0 +kaleido==0.2.1 +kiwisolver==1.4.7 +lightning==2.4.0 +lightning-utilities==0.11.8 +loguru==0.7.2 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.9.2 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.1.0 +mypy-extensions==1.0.0 +networkx==3.4.2 +numpy==2.1.3 +numpydoc==1.8.0 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +omegaconf==2.3.0 +onnx==1.17.0 +onnxruntime==1.20.1 +opencv-python==4.10.0.84 +packaging==24.2 +pandas==2.2.3 +pathspec==0.12.1 +pillow==11.0.0 +platformdirs==4.3.6 +plotly==5.24.1 +pluggy==1.5.0 +propcache==0.2.0 +protobuf==5.28.3 +psutil==6.1.0 +ptlflow==0.2.5 +Pygments==2.18.0 +pynvml==11.5.3 +pyparsing==3.2.0 +pypng==0.20220715.0 +pytest==8.3.3 +python-dateutil==2.9.0.post0 +pytorch-lightning==2.4.0 +pytz==2024.2 +PyYAML==6.0.2 +quadtree_attention_package==0.0.0 +requests==2.32.3 +rich==13.9.4 +safetensors==0.4.5 +scipy==1.14.1 +sentry-sdk==2.18.0 +setproctitle==1.3.3 +setuptools==75.1.0 +six==1.16.0 +smmap==5.0.1 +snowballstemmer==2.2.0 +Sphinx==8.1.3 +sphinx-rtd-theme==3.0.2 +sphinxcontrib-applehelp==2.0.0 +sphinxcontrib-devhelp==2.0.0 +sphinxcontrib-htmlhelp==2.1.0 +sphinxcontrib-jquery==4.1 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==2.0.0 +sphinxcontrib-serializinghtml==2.0.0 +sympy==1.13.1 +tabulate==0.9.0 +tenacity==9.0.0 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tensorboardX==2.6.2.2 +tensorrt-cu12==10.3.0 +tensorrt-cu12-bindings==10.3.0 +tensorrt-cu12-libs==10.3.0 +timm==1.0.11 +torch==2.5.1 +torch_tensorrt==2.5.0 +torchmetrics==1.5.2 +torchvision==0.20.1 +tqdm==4.67.0 +triton==3.1.0 +typeshed_client==2.7.0 +typing_extensions==4.12.2 +tzdata==2024.2 +urllib3==2.2.3 +wandb==0.18.6 +Werkzeug==3.1.3 +wheel==0.44.0 +yarl==1.17.1 diff --git a/ptlflow/models/dpflow/res_stem.py b/ptlflow/models/dpflow/res_stem.py new file mode 100644 index 00000000..55df9f60 --- /dev/null +++ b/ptlflow/models/dpflow/res_stem.py @@ -0,0 +1,97 @@ +# ============================================================================= +# Copyright 2025 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import torch.nn as nn +from .conv import Conv2dBlock + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_layer, stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = Conv2dBlock( + in_planes, planes, kernel_size=3, padding=1, stride=stride + ) + self.conv2 = Conv2dBlock(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(num_channels=planes) + self.norm2 = norm_layer(num_channels=planes) + if not stride == 1: + self.norm3 = norm_layer(num_channels=planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + Conv2dBlock(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class ResStem(nn.Module): + def __init__(self, hidden_chs, norm_layer): + super(ResStem, self).__init__() + self.norm_fn = norm_layer + + self.norm1 = norm_layer(num_channels=hidden_chs[0]) + + self.conv1 = Conv2dBlock(3, hidden_chs[0], kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = hidden_chs[0] + self.layer1 = self._make_layer(hidden_chs[0], stride=1) + self.layer2 = self._make_layer(hidden_chs[1], stride=2) + + self.conv2 = Conv2dBlock(hidden_chs[1], hidden_chs[2], kernel_size=1) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + + x = self.conv2(x) + + return x diff --git a/ptlflow/models/dpflow/update.py b/ptlflow/models/dpflow/update.py new file mode 100644 index 00000000..0fdff5b1 --- /dev/null +++ b/ptlflow/models/dpflow/update.py @@ -0,0 +1,306 @@ +# ============================================================================= +# Copyright 2025 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import torch +import torch.nn as nn + +from .conv import Conv2dBlock +from .cgu import CGUStage +from .norm import LayerNorm2d + + +class FlowHead(nn.Module): + def __init__( + self, input_dim=128, hidden_dim=256, activation_function=None, info_pred=False + ): + super(FlowHead, self).__init__() + self.conv1 = Conv2dBlock(input_dim, hidden_dim, 3, padding=1) + out_ch = 6 if info_pred else 2 + self.conv2 = Conv2dBlock(hidden_dim, out_ch, 3, padding=1) + + act = nn.ReLU if activation_function is None else activation_function + self.act = act(inplace=True) + + def forward(self, x): + return self.conv2(self.act(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = Conv2dBlock(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = Conv2dBlock(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = Conv2dBlock(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = self.convz(hx) + z = torch.sigmoid(z) + + r = self.convr(hx) + r = torch.sigmoid(r) + + q = self.convq(torch.cat([r * h, x], dim=1)) + q = torch.tanh(q) + + h = (1 - z) * h + z * q + return h + + +class CGUGRU(nn.Module): + def __init__( + self, + hidden_dim=128, + input_dim=192 + 128, + activation_function=None, + norm_layer=LayerNorm2d, + depth=2, + mlp_ratio=4, + mlp_use_dw_conv=True, + mlp_dw_kernel_size=7, + mlp_in_kernel_size=1, + mlp_out_kernel_size=1, + layer_scale_init_value=1e-2, + ): + super(CGUGRU, self).__init__() + + self.convz = CGUStage( + hidden_dim + input_dim, + hidden_dim, + stride=1, + activation_function=activation_function, + norm_layer=norm_layer, + depth=depth, + use_cross=False, + mlp_ratio=mlp_ratio, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + layer_scale_init_value=layer_scale_init_value, + ) + self.convr = CGUStage( + hidden_dim + input_dim, + hidden_dim, + stride=1, + activation_function=activation_function, + norm_layer=norm_layer, + depth=depth, + use_cross=False, + mlp_ratio=mlp_ratio, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + layer_scale_init_value=layer_scale_init_value, + ) + self.convq = CGUStage( + hidden_dim + input_dim, + hidden_dim, + stride=1, + activation_function=activation_function, + norm_layer=norm_layer, + depth=depth, + use_cross=False, + mlp_ratio=mlp_ratio, + mlp_use_dw_conv=mlp_use_dw_conv, + mlp_dw_kernel_size=mlp_dw_kernel_size, + mlp_in_kernel_size=mlp_in_kernel_size, + mlp_out_kernel_size=mlp_out_kernel_size, + layer_scale_init_value=layer_scale_init_value, + ) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = self.convz(hx) + z = torch.sigmoid(z) + + r = self.convr(hx) + r = torch.sigmoid(r) + + q = self.convq(torch.cat([r * h, x], dim=1)) + q = torch.tanh(q) + + h = (1 - z) * h + z * q + return h + + +class ConvexMask(nn.Module): + def __init__(self, net_chs, pred_stride, activation_function=None): + super(ConvexMask, self).__init__() + self.conv1 = Conv2dBlock(net_chs, net_chs * 2, 3, padding=1) + self.conv2 = Conv2dBlock(net_chs * 2, pred_stride**2 * 9, 1, padding=0) + + act = nn.ReLU if activation_function is None else activation_function + self.act = act(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.act(x) + x = self.conv2(x) + return x + + +class MotionEncoder(nn.Module): + def __init__( + self, + dec_motenc_corr_hidden_chs: int, + dec_motenc_corr_out_chs: int, + dec_motenc_flow_hidden_chs: int, + dec_motenc_flow_out_chs: int, + corr_levels: int, + corr_range: int, + dec_flow_kernel_size: int, + dec_motion_chs: int, + activation_function: callable, + ): + super(MotionEncoder, self).__init__() + + c_hidden = dec_motenc_corr_hidden_chs + c_out = dec_motenc_corr_out_chs + f_hidden = dec_motenc_flow_hidden_chs + f_out = dec_motenc_flow_out_chs + + cor_planes = corr_levels * (2 * corr_range + 1) ** 2 + self.convc1 = Conv2dBlock(cor_planes, c_hidden, 1, padding=0) + self.convc2 = Conv2dBlock(c_hidden, c_out, 3, padding=1) + + self.convf1 = Conv2dBlock( + 2, + f_hidden, + dec_flow_kernel_size, + padding=dec_flow_kernel_size // 2, + ) + self.convf2 = Conv2dBlock(f_hidden, f_out, 3, padding=1) + + in_ch = f_out + c_out + out_ch = dec_motion_chs - 2 + self.conv = Conv2dBlock(in_ch, out_ch, 3, padding=1) + + act = nn.ReLU if activation_function is None else activation_function + self.act = act(inplace=True) + + def forward(self, flow, corr): + cor = self.act(self.convc1(corr)) + cor = self.act(self.convc2(cor)) + + outs = [cor] + + flo = self.act(self.convf1(flow)) + flo = self.act(self.convf2(flo)) + outs.append(flo) + + outs = torch.cat(outs, dim=1) + out = self.act(self.conv(outs)) + out_t = [out, flow] + return torch.cat(out_t, dim=1) + + +class UpdateBlock(nn.Module): + def __init__( + self, + dec_motenc_corr_hidden_chs: int, + dec_motenc_corr_out_chs: int, + dec_motenc_flow_hidden_chs: int, + dec_motenc_flow_out_chs: int, + corr_levels: int, + corr_range: int, + dec_flow_kernel_size: int, + dec_motion_chs: int, + activation_function: callable, + net_chs_fixed: int, + inp_chs_fixed: int, + dec_gru_norm_layer: callable, + dec_gru_depth: int, + dec_gru_iters: int, + dec_gru_mlp_ratio: float, + cgu_mlp_use_dw_conv: bool, + cgu_mlp_dw_kernel_size: int, + dec_gru_mlp_in_kernel_size: int, + dec_gru_mlp_out_kernel_size: int, + cgu_layer_scale_init_value: float, + dec_flow_head_chs: int, + loss: str, + use_upsample_mask: bool, + upmask_gradient_scale: float, + ): + super(UpdateBlock, self).__init__() + self.use_upsample_mask = use_upsample_mask + self.upmask_gradient_scale = upmask_gradient_scale + + self.encoder = MotionEncoder( + dec_motenc_corr_hidden_chs=dec_motenc_corr_hidden_chs, + dec_motenc_corr_out_chs=dec_motenc_corr_out_chs, + dec_motenc_flow_hidden_chs=dec_motenc_flow_hidden_chs, + dec_motenc_flow_out_chs=dec_motenc_flow_out_chs, + corr_levels=corr_levels, + corr_range=corr_range, + dec_flow_kernel_size=dec_flow_kernel_size, + dec_motion_chs=dec_motion_chs, + activation_function=activation_function, + ) + + self.gru_list = nn.ModuleList( + [ + CGUGRU( + hidden_dim=net_chs_fixed, + input_dim=dec_motion_chs + inp_chs_fixed, + activation_function=activation_function, + norm_layer=dec_gru_norm_layer, + depth=dec_gru_depth, + mlp_ratio=dec_gru_mlp_ratio, + mlp_use_dw_conv=cgu_mlp_use_dw_conv, + mlp_dw_kernel_size=cgu_mlp_dw_kernel_size, + mlp_in_kernel_size=dec_gru_mlp_in_kernel_size, + mlp_out_kernel_size=dec_gru_mlp_out_kernel_size, + layer_scale_init_value=cgu_layer_scale_init_value, + ) + for _ in range(dec_gru_iters) + ] + ) + + self.flow_head = FlowHead( + net_chs_fixed, + hidden_dim=dec_flow_head_chs, + activation_function=activation_function, + info_pred=(loss == "laplace"), + ) + + if use_upsample_mask: + pred_stride = 8 + self.mask = ConvexMask( + net_chs_fixed, + pred_stride, + activation_function=activation_function, + ) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + + inp = torch.cat([inp, motion_features], dim=1) + + for gru in self.gru_list: + net = gru(net, inp) + + delta_flow = self.flow_head(net) + + mask = None + if self.use_upsample_mask: + mask = self.upmask_gradient_scale * self.mask(net) + + return delta_flow, net, mask diff --git a/ptlflow/models/dpflow/utils.py b/ptlflow/models/dpflow/utils.py new file mode 100644 index 00000000..5250e7cc --- /dev/null +++ b/ptlflow/models/dpflow/utils.py @@ -0,0 +1,77 @@ +# ============================================================================= +# Copyright 2025 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from functools import partial +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .norm import GroupNorm2d, LayerNorm2d, BatchNorm2d +from .local_timm.gelu import GELU + + +def compute_pyramid_levels(x): + img_diag = math.sqrt((x.shape[-2] ** 2) + (x.shape[-1] ** 2)) + input_factor = max( + 1, img_diag / 1100 + ) # 1100 ~= math.sqrt((960 ** 2) + (540 ** 2)), i.e., 1K resolution + pyr_levels = int(round(math.log2(input_factor))) + 3 + return pyr_levels + + +def bilinear_sampler(img, coords, mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def get_activation(name): + if name == "relu": + return nn.ReLU + elif name == "gelu": + return GELU + elif name == "silu": + return nn.SiLU + elif name == "mish": + return nn.Mish + elif name == "linear": + return nn.Identity + else: + return None + + +def get_norm(name, affine=False, num_groups=8): + if name == "group": + return partial(GroupNorm2d, affine=affine, num_groups=num_groups) + elif name == "layer": + return partial(LayerNorm2d, affine=affine) + elif name == "batch": + return BatchNorm2d + else: + return None diff --git a/ptlflow/models/rapidflow/configs/rapidflow-train1-chairs.yaml b/ptlflow/models/rapidflow/configs/rapidflow-train1-chairs.yaml index 48dabbd5..c4f166ab 100644 --- a/ptlflow/models/rapidflow/configs/rapidflow-train1-chairs.yaml +++ b/ptlflow/models/rapidflow/configs/rapidflow-train1-chairs.yaml @@ -1,5 +1,7 @@ # lightning.pytorch==2.4.0 seed_everything: true +lr: 0.0004 +wdecay: 0.0001 trainer: max_epochs: 10 accumulate_grad_batches: 1 @@ -30,8 +32,6 @@ model: simple_io: false gamma: 0.8 max_flow: 400 - lr: 0.0004 - wdecay: 0.0001 warm_start: false data: train_dataset: chairs diff --git a/ptlflow/models/rapidflow/configs/rapidflow-train2-things.yaml b/ptlflow/models/rapidflow/configs/rapidflow-train2-things.yaml index 5cee2ae5..4ae1f116 100644 --- a/ptlflow/models/rapidflow/configs/rapidflow-train2-things.yaml +++ b/ptlflow/models/rapidflow/configs/rapidflow-train2-things.yaml @@ -1,6 +1,8 @@ # lightning.pytorch==2.4.0 seed_everything: true ckpt_path: /path/to/chairs.ckpt # Change to the ckpt resulting from rapidflow-train1-chairs +lr: 0.000125 +wdecay: 0.0001 trainer: max_epochs: 10 accumulate_grad_batches: 1 @@ -31,8 +33,6 @@ model: simple_io: false gamma: 0.8 max_flow: 400 - lr: 0.000125 - wdecay: 0.0001 warm_start: false data: train_dataset: things diff --git a/ptlflow/models/rapidflow/configs/rapidflow-train3-sintel.yaml b/ptlflow/models/rapidflow/configs/rapidflow-train3-sintel.yaml index 7b17ce9e..5adf04e8 100644 --- a/ptlflow/models/rapidflow/configs/rapidflow-train3-sintel.yaml +++ b/ptlflow/models/rapidflow/configs/rapidflow-train3-sintel.yaml @@ -1,6 +1,8 @@ # lightning.pytorch==2.4.0 seed_everything: true ckpt_path: /path/to/things.ckpt # Change to the ckpt resulting from rapidflow-train2-things +lr: 0.000125 +wdecay: 0.00001 trainer: max_epochs: 4 accumulate_grad_batches: 1 @@ -31,8 +33,6 @@ model: simple_io: false gamma: 0.85 max_flow: 400 - lr: 0.000125 - wdecay: 0.00001 warm_start: false data: train_dataset: sintel_finetune diff --git a/ptlflow/models/rapidflow/configs/rapidflow-train4-kitti.yaml b/ptlflow/models/rapidflow/configs/rapidflow-train4-kitti.yaml index ef4fea04..415b5fc3 100644 --- a/ptlflow/models/rapidflow/configs/rapidflow-train4-kitti.yaml +++ b/ptlflow/models/rapidflow/configs/rapidflow-train4-kitti.yaml @@ -1,6 +1,8 @@ # lightning.pytorch==2.4.0 seed_everything: true ckpt_path: /path/to/sintel.ckpt # Change to the ckpt resulting from rapidflow-train3-sintel +lr: 0.000125 +wdecay: 0.00001 trainer: max_epochs: 300 check_val_every_n_epoch: 10 @@ -32,8 +34,6 @@ model: simple_io: false gamma: 0.85 max_flow: 400 - lr: 0.000125 - wdecay: 0.00001 warm_start: false data: train_dataset: kitti-2015 diff --git a/ptlflow/models/rpknet/configs/rpknet-train1-chairs.yaml b/ptlflow/models/rpknet/configs/rpknet-train1-chairs.yaml index 2bd82708..3c8f1e67 100644 --- a/ptlflow/models/rpknet/configs/rpknet-train1-chairs.yaml +++ b/ptlflow/models/rpknet/configs/rpknet-train1-chairs.yaml @@ -1,5 +1,7 @@ # lightning.pytorch==2.4.0 seed_everything: true +lr: 0.0004 +wdecay: 0.0001 trainer: max_epochs: 45 accumulate_grad_batches: 1 @@ -41,8 +43,6 @@ model: cache_pkconv_weights: false gamma: 0.8 max_flow: 400 - lr: 0.0004 - wdecay: 0.0001 warm_start: false data: train_dataset: chairs diff --git a/ptlflow/models/rpknet/configs/rpknet-train2-things.yaml b/ptlflow/models/rpknet/configs/rpknet-train2-things.yaml index 0507af33..2b9db221 100644 --- a/ptlflow/models/rpknet/configs/rpknet-train2-things.yaml +++ b/ptlflow/models/rpknet/configs/rpknet-train2-things.yaml @@ -1,6 +1,8 @@ # lightning.pytorch==2.4.0 seed_everything: true ckpt_path: /path/to/chairs.ckpt # Change to the ckpt resulting from rpknet-train1-chairs +lr: 0.000125 +wdecay: 0.0001 trainer: max_epochs: 80 accumulate_grad_batches: 1 @@ -42,8 +44,6 @@ model: cache_pkconv_weights: false gamma: 0.8 max_flow: 400 - lr: 0.000125 - wdecay: 0.0001 warm_start: false data: train_dataset: things diff --git a/ptlflow/models/rpknet/configs/rpknet-train3-sintel.yaml b/ptlflow/models/rpknet/configs/rpknet-train3-sintel.yaml index ce52631b..5e10bc14 100644 --- a/ptlflow/models/rpknet/configs/rpknet-train3-sintel.yaml +++ b/ptlflow/models/rpknet/configs/rpknet-train3-sintel.yaml @@ -1,6 +1,8 @@ # lightning.pytorch==2.4.0 seed_everything: true ckpt_path: /path/to/things.ckpt # Change to the ckpt resulting from rpknet-train2-things +lr: 0.000125 +wdecay: 0.00001 trainer: max_epochs: 5 accumulate_grad_batches: 1 @@ -42,8 +44,6 @@ model: cache_pkconv_weights: false gamma: 0.85 max_flow: 400 - lr: 0.000125 - wdecay: 0.00001 warm_start: false data: train_dataset: sintel_finetune diff --git a/ptlflow/models/rpknet/configs/rpknet-train4-kitti.yaml b/ptlflow/models/rpknet/configs/rpknet-train4-kitti.yaml index 7e3c9c55..18996386 100644 --- a/ptlflow/models/rpknet/configs/rpknet-train4-kitti.yaml +++ b/ptlflow/models/rpknet/configs/rpknet-train4-kitti.yaml @@ -1,6 +1,8 @@ # lightning.pytorch==2.4.0 seed_everything: true ckpt_path: /path/to/sintel.ckpt # Change to the ckpt resulting from rpknet-train3-sintel +lr: 0.000125 +wdecay: 0.00001 trainer: max_epochs: 150 check_val_every_n_epoch: 10 @@ -43,8 +45,6 @@ model: cache_pkconv_weights: false gamma: 0.85 max_flow: 400 - lr: 0.000125 - wdecay: 0.00001 warm_start: false data: train_dataset: kitti-2015 diff --git a/ptlflow/utils/flow_metrics.py b/ptlflow/utils/flow_metrics.py index 051e284f..2cf39c29 100644 --- a/ptlflow/utils/flow_metrics.py +++ b/ptlflow/utils/flow_metrics.py @@ -20,10 +20,12 @@ # limitations under the License. # ============================================================================= -from typing import Dict +from typing import Dict, Sequence -from torchmetrics import Metric +import numpy as np import torch +import torch.nn.functional as F +from torchmetrics import Metric class FlowMetrics(Metric): @@ -48,6 +50,7 @@ def __init__( average_mode: str = "epoch_mean", ema_decay: float = 0.99, f1_mode: str = "macro", + interpolate_pred_to_target_size: bool = False, ) -> None: """Initialize FlowMetrics. @@ -65,6 +68,8 @@ def __init__( How to calculate the f1-score. Accepts one of these options {binary, macro, weighted}. If binary, then the f1-score is calculated only for the positive pixels. If macro, then the f1-score is the average of positive and negative scores. If weighted, then the average is weighted according to the number of positive/negative samples. + interpolate_pred_to_target_size : bool, default False + If True, the prediction is bilinearly interpolated to match the target size, if their sizes are different. """ super().__init__(dist_sync_on_step=dist_sync_on_step) @@ -75,6 +80,7 @@ def __init__( self.ema_decay = ema_decay self.f1_mode = f1_mode self.ema_max_count = min(100, int(1.0 / (1.0 - ema_decay))) + self.interpolate_pred_to_target_size = interpolate_pred_to_target_size self.add_state("epe", default=torch.tensor(0).float(), dist_reduce_fx="sum") self.add_state( @@ -150,8 +156,30 @@ def update( prev_weight = self.ema_decay next_weight = 1.0 - self.ema_decay + metric_preds = {} + if self.interpolate_pred_to_target_size: + for k, v in preds.items(): + if isinstance(v, torch.Tensor): + v, orig_shape = self._to_bchw_shape(v) + target_size = targets["flows"].shape[-2:] + v = F.interpolate( + v, target_size, mode="bilinear", align_corners=True + ) + new_shape = list(orig_shape[:-2]) + list(target_size) + v = v.view(*new_shape) + + if "flow" in k: + scale_y = float(target_size[-2]) / orig_shape[-2] + scale_x = float(target_size[-1]) / orig_shape[-1] + v[..., 0, :, :] *= scale_x + v[..., 1, :, :] *= scale_y + + metric_preds[k] = v + else: + metric_preds = preds + batch_size = self._get_batch_size(targets["flows"]) - flow_pred = self._fix_shape(preds["flows"], batch_size) + flow_pred = self._fix_shape(metric_preds["flows"], batch_size) flow_target = self._fix_shape(targets["flows"], batch_size) valid_target = targets.get("valids") @@ -208,24 +236,24 @@ def update( ) self.include_occlusion = True - if preds.get("occs") is not None: - occlusion_pred = self._fix_shape(preds["occs"], batch_size) + if metric_preds.get("occs") is not None: + occlusion_pred = self._fix_shape(metric_preds["occs"], batch_size) occ_f1 = self._f1_score( occlusion_pred, occlusion_target, mode=self.f1_mode ) self.used_keys.extend([("occ_f1", "occ_f1", "valid_target")]) - if preds.get("mbs") is not None and targets.get("mbs") is not None: - mb_pred = self._fix_shape(preds["mbs"], batch_size) + if metric_preds.get("mbs") is not None and targets.get("mbs") is not None: + mb_pred = self._fix_shape(metric_preds["mbs"], batch_size) mb_target = self._fix_shape(targets["mbs"], batch_size) mb_f1 = self._f1_score(mb_pred, mb_target, mode=self.f1_mode) self.used_keys.extend([("mb_f1", "mb_f1", "valid_target")]) - if preds.get("confs") is not None: + if metric_preds.get("confs") is not None: conf_target = torch.exp( -torch.pow(flow_target - flow_pred, 2).sum(dim=1, keepdim=True) ) - conf_pred = self._fix_shape(preds["confs"], batch_size) + conf_pred = self._fix_shape(metric_preds["confs"], batch_size) conf_f1 = self._f1_score(conf_pred, conf_target, mode=self.f1_mode) self.used_keys.extend([("conf_f1", "conf_f1", "valid_target")]) @@ -380,6 +408,22 @@ def _fix_shape(self, tensor: torch.Tensor, batch_size: int) -> torch.Tensor: ) return tensor + def _to_bchw_shape(self, tensor) -> tuple[torch.Tensor, Sequence[int]]: + orig_shape = tensor.shape + if len(tensor.shape) == 2: + tensor = tensor[None, None] + elif len(tensor.shape) == 3: + tensor = tensor[None] + elif len(tensor.shape) > 4: + batch_size = int(np.prod(orig_shape[:-3])) + tensor = tensor.view( + batch_size, + tensor.shape[-3], + tensor.shape[-2], + tensor.shape[-1], + ) + return tensor, orig_shape + def _get_batch_size(self, flow_tensor: torch.Tensor) -> int: if len(flow_tensor.shape) < 4: return 1 diff --git a/tests/ptlflow/models/test_models.py b/tests/ptlflow/models/test_models.py index 54305aad..70d9bdfe 100644 --- a/tests/ptlflow/models/test_models.py +++ b/tests/ptlflow/models/test_models.py @@ -87,11 +87,10 @@ def test_forward() -> None: model = ptlflow.get_model(mname, args=args) model = model.eval() - s = make_divisible(256, model.output_stride) num_images = 2 if mname in ["videoflow_bof", "videoflow_mof"]: num_images = 3 - inputs = {"images": torch.rand(1, num_images, 3, s, s)} + inputs = {"images": torch.rand(1, num_images, 3, 256, 256)} if torch.cuda.is_available(): model = model.cuda() diff --git a/train.py b/train.py index fa57b5f1..67c076a8 100644 --- a/train.py +++ b/train.py @@ -29,6 +29,8 @@ def _init_parser(): parser = ArgumentParser(add_help=False) + parser.add_argument("--lr", type=float, default=None) + parser.add_argument("--wdecay", type=float, default=None) parser.add_argument("--ckpt_path", type=str, default=None) parser.add_argument("--project", type=str, default="ptlflow") parser.add_argument("--version", type=str, default=None) @@ -155,6 +157,8 @@ def cli_main(): cfg.trainer.logger = trainer_logger cfg.trainer.callbacks = callbacks + cfg.model.init_args.lr = cfg.lr + cfg.model.init_args.wdecay = cfg.wdecay cli = PTLFlowCLI( model_class=RegisteredModel, subclass_mode_model=True, diff --git a/validate.py b/validate.py index e2d1ad7f..6af52da6 100644 --- a/validate.py +++ b/validate.py @@ -236,6 +236,9 @@ def validate( if args.fp16: model = model.half() + if args.scale_factor is not None and args.scale_factor != 1.0: + model.metric_interpolate_pred_to_target_size = True + data_module.setup("validate") dataloaders = data_module.val_dataloader() dataloaders = {