Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pystiche_papers/sanakoyeu_et_al_2018/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SequentialNumIterationsBatchSampler,
)

from ..utils import OptionalGrayscaleToFakegrayscale
from ._augmentation import (
AugmentationBase2d,
_adapted_uniform,
Expand Down Expand Up @@ -243,6 +244,7 @@ def image_transform(impl_params: bool = True, edge_size: int = 768) -> nn.Sequen
transforms_: List[nn.Module] = [
ClampSize() if impl_params else OptionalUpsample(edge_size),
]
transforms_.append(OptionalGrayscaleToFakegrayscale())
# https://github.com/pmeier/adaptive-style-transfer/blob/07a3b3fcb2eeed2bf9a22a9de59c0aea7de44181/model.py#L286-L287
# https://github.com/pmeier/adaptive-style-transfer/blob/07a3b3fcb2eeed2bf9a22a9de59c0aea7de44181/model.py#L291-L292
# https://github.com/pmeier/adaptive-style-transfer/blob/07a3b3fcb2eeed2bf9a22a9de59c0aea7de44181/model.py#L271-L276
Expand Down
6 changes: 4 additions & 2 deletions pystiche_papers/sanakoyeu_et_al_2018/_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(
loss += self.prediction_loss(input_photo).aggregate(0)
accuracies.append(self.prediction_loss.get_accuracy())

self.accuracy = torch.mean(torch.cat(accuracies))
self.accuracy = torch.mean(torch.stack(accuracies))

return cast(torch.Tensor, loss)

Expand Down Expand Up @@ -176,6 +176,7 @@ def style_aware_content_loss(

def transformer_loss(
encoder: SequentialEncoder,
prediction_loss: Optional[MultiLayerPredictionOperator] = None,
impl_params: bool = True,
style_aware_content_kwargs: Optional[Dict[str, Any]] = None,
transformed_image_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -184,6 +185,7 @@ def transformer_loss(

Args:
encoder: :class:`~pystiche.enc.SequentialEncoder`.
prediction_loss: Trainable :class:`MultiLayerPredictionOperator`.
impl_params: If ``True``, uses the parameters used in the reference
implementation of the original authors rather than what is described in
the paper.
Expand Down Expand Up @@ -214,5 +216,5 @@ def transformer_loss(
)
)

style_loss = style_loss_(impl_params)
style_loss = cast(ops.OperatorContainer, prediction_loss)
return loss.PerceptualLoss(content_loss, style_loss)
48 changes: 41 additions & 7 deletions pystiche_papers/sanakoyeu_et_al_2018/_nst.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Callable, Optional, Tuple, Union, cast

import torch
Expand All @@ -6,7 +7,7 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pystiche import loss, misc
from pystiche import LossDict, loss, misc, optim
from pystiche.image.transforms import functional as F

from ._data import content_dataset, image_loader, style_dataset
Expand Down Expand Up @@ -85,6 +86,10 @@ def gan_optim_loop(
style_transform = _maybe_extract_transform(style_image_loader)

device = misc.get_device()

logger = optim.OptimLogger()
log_fn = optim.default_transformer_optim_log_fn(logger, len(content_image_loader))
quiet = False
style_image_loader = iter(style_image_loader)

if isinstance(content_transform, nn.Module):
Expand All @@ -106,26 +111,46 @@ def gan_optim_loop(
"discriminator_success", init_val=target_win_rate
)

def logging(
loss: LossDict, batch_size: int, loading_time: float, processing_time: float
) -> None:
image_loading_velocity = batch_size / max(loading_time, 1e-6)
image_processing_velocity = batch_size / max(processing_time, 1e-6)
# See https://github.com/pmeier/pystiche/pull/264#discussion_r430205029
log_fn(batch, loss, image_loading_velocity, image_processing_velocity)

def train_discriminator_one_step(
output_image: torch.Tensor,
style_image: torch.Tensor,
input_image: Optional[torch.Tensor] = None,
) -> None:
def closure() -> float:
processing_time_start = time.time()

cast(Optimizer, discriminator_optimizer).zero_grad()
loss = discriminator_criterion(output_image, style_image, input_image)
loss.backward()

processing_time = time.time() - processing_time_start

logging(loss, output_image.size()[0], loading_time, processing_time)
return cast(float, loss.item())

cast(Optimizer, discriminator_optimizer).step(closure)
discriminator_success.update(discriminator_criterion.accuracy)

def train_transformer_one_step(output_image: torch.Tensor) -> None:
def closure() -> float:
processing_time_start = time.time()

cast(Optimizer, transformer_optimizer).zero_grad()
cast(MultiLayerPredictionOperator, transformer_criterion.style_loss).real()
loss = transformer_criterion(output_image)
loss.backward()

processing_time = time.time() - processing_time_start

logging(loss, output_image.size()[0], loading_time, processing_time)
return cast(float, loss.item())

cast(Optimizer, transformer_optimizer).step(closure)
Expand All @@ -134,19 +159,24 @@ def closure() -> float:
).get_accuracy()
discriminator_success.update(1.0 - accuracy)

for content_image in content_image_loader:
loading_time_start = time.time()
for batch, content_image in enumerate(content_image_loader, 1):
content_image = content_image.squeeze(1)
input_image = content_image.to(device)
if content_transform is not None:
input_image = content_transform(input_image)
input_image = preprocessor(input_image)

loading_time = time.time() - loading_time_start

output_image = transformer(input_image)

if discriminator_success.local_avg < target_win_rate:
if discriminator_success.global_avg < target_win_rate:
style_image = next(style_image_loader)
style_image = style_image.to(device)
style_image = style_image.to(device).squeeze(1)
if style_transform is not None:
style_image = style_transform(style_image)

style_image = preprocessor(style_image)

train_discriminator_one_step(
Expand All @@ -158,6 +188,8 @@ def closure() -> float:
transformer_criterion_update_fn(input_image, transformer_criterion)
train_transformer_one_step(output_image)

loading_time_start = time.time()

return transformer


Expand Down Expand Up @@ -298,15 +330,17 @@ def training(
transformer = transformer.to(device)

prediction_operator = prediction_loss(impl_params=impl_params)
# TODO: Change this in MultiLayerEncoder
prediction_operator.train()
prediction_operator.requires_grad_(True)

discriminator_criterion = DiscriminatorLoss(prediction_operator)
discriminator_criterion = discriminator_criterion.eval()
discriminator_criterion = discriminator_criterion.to(device)

transformer_criterion = transformer_loss(
transformer.encoder, impl_params=impl_params
transformer.encoder, prediction_operator, impl_params=impl_params
)
transformer_criterion = transformer_criterion.eval()
transformer_criterion = transformer_criterion.train()
transformer_criterion = transformer_criterion.to(device)

get_optimizer = optimizer
Expand Down
138 changes: 138 additions & 0 deletions replication/sanakoyeu_et_al_2018/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
from argparse import Namespace
from os import path

import torch

import pystiche_papers.sanakoyeu_et_al_2018 as paper
from pystiche.image import write_image
from pystiche_papers import utils


def training(args):
contents = (
"garden",
"bridge_river",
"glacier_human",
"mountain",
"horses",
"stone_facade",
"waterway",
"garden_parc",
)

styles = (
"berthe-morisot",
"edvard-munch",
"el-greco",
"ernst-ludwig-kirchner",
"jackson-pollock",
"monet_water-lilies-1914",
"nicholas-roerich",
"pablo-picasso",
"paul-cezanne",
"samuel-peploe",
"vincent-van-gogh_road-with-cypresses-1890",
"wassily-kandinsky",
)

images = paper.images()
images.download(args.image_source_dir)

# content_dataset = paper.content_dataset(
# path.join(args.dataset_dir, "content"), impl_params=args.impl_params
# )

content_dataset = paper.content_dataset(
"/home/julianbueltemeier/datasets/places365/data_large_standard",
impl_params=args.impl_params,
)

content_image_loader = paper.image_loader(
content_dataset,
impl_params=args.impl_params,
pin_memory=str(args.device).startswith("cuda"),
)

for style in styles:
style_dataset = paper.style_dataset(
path.join(args.dataset_dir, "style"), style, impl_params=args.impl_params
)
style_image_loader = paper.image_loader(
style_dataset,
impl_params=args.impl_params,
pin_memory=str(args.device).startswith("cuda"),
)

transformer = paper.training(
content_image_loader, style_image_loader, impl_params=args.impl_params
)

model_name = f"sanakoyeu_et_al_2018__{style}"
if args.impl_params:
model_name += "__impl_params"
utils.save_state_dict(transformer, model_name, root=args.model_dir)

for content in contents:
content_image = images[content].read(device=args.device)
output_image = paper.stylization(
content_image, transformer, impl_params=args.impl_params,
)

output_name = f"{style}_{content}"
if args.impl_params:
output_name += "__impl_params"
output_file = path.join(args.image_results_dir, f"{output_name}.jpg")
write_image(output_image, output_file)


def parse_input():
# TODO: write CLI
image_source_dir = None
dataset_dir = None
image_results_dir = None
model_dir = None
device = None
impl_params = True

def process_dir(dir):
dir = path.abspath(path.expanduser(dir))
os.makedirs(dir, exist_ok=True)
return dir

here = path.dirname(__file__)

if image_source_dir is None:
image_source_dir = path.join(here, "data", "images", "source")
image_source_dir = process_dir(image_source_dir)

if dataset_dir is None:
dataset_dir = path.join(here, "data", "images", "dataset")
dataset_dir = process_dir(dataset_dir)

if image_results_dir is None:
image_results_dir = path.join(here, "data", "images", "results")
image_results_dir = process_dir(image_results_dir)

if model_dir is None:
model_dir = path.join(here, "data", "models")
model_dir = process_dir(model_dir)

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)

return Namespace(
image_source_dir=image_source_dir,
dataset_dir=dataset_dir,
image_results_dir=image_results_dir,
model_dir=model_dir,
device=device,
impl_params=impl_params,
)


if __name__ == "__main__":
args = parse_input()
training(args)
Loading