-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extends smp_model_factory class #56
Changes from 3 commits
fb68d27
52699a9
93c523e
e544105
7edbbf8
97b4688
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
from terratorch.models.smp_model_factory import SMPModelFactory | ||
from terratorch.models.timm_model_factory import TimmModelFactory | ||
|
||
from terratorch.models.smp_model_factory import get_smp_decoder | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont think we should expose this directly. Is there a good reason for importing this here? |
||
|
||
__all__ = ( | ||
"PrithviModelFactory", | ||
"ClayModelFactory", | ||
|
@@ -15,4 +17,5 @@ | |
"TimmModelFactory", | ||
"AuxiliaryHead", | ||
"AuxiliaryHeadWithDecoderWithoutInstantiatedHead", | ||
"get_smp_decoder", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
) | ||
from terratorch.models.pixel_wise_model import PixelWiseModel | ||
from terratorch.models.scalar_output_model import ScalarOutputModel | ||
from terratorch.models.smp_model_factory import get_smp_decoder | ||
|
||
PIXEL_WISE_TASKS = ["segmentation", "regression"] | ||
SCALAR_TASKS = ["classification"] | ||
|
@@ -26,6 +27,7 @@ | |
class DecoderNotFoundError(Exception): | ||
pass | ||
|
||
|
||
@register_factory | ||
class PrithviModelFactory(ModelFactory): | ||
def build_model( | ||
|
@@ -34,7 +36,8 @@ def build_model( | |
backbone: str | nn.Module, | ||
decoder: str | nn.Module, | ||
bands: list[HLSBands | int], | ||
in_channels: int | None = None, # this should be removed, can be derived from bands. But it is a breaking change | ||
in_channels: int | ||
| None = None, # this should be removed, can be derived from bands. But it is a breaking change | ||
num_classes: int | None = None, | ||
pretrained: bool = True, # noqa: FBT001, FBT002 | ||
num_frames: int = 1, | ||
|
@@ -95,7 +98,13 @@ def build_model( | |
msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}" | ||
raise NotImplementedError(msg) | ||
|
||
# These params are used in case we need a SMP decoder | ||
# but should not be used for timm encoder | ||
backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_") | ||
smp_kwargs, kwargs = _extract_prefix_keys(kwargs, "smp_") | ||
aux_kwargs, kwargs = _extract_prefix_keys(kwargs, "aux_") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make use of the args passed through |
||
output_stride = backbone_kwargs.pop("output_stride", None) | ||
out_channels = backbone_kwargs.pop("out_channels", None) | ||
|
||
backbone: nn.Module = timm.create_model( | ||
backbone, | ||
|
@@ -106,13 +115,26 @@ def build_model( | |
features_only=True, | ||
**backbone_kwargs, | ||
) | ||
# allow decoder to be a module passed directly | ||
decoder_cls = _get_decoder(decoder) | ||
|
||
decoder_kwargs, kwargs = _extract_prefix_keys(kwargs, "decoder_") | ||
|
||
args = kwargs.copy() | ||
# TODO: remove this | ||
decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs) | ||
if decoder.startswith("smp_"): | ||
decoder: nn.Module = get_smp_decoder( | ||
decoder, | ||
backbone_kwargs, | ||
smp_kwargs, | ||
aux_kwargs, | ||
args, | ||
out_channels, | ||
in_channels, | ||
num_classes, | ||
output_stride, | ||
) | ||
else: | ||
# allow decoder to be a module passed directly | ||
decoder_cls = _get_decoder(decoder) | ||
decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs) | ||
# decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs) | ||
|
||
head_kwargs, kwargs = _extract_prefix_keys(kwargs, "head_") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unsorted imports (ruff can sort imports)