Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
lightning: ["2.1.4", "2.2.5", "2.3.3", "2.4.0"]
lightning: ["2.1.4", "2.2.5", "2.3.3", "2.4.0", "2.5.0"]

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ jobs:
fail-fast: false
matrix:
pytorch: [
'torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cpu',
'torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu',
'torch==2.4.1 torchvision==0.19.1 --index-url https://download.pytorch.org/whl/cpu',
'torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cpu',
'torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cpu',
]

steps:
Expand Down
15 changes: 6 additions & 9 deletions model_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@
from ptlflow.utils.lightning.ptlflow_cli import PTLFlowCLI
from ptlflow.utils.registry import RegisteredModel
from ptlflow.utils.timer import Timer
from ptlflow.utils.utils import (
count_parameters,
make_divisible,
)
from ptlflow.utils.utils import count_parameters

NUM_COMMON_COLUMNS = 6
TABLE_KEYS_LEGENDS = {
Expand Down Expand Up @@ -302,8 +299,8 @@ def benchmark(args: Namespace, device_handle) -> pd.DataFrame:
1,
2,
3,
make_divisible(input_size[0], model.output_stride),
make_divisible(input_size[1], model.output_stride),
input_size[0],
input_size[1],
)
}

Expand Down Expand Up @@ -372,7 +369,7 @@ def benchmark(args: Namespace, device_handle) -> pd.DataFrame:
)
except Exception as e: # noqa: B902
logger.warning(
"Skipping model %s with datatype %s due to exception %s",
"Skipping model {} with datatype {} due to exception {}",
mname,
dtype_str,
e,
Expand Down Expand Up @@ -440,8 +437,8 @@ def estimate_inference_time(
args.batch_size,
2,
3,
make_divisible(input_size[0], model.output_stride),
make_divisible(input_size[1], model.output_stride),
input_size[0],
input_size[1],
)
}
if torch.cuda.is_available():
Expand Down
3 changes: 1 addition & 2 deletions plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
import argparse
import logging
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Union

import numpy as np
import pandas as pd
import plotly.express as px

Expand Down
5 changes: 4 additions & 1 deletion ptlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,17 @@ def load_checkpoint(ckpt_path: str, model_ref: BaseModel) -> Dict[str, Any]:
device = "cuda" if torch.cuda.is_available() else "cpu"

if Path(ckpt_path).exists():
ckpt = torch.load(ckpt_path, map_location=torch.device(device))
ckpt = torch.load(
ckpt_path, map_location=torch.device(device), weights_only=True
)
else:
model_dir = Path(hub.get_dir()) / "checkpoints"
ckpt = hub.load_state_dict_from_url(
ckpt_path,
model_dir=model_dir,
map_location=torch.device(device),
check_hash=True,
weights_only=True,
)
return ckpt

Expand Down
44 changes: 28 additions & 16 deletions ptlflow/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from loguru import logger
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from ptlflow.utils import flow_utils

Expand Down Expand Up @@ -1662,7 +1663,6 @@ def __init__( # noqa: C901
reverse_only: bool = False,
subsample: bool = False,
is_image_4k: bool = False,
image_4k_split_dir_suffix: str = "_4k",
) -> None:
"""Initialize SintelDataset.

Expand Down Expand Up @@ -1705,11 +1705,6 @@ def __init__( # noqa: C901
If False, and is_image_4k is True, then the groundtruth is returned in its original 4D-shaped 4K resolution, but the flow values are doubled.
is_image_4k : bool, default False
If True, assumes the input images will be provided in 4K resolution, instead of the original 2K.
image_4k_split_dir_suffix : str, default "_4k"
Only used when is_image_4k == True. It indicates the suffix to add to the split folder name where the 4k images are located.
For example, by default, the 4K images need to be located inside folders called "train_4k" and/or "test/4k".
The structure of these folders should be the same as the original "train" and "test".
The "*_4k" folders only need to contain the image directories, the groundtruth will still be loaded from the original locations.
"""
if isinstance(side_names, str):
side_names = [side_names]
Expand All @@ -1731,7 +1726,6 @@ def __init__( # noqa: C901
self.sequence_position = sequence_position
self.subsample = subsample
self.is_image_4k = is_image_4k
self.image_4k_split_dir_suffix = image_4k_split_dir_suffix

if self.is_image_4k:
assert not self.subsample
Expand All @@ -1758,17 +1752,9 @@ def __init__( # noqa: C901
for side in side_names:
for direcs in directions:
rev = direcs[0] == "BW"
img_split_dir_name = (
f"{split_dir}{self.image_4k_split_dir_suffix}"
if self.is_image_4k
else split_dir
)
image_paths = sorted(
(
Path(self.root_dir)
/ img_split_dir_name
/ seq_name
/ f"frame_{side}"
Path(self.root_dir) / split_dir / seq_name / f"frame_{side}"
).glob("*.png"),
reverse=rev,
)
Expand Down Expand Up @@ -1883,12 +1869,38 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901
if self.transform is not None:
inputs = self.transform(inputs)
elif self.is_image_4k:
inputs["images"] = [
cv.resize(img, None, fx=2, fy=2, interpolation=cv.INTER_CUBIC)
for img in inputs["images"]
]
if self.transform is not None:
inputs = self.transform(inputs)
if "flows" in inputs:
inputs["flows"] = 2 * inputs["flows"]
if self.get_backward:
inputs["flows_b"] = 2 * inputs["flows_b"]

process_keys = [("flows", "valids")]
if self.get_backward:
process_keys.append(("flows_b", "valids_b"))

for flow_key, valid_key in process_keys:
flow = inputs[flow_key]
flow_stack = rearrange(
flow, "b c (h nh) (w nw) -> b (nh nw) c h w", nh=2, nw=2
)
flow_stack4 = flow_stack.repeat(1, 4, 1, 1, 1)
flow_stack4 = rearrange(
flow_stack4, "b (m n) c h w -> b m n c h w", m=4
)
diff = flow_stack[:, :, None] - flow_stack4
diff = rearrange(diff, "b m n c h w -> b (m n) c h w")
diff = torch.sqrt(torch.pow(diff, 2).sum(2))
max_diff, _ = diff.max(1)
max_diff = F.interpolate(
max_diff[:, None], scale_factor=2, mode="nearest"
)
inputs[valid_key] = (max_diff < 1.0).float()
else:
if self.transform is not None:
inputs = self.transform(inputs)
Expand Down
8 changes: 8 additions & 0 deletions ptlflow/data/flow_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
tartanair_root_dir: Optional[str] = None,
spring_root_dir: Optional[str] = None,
kubric_root_dir: Optional[str] = None,
middlebury_st_root_dir: Optional[str] = None,
viper_root_dir: Optional[str] = None,
dataset_config_path: str = "./datasets.yaml",
):
super().__init__()
Expand All @@ -89,6 +91,8 @@ def __init__(
self.tartanair_root_dir = tartanair_root_dir
self.spring_root_dir = spring_root_dir
self.kubric_root_dir = kubric_root_dir
self.middlebury_st_root_dir = middlebury_st_root_dir
self.viper_root_dir = viper_root_dir
self.dataset_config_path = dataset_config_path

self.predict_dataset_parsed = None
Expand Down Expand Up @@ -935,6 +939,7 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset:
sequence_position = "first"
reverse_only = False
subsample = False
is_image_4k = False
side_names = []
fbocc_transform = False
for v in args:
Expand All @@ -952,6 +957,8 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset:
sequence_position = v.split("_")[1]
elif v == "sub":
subsample = True
elif v == "4k":
is_image_4k = True
elif v == "left":
side_names.append("left")
elif v == "right":
Expand Down Expand Up @@ -1012,6 +1019,7 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset:
sequence_position=sequence_position,
reverse_only=reverse_only,
subsample=subsample,
is_image_4k=is_image_4k,
)
return dataset

Expand Down
1 change: 1 addition & 0 deletions ptlflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .csflow import *
from .dicl import *
from .dip import *
from .dpflow import *
from .fastflownet import *
from .flow1d import *
from .flowformer import *
Expand Down
33 changes: 25 additions & 8 deletions ptlflow/models/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,25 @@ def __init__(
lr: Optional[float] = None,
wdecay: Optional[float] = None,
warm_start: bool = False,
metric_interpolate_pred_to_target_size: bool = False,
) -> None:
"""Initialize BaseModel.

Parameters
----------
args : Namespace
A namespace with the required arguments. Typically, this can be gotten from add_model_specific_args().
loss_fn : Callable
A function to be used to compute the loss for the training. The input of this function must match the output of the
forward() method. The output of this function must be a tensor with a single value.
output_stride : int
How many times the output of the network is smaller than the input.
loss_fn : Optional[Callable]
A function to be used to compute the loss for the training. The input of this function must match the output of the
forward() method. The output of this function must be a tensor with a single value.
lr : Optional[float]
The learning rate to be used for training the model. If not provided, it will be set as 1e-4.
wdecay : Optional[float]
The weight decay to be used for training the model. If not provided, it will be set as 1e-4.
warm_start : bool, default False
If True, use warm start to initialize the flow prediction. The warm_start strategy was presented by the RAFT method and forward interpolates the prediction from the last frame.
metric_interpolate_pred_to_target_size : bool, default False
If True, the prediction is bilinearly interpolated to match the target size during metric calculation, if their sizes are different.
"""
super(BaseModel, self).__init__()

Expand All @@ -89,13 +96,19 @@ def __init__(
self.lr = lr
self.wdecay = wdecay
self.warm_start = warm_start
self.metric_interpolate_pred_to_target_size = (
metric_interpolate_pred_to_target_size
)

self.train_size = None
self.train_avg_length = None

self.extra_params = None

self.train_metrics = FlowMetrics(prefix="train/")
self.train_metrics = FlowMetrics(
prefix="train/",
interpolate_pred_to_target_size=self.metric_interpolate_pred_to_target_size,
)
self.val_metrics = nn.ModuleList()
self.val_dataset_names = []

Expand Down Expand Up @@ -132,6 +145,7 @@ def add_extra_param(self, name, value):
def preprocess_images(
self,
images: torch.Tensor,
stride: Optional[int] = None,
bgr_add: Union[float, Tuple[float, float, float], np.ndarray, torch.Tensor] = 0,
bgr_mult: Union[
float, Tuple[float, float, float], np.ndarray, torch.Tensor
Expand Down Expand Up @@ -201,7 +215,7 @@ def preprocess_images(
if bgr_to_rgb:
images = torch.flip(images, [-3])

stride = self.output_stride
stride = self.output_stride if stride is None else stride
if target_size is not None:
stride = None

Expand Down Expand Up @@ -371,7 +385,10 @@ def validation_step(
"""
if len(self.val_metrics) <= dataloader_idx:
self.val_metrics.append(
FlowMetrics(prefix="val/").to(device=batch["flows"].device)
FlowMetrics(
prefix="val/",
interpolate_pred_to_target_size=self.metric_interpolate_pred_to_target_size,
).to(device=batch["flows"].device)
)
self.val_dataset_names.append(None)

Expand Down
Loading