-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
1820 lines (1540 loc) · 74.6 KB
/
train.py
File metadata and controls
1820 lines (1540 loc) · 74.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
3D segmentation model based on DINOv3 ViT-S/16 (per-slice 2D encoder + 3D aggregator/decoder).
Changes in this version:
- Train/Val lists: accepts --train_list and --val_list (fallback to --list_json if needed).
- Always load DINOv3 ViT-S/16 from Hugging Face (no timm fallback).
- Explicitly DROP 5 special tokens (CLS=1 + Register=4) from last_hidden_state.
- Hybrid loss: CE + Dice with class weighting, foreground-slice filtering, and
optional prior-bias initialization for the 1x1 head.
- Log CE, Dice, Topo, Total per train/eval.
- Save color (GT | Pred) image every 20 epochs to `vis_outputs/`.
Architecture:
Frozen Adapter ([0,1] min-max + ImageNet norm, 2ch (Image + Z-coord) -> 3ch, NO learnable params)
-> Frozen DINOv3 ViT-S/16 (2D encoder, outputs patch embeddings)
-> 3D Pyramidal Adapter (ConvNeXt-style, anisotropic DWConv)
-> Shared Parallel FFA Aggregator (Slice Self-Attn + Global Spatial Attn + FFN)
-> UNETR-Lite Decoder (1/16 -> 1/4 with skip connections)
-> Refine3DHead (DW (1,3,3) + GroupNorm + SiLU + SE + 1x1, 3D depth-gated)
-> HRHead2D (DW 3x3 + GroupNorm + SiLU + SE + 1x1, per-slice chunked)
Example:
python train.py \
--train_list /path/to/train.json \
--val_list /path/to/val.json \
--num_classes 16 \
--epochs 300 \
--batch_size 2 \
--accumulation_steps 1 \
--img_size 336 \
--lr 5e-3 \
--drop_empty \
--min_fg_frac 0.0 \
--bg_weight 0.05 \
--use_mfb \
--init_prior_bias \
--aug \
--use_3d \
--use_3d_unetr \
--depth_min_fg_frac 0.0005
"""
from __future__ import annotations
import argparse, json, math, os
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import nibabel as nib
from transformers import AutoModel
import imageio, random
from colorsys import hsv_to_rgb
from contextlib import nullcontext
import torch.utils.checkpoint as cp
import math
from datetime import datetime
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1,3,1,1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1,3,1,1)
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_img_size(s: str) -> Tuple[int, int]:
"""
'224' -> (224 ,224), '512' -> (512, 512)
"""
s = s.strip().lower().replace('x', ',')
parts = [p for p in s.split(',') if p]
if len(parts) == 1:
h = w = int(parts[0])
elif len(parts) == 2:
h, w = int(parts[0]), int(parts[1])
else:
raise ValueError(f"Bad --img_size: {s}")
return h, w
class DropPath(nn.Module):
"""Stochastic Depth (per sample).
Ref: https://arxiv.org/abs/1603.09382
"""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = float(drop_prob)
def forward(self, x):
if (not self.training) or self.drop_prob == 0.0:
return x
keep_prob = 1.0 - self.drop_prob
# shape: [B, 1, 1, 1, 1] to broadcast over [B,C,D,H,W]
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
def _rbool(p: float) -> bool:
return bool(torch.rand(1) < p)
def _runif(a:float, b:float) -> float:
return float(torch.empty(1).uniform_(a, b))
def make_palette(num_classes: int) -> np.ndarray:
"""
Create a bright, high-contrast palette.
0=background -> black, others evenly spaced hues in HSV (S=0.85, V=0.95).
"""
num_classes = max(1, int(num_classes))
palette = np.zeros((num_classes, 3), dtype=np.uint8)
if num_classes <= 1:
return palette
for c in range(1, num_classes):
h = (c - 1) / max(1, num_classes - 1)
s, v = 0.85, 0.95
r, g, b = hsv_to_rgb(h, s, v)
palette[c] = (int(r * 255), int(g * 255), int(b * 255))
return palette
def colorize(label_2d: np.ndarray, palette: np.ndarray) -> np.ndarray:
label_safe = np.clip(label_2d.astype(np.int64), 0, len(palette) - 1)
return palette[label_safe]
def collate_pad_3d(batch):
max_d = max(x.shape[0] for x, _ in batch)
B = len(batch)
# x: [D, C, H, W]
C, H, W = batch[0][0].shape[1:]
X = batch[0][0].new_zeros((B, max_d, C, H, W))
Y = torch.zeros((B, max_d, H, W), dtype=batch[0][1].dtype)
M = torch.zeros((B, max_d, H, W), dtype=torch.bool)
for i, (x, y) in enumerate(batch):
d = x.shape[0]
X[i, :d] = x
Y[i, :d] = y
M[i, :d] = True
return X, Y, M
def make_collate_pad_3d(depth_min_fg_frac: float | None = None):
thr = None if depth_min_fg_frac is None else float(depth_min_fg_frac)
def _collate(batch):
X, Y, M = collate_pad_3d(batch)
if thr is None or thr <= 0.0:
return X, Y, M
B, Dmax, _, H, W = X.shape
for b in range(B):
valid_depths = M[b, :, 0, 0] # [Dmax] True/False
d = int(valid_depths.sum().item()) # valid depth num
if d == 0:
continue
yb = Y[b, :d] # [d,H,W]
fg_counts = (yb > 0).float().view(d, -1).sum(dim=1) # [d]
frac = fg_counts / float(H * W) # [d]
good = (frac >= thr) # [d] bool
M[b, :d] &= good.view(d, 1, 1).expand(d, H, W)
return X, Y, M
return _collate
class FrozenAdapter(nn.Module):
"""
[B, 2, H, W] -> [B, 3, H, W]
Channel 0: Image (min-max norm)
Channel 1: Z-coord (from dataset)
Output: R=Image, G=Image, B=Z-coord -> ImageNet Normalize
"""
def __init__(self):
super().__init__()
self.register_buffer('mean', IMAGENET_MEAN)
self.register_buffer('std', IMAGENET_STD)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, 2, H, W]
B, C, H, W = x.shape
assert C == 2, f"Expected 2-channel input (Image + Z), got {C}"
img = x[:, 0:1, :, :] # [B, 1, H, W]
z_map = x[:, 1:2, :, :] # [B, 1, H, W]
x_min = img.amin(dim=(2,3), keepdim=True)
x_max = img.amax(dim=(2,3), keepdim=True)
img01 = (img - x_min) / (x_max - x_min + 1e-6)
# (R=Img, G=Img, B=Z)
x3 = torch.cat([img01, img01, z_map], dim=1) # [B, 3, H, W]
x3 = (x3 - self.mean) / self.std
return x3
# 3D Pyramidal Adapter (ConvNeXt)
class AdapterPyramid3DConvNeXt(nn.Module):
"""
input: x_bd1hw = [B, D, 1, H, W]
output: F2_3d = [B, C2, D, H/2, W/2], F3_3d = [B, C3, D, H/4, W/4]
Each stage consists of 2 blocks.
The first block performs downsampling with a stride of (1, 2, 2).
Depthwise Convolution (DWConv) uses an anisotropic kernel of (3, 7, 7).
"""
def __init__(self, c2: int = 48, c3: int = 64, in_ch: int = 1, drop_path: float = 0.05):
super().__init__()
# Stage2 (1/2)
self.s2_down = ConvNeXtAniso3DBlock(in_ch, c2, kernel=(3, 7, 7), stride=(1, 2, 2),
drop_path=drop_path, use_depth_branch=True)
self.s2_blk = ConvNeXtAniso3DBlock(c2, c2, kernel=(3, 7, 7), stride=(1, 1, 1),
drop_path=drop_path, use_depth_branch=True)
# Stage3 (1/4)
self.s3_down = ConvNeXtAniso3DBlock(c2, c3, kernel=(3, 7, 7), stride=(1, 2, 2),
drop_path=drop_path, use_depth_branch=True)
self.s3_blk = ConvNeXtAniso3DBlock(c3, c3, kernel=(3, 7, 7), stride=(1, 1, 1),
drop_path=drop_path, use_depth_branch=True)
def forward(self, x_bd2hw: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# x_bd2hw: [B,D,2,H,W] -> [B,1,D,H,W]
B, D, C, H, W = x_bd2hw.shape
assert C == 2
x = x_bd2hw.permute(0, 2, 1, 3, 4).contiguous()
f2 = self.s2_blk(self.s2_down(x)) # [B, C2, D, H/2, W/2]
f3 = self.s3_blk(self.s3_down(f2)) # [B, C3, D, H/4, W/4]
return f2, f3
class Refine3DHead(nn.Module):
"""
input: [B, (K + Cref), D, H/2, W/2]
output: [B, K, D, H/2, W/2]
structure: (1,3,3) center of 2D path + lightweight 3D path made with (3,1,1)/(1,3,3) (DepthGate)
"""
def __init__(self, in_ch: int, num_classes: int, mid: int = 128, use_se: bool = True):
super().__init__()
self.conv_in = nn.Conv3d(in_ch, mid, kernel_size=(1, 3, 3), padding=(0, 1, 1), bias=False)
self.gn_in = nn.GroupNorm(8, mid)
self.act = nn.SiLU(inplace=False)
# 2D branch (in-plane)
self.conv2d = nn.Conv3d(mid, mid, kernel_size=(1, 3, 3), padding=(0, 1, 1), bias=False)
self.gn2d = nn.GroupNorm(8, mid)
# 3D branch (shallow depth mixing, anisotropic)
self.conv3d_z = nn.Conv3d(mid, mid, kernel_size=(3, 1, 1), padding=(1, 0, 0), bias=False)
self.conv3d_hw= nn.Conv3d(mid, mid, kernel_size=(1, 3, 3), padding=(0, 1, 1), bias=False)
self.gn3d = nn.GroupNorm(8, mid)
self.depth_gate = nn.Conv3d(in_ch, 1, kernel_size=1)
self.se = SEBlock3D(mid) if use_se else nn.Identity()
self.proj = nn.Conv3d(mid, num_classes, kernel_size=1)
def forward(self, x_cat: torch.Tensor) -> torch.Tensor:
# x_cat: [B, K+Cref, D, H/2, W/2]
g = torch.sigmoid(self.depth_gate(x_cat)) # [B,1,D,H/2,W/2]
h = self.act(self.gn_in(self.conv_in(x_cat)))
h2d = self.act(self.gn2d(self.conv2d(h)))
h3d = self.conv3d_hw(self.conv3d_z(h))
h3d = self.act(self.gn3d(h3d))
h = h2d + g * h3d
h = self.se(h)
return self.proj(h) # [B,K,D,H/2,W/2]
class FrozenDINOv3ViTS16(nn.Module):
def __init__(self, hf_id: str = "facebook/dinov3-vits16-pretrain-lvd1689m", default_layers: Tuple[int, ...] = (2, 5, 8, 11)):
super().__init__()
self.vit = AutoModel.from_pretrained(hf_id, trust_remote_code=True)
for p in self.vit.parameters():
p.requires_grad = False
self.vit.eval()
self.hidden_dim = int(getattr(self.vit.config, 'hidden_size', 384))
self.patch = int(getattr(self.vit.config, 'patch_size', 16))
self.num_special_tokens = 5 # CLS(1) + Register(4)
self.default_layers = tuple(int(i) for i in default_layers)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
Gh, Gw = H // self.patch, W // self.patch
outputs = self.vit(pixel_values=x, return_dict=True)
tokens = outputs.last_hidden_state # [B, seq, C]
patch_tokens = tokens[:, self.num_special_tokens:self.num_special_tokens + Gh * Gw, :]
return patch_tokens.transpose(1, 2).contiguous().view(B, self.hidden_dim, Gh, Gw)
@torch.no_grad()
def forward_multi(self, x: torch.Tensor, layers: Tuple[int, ...] | None = None) -> List[torch.Tensor]:
"""
Returns output (patch tokens only)
from specified Transformer blocks as 2D grids.
"""
B, C, H, W = x.shape
assert H % self.patch == 0 and W % self.patch == 0
Gh, Gw = H // self.patch, W // self.patch
N = Gh * Gw
outputs = self.vit(pixel_values=x, output_hidden_states=True, return_dict=True)
hiddens = outputs.hidden_states # L+1
L_total = len(hiddens) - 1
layers = tuple(self.default_layers if layers is None else layers)
grids: List[torch.Tensor] = []
for li in layers:
idx = (L_total if li == -1 else int(li))
# hidden_states は [0]=embeddings, [1..L_total]=each block
idx = max(1, min(idx, L_total))
tok = hiddens[idx] # [B, seq, C]
patch_tokens = tok[:, self.num_special_tokens:self.num_special_tokens + N, :]
grid = patch_tokens.transpose(1, 2).contiguous().view(B, self.hidden_dim, Gh, Gw)
grids.append(grid)
return grids
class SEBlock3D(nn.Module):
def __init__(self, channels: int, reduction: int = 4):
super().__init__()
hidden = max(8, channels // reduction)
self.avg = nn.AdaptiveAvgPool3d(1)
self.fc = nn.Sequential(
nn.Conv3d(channels, hidden, 1),
nn.SiLU(inplace=True),
nn.Conv3d(hidden, channels, 1),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.fc(self.avg(x))
class SEBlock2D(nn.Module):
def __init__(self, channels: int, reduction: int = 4):
super().__init__()
hidden = max(8, channels // reduction)
self.avg = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(channels, hidden, 1),
nn.SiLU(inplace=True),
nn.Conv2d(hidden, channels, 1),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.fc(self.avg(x))
class HRHead2D(nn.Module):
"""
input (1/1): [B2D, K+64, H, W] (K=classes, 64=F2 channels)
output (1/1): [B2D, K, H, W]
3x3 → DW 3x3 → 1x1 (+SE) → 1x1(classifier)
"""
def __init__(self, in_ch: int, num_classes: int, mid: int = 64, use_se: bool = True):
super().__init__()
self.conv_in = nn.Conv2d(in_ch, mid, kernel_size=3, padding=1, bias=False)
self.gn_in = nn.GroupNorm(8, mid)
self.act = nn.SiLU(inplace=False)
self.dw = nn.Conv2d(mid, mid, kernel_size=3, padding=1, groups=mid, bias=False)
self.gn_dw = nn.GroupNorm(8, mid)
self.pw = nn.Conv2d(mid, mid, kernel_size=1, bias=False)
self.gn_pw = nn.GroupNorm(8, mid)
self.se = SEBlock2D(mid) if use_se else nn.Identity()
self.cls = nn.Conv2d(mid, num_classes, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.act(self.gn_in(self.conv_in(x)))
h = self.act(self.gn_dw(self.dw(h)))
h = self.act(self.gn_pw(self.pw(h)))
h = self.se(h)
out = self.cls(h)
return out
class UpBlock3DSkip(nn.Module):
def __init__(self, c_in: int, c_skip: int, c_out: int, use_se: bool = True):
super().__init__()
self.proj_in = nn.Conv3d(c_in, c_out, kernel_size=1, bias=False)
self.proj_skip = nn.Conv3d(c_skip, c_out, kernel_size=1, bias=False)
self.conv1 = nn.Conv3d(2 * c_out, c_out, kernel_size=(1, 3, 3), padding=(0, 1, 1))
self.gn1 = nn.GroupNorm(8, c_out)
self.conv2 = nn.Conv3d(c_out, c_out, kernel_size=(1, 3, 3), padding=(0, 1, 1))
self.gn2 = nn.GroupNorm(8, c_out)
self.act = nn.SiLU(inplace=False)
self.se = SEBlock3D(c_out) if use_se else nn.Identity()
def forward(self, x, skip):
# upsample H/W 2x(Depth: keep same)
x = F.interpolate(x, scale_factor=(1, 2, 2), mode="trilinear", align_corners=False)
x = self.proj_in(x)
skip = F.interpolate(skip, size=x.shape[-3:], mode="trilinear", align_corners=False)
skip = self.proj_skip(skip)
h = torch.cat([x, skip], dim=1)
h = self.act(self.gn1(self.conv1(h)))
h = self.act(self.gn2(self.conv2(h)))
return self.se(h)
class ConvNeXtAniso3DBlock(nn.Module):
"""
ConvNeXt-style 3D block with anisotropic DWConv3D,
LayerNorm (channels-last), and PW-MLP (1x1x1).
Additionally, a lightweight branch with DWConv (3,1,1) is added for depth mixing.
- When stride=(1,2,2), it performs in-plane downsampling.
- If input/output channels or spatial size differ, a 1x1x1 (with stride if needed)
is used for the residual path.
"""
def __init__(
self,
in_ch: int,
out_ch: int,
kernel: Tuple[int, int, int] = (3, 7, 7),
stride: Tuple[int, int, int] = (1, 1, 1),
mlp_ratio: float = 4.0,
drop_path: float = 0.05,
layer_scale_init: float = 1e-6,
use_depth_branch: bool = True,
):
super().__init__()
g = in_ch
pad = (kernel[0] // 2, kernel[1] // 2, kernel[2] // 2)
# Depthwise
self.dw = nn.Conv3d(in_ch, in_ch, kernel, stride=stride, padding=pad, groups=g, bias=True)
self.use_depth_branch = bool(use_depth_branch)
if self.use_depth_branch:
kz = (3, 1, 1)
self.dw_z = nn.Conv3d(in_ch, in_ch, kz, stride=stride,
padding=(kz[0] // 2, 0, 0), groups=g, bias=True)
# channels-last LayerNorm
self.norm = nn.LayerNorm(in_ch, eps=1e-6)
# pointwise-MLP (1x1x1 conv → GELU → 1x1x1 conv)
hidden = int(in_ch * mlp_ratio)
self.pw1 = nn.Conv3d(in_ch, hidden, kernel_size=1, bias=True)
self.act = nn.GELU()
self.pw2 = nn.Conv3d(hidden, out_ch, kernel_size=1, bias=True)
# residual path
self.need_proj = (in_ch != out_ch) or (stride != (1, 1, 1))
self.proj = nn.Conv3d(in_ch, out_ch, kernel_size=1, stride=stride, bias=True) if self.need_proj else nn.Identity()
self.gamma = nn.Parameter(torch.ones((out_ch, 1, 1, 1)) * layer_scale_init)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# DW
y = self.dw(x)
if self.use_depth_branch:
y = y + self.dw_z(x)
# channels-last LN
y = y.permute(0, 2, 3, 4, 1) # [B,D,H,W,C]
y = self.norm(y)
y = y.permute(0, 4, 1, 2, 3).contiguous() # [B,C,D,H,W]
# PW-MLP
y = self.pw2(self.act(self.pw1(y)))
# Residual + DropPath + LayerScale
y = self.drop_path(self.gamma * y) + self.proj(x)
return y
class SegDecoder3D_UNETRLite(nn.Module):
"""
- input: bottom [B,Cb,D,Gh,Gw], skips: List[[B,Cs,D,Gh,Gw]] (all 1/16 resolution)
- Internally interpolates each skip connection to the stage's resolution and concatenates them.
- Creates 3 upsampling blocks for up_factor_hw=8.
- (Reuse the last skip if there are too few; discard the first ones if there are too many.)
"""
def __init__(self, c_in: int, c_skip: int, num_classes: int,
up_factor_hw: int = 4, base_channels: int = 128, use_se: bool = True):
super().__init__()
steps = []
f = up_factor_hw
while f > 1:
steps.append(2); f //= 2
self.num_stages = max(1, len(steps))
self.stem = nn.Sequential(
nn.Conv3d(c_in, base_channels, kernel_size=1, bias=False),
nn.GroupNorm(8, base_channels),
nn.SiLU(inplace=False),
)
blocks = []
for _ in range(self.num_stages):
blocks.append(UpBlock3DSkip(base_channels, c_skip, base_channels, use_se=use_se))
self.blocks = nn.ModuleList(blocks)
self.refine = nn.Sequential(
nn.Conv3d(base_channels, base_channels, kernel_size=(1,3,3), padding=(0,1,1)),
nn.GroupNorm(8, base_channels),
nn.SiLU(inplace=False),
)
self.classifier = nn.Conv3d(base_channels, num_classes, kernel_size=1)
def forward(self, bottom, skip_list: List[torch.Tensor]):
h = self.stem(bottom)
skips = list(skip_list)
if len(skips) >= self.num_stages:
skips = skips[-self.num_stages:] # from deepest
else:
# when there are too few skip connections, repeat the last one
while len(skips) < self.num_stages:
skips.append(skips[-1])
# low resolution -> high resolution
for blk, sk in zip(self.blocks, reversed(skips)):
h = blk(h, sk)
h = self.refine(h)
return self.classifier(h) # [B,K,D,H,W]
# SegModel3D_UNETRLite (3D ConvNeXt Adapter + 3DGate + Refine3D)
class SegModel3D_UNETRLite(nn.Module):
"""
x: [B,D,1,H,W] ->
Adapter3DConvNeXt (F2@1/2, F3@1/4)
-> Frozen DINOv3 ViT (per-slice) -> multi-layer patch grids
-> Shared Axial Aggregator
-> Decoder(UNETR-Lite): bottom = last layer, skips = others
-> Up to 1/4 (H/4,W/4), then upsample to 1/2 and Refine3D
"""
def __init__(self, num_classes: int,
vit_layers: Tuple[int, ...] = (2, 5, 8, 11),
decoder_base_channels: int = 128, decoder_up_factor: int = 4,
vit_chunk_slices: int = 8, vit_amp: bool = True, use_se: bool = True):
super().__init__()
# Frozen per-slice adapter for ViT encoder (min-max + ImageNet norm 1ch->3ch)
self.adapter_slice = FrozenAdapter(); [p.requires_grad_(False) for p in self.adapter_slice.parameters()]
self.encoder = FrozenDINOv3ViTS16(default_layers=vit_layers); [p.requires_grad_(False) for p in self.encoder.parameters()]
C = int(self.encoder.hidden_dim)
self.patch = int(self.encoder.patch)
self.selected_layers = tuple(vit_layers)
# 3D Pyramidal Adapter (ConvNeXt)
self.adapter3d = AdapterPyramid3DConvNeXt(c2=48, c3=64, in_ch=2, drop_path=0.05)
self.f2_proj_ref = nn.Conv3d(48, 64, kernel_size=1, bias=False) # F2 -> 64 (refine用)
self.f3_proj_C = nn.Conv3d(64, C, kernel_size=1, bias=False) # F3 -> C (ViT hidden)
self.f3_gate3d = nn.Conv3d(C + C, 1, kernel_size=1) # 3D gate σ([A5_up, F3pC])
# Shared parallel axial aggregator
self.shared_agg = ParallelAggregatorSharedFFA(
c=C, num_layers=len(vit_layers),
n_blocks=2,
heads=6,
attn_dim=C // 2,
kv_down=2,
dropout=0.0,
drop_path=0.05,
use_rope=True,
use_pos_slice=True
)
# Decoder (UNETR-lite 1/16 -> 1/4)
self.dec = SegDecoder3D_UNETRLite(
c_in=C, c_skip=C, num_classes=num_classes,
up_factor_hw=decoder_up_factor, base_channels=decoder_base_channels, use_se=use_se
)
# Refine head
self.head = Refine3DHead(in_ch=(num_classes + 64), num_classes=num_classes, mid=decoder_base_channels, use_se=True)
self.hr2d = HRHead2D(in_ch=(num_classes + 16), num_classes=num_classes,
mid=decoder_base_channels, use_se=True)
self.f2_hr_reduce3d = nn.Conv3d(64, 16, kernel_size=1, bias=False) # For HR
self.hr2d_chunk = 16
self.vit_chunk_slices = int(vit_chunk_slices)
self.vit_amp = bool(vit_amp)
def trainable_parameters(self):
# Aggregator + Decoder + 3D Adapter + gates + Refine3D
return (list(self.shared_agg.parameters())
+ list(self.dec.parameters())
+ list(self.adapter3d.parameters())
+ list(self.f3_proj_C.parameters()) + list(self.f3_gate3d.parameters())
+ list(self.f2_proj_ref.parameters()) + list(self.head.parameters())
+ list(self.hr2d.parameters()))
def _stack_depth(self, feats_2d: List[torch.Tensor], B: int, D: int) -> List[torch.Tensor]:
outs = []
for f in feats_2d:
_, C, Gh, Gw = f.shape
outs.append(f.view(B, -1, C, Gh, Gw).permute(0, 2, 1, 3, 4).contiguous())
return outs
def forward(self, x_bdchw: torch.Tensor) -> torch.Tensor:
# x: [B,D,2,H,W]
B, D, C_in, H, W = x_bdchw.shape
assert C_in == 2, "Input must have 2 channels (Image + Z)"
# Per-slice ViT encoding (frozen)
x_slices = x_bdchw.view(-1, 2, H, W)
feats_per_layer: List[List[torch.Tensor]] = []
# ViT inference
for s in range(0, x_slices.size(0), self.vit_chunk_slices):
xi = x_slices[s:s + self.vit_chunk_slices] # [chunk, 2, H, W]
xi3 = self.adapter_slice(xi) # [chunk,3,H,W]
with torch.no_grad():
ctx = torch.cuda.amp.autocast(dtype=torch.bfloat16) if (self.vit_amp and xi3.is_cuda) else nullcontext()
with ctx:
grids = self.encoder.forward_multi(xi3, layers=self.selected_layers)
for k, gi in enumerate(grids):
if len(feats_per_layer) <= k:
feats_per_layer.append([])
tgt_dtype = torch.bfloat16 if (self.vit_amp and xi3.is_cuda) else torch.float32
feats_per_layer[k].append(gi.to(tgt_dtype))
del x_slices, xi, xi3
feats_2d: List[torch.Tensor] = [torch.cat(vs, dim=0) for vs in feats_per_layer]
del feats_per_layer
feats_3d: List[torch.Tensor] = self._stack_depth(feats_2d, B=B, D=D)
del feats_2d
# Shared Aggregator -> Checkpoint
if self.training:
agg_outs = cp.checkpoint(self.shared_agg, feats_3d, use_reentrant=False)
else:
agg_outs = self.shared_agg(feats_3d)
del feats_3d
*skips, bottom = agg_outs
if len(skips) < 2:
raise RuntimeError("Need at least two skip features for up_factor=4")
A5, A8 = skips[-2], skips[-1] # [B,C,D,H/16,W/16]
A11 = bottom # [B,C,D,H/16,W/16]
# 3D Adapter -> Checkpoint
if self.training:
F2_3d, F3_3d = cp.checkpoint(self.adapter3d, x_bdchw, use_reentrant=False)
else:
F2_3d, F3_3d = self.adapter3d(x_bdchw)
# Projections
F2p = self.f2_proj_ref(F2_3d)
F3pC = self.f3_proj_C(F3_3d)
# Gate fusion
A5_up = F.interpolate(A5, size=F3pC.shape[-3:], mode="trilinear", align_corners=False)
gate3d = torch.sigmoid(self.f3_gate3d(torch.cat([A5_up, F3pC], dim=1)))
skip2 = A5_up + gate3d * F3pC
# Decoder
if self.training:
logits_56 = cp.checkpoint(self.dec, A11, [skip2, A8], use_reentrant=False)
else:
logits_56 = self.dec(A11, [skip2, A8])
# Refine Head
logits_112 = F.interpolate(logits_56, scale_factor=(1, 2, 2), mode="trilinear", align_corners=False)
ref_in = torch.cat([logits_112, F2p], dim=1)
if self.training:
logits_ref = cp.checkpoint(self.head, ref_in, use_reentrant=False)
else:
logits_ref = self.head(ref_in)
# HR Head
F2r_3d = self.f2_hr_reduce3d(F2p.detach())
# Upsample to 1/1
logits_1x1 = F.interpolate(logits_ref, scale_factor=(1, 2, 2), mode="trilinear", align_corners=False)
F2r_1x1 = F.interpolate(F2r_3d, scale_factor=(1, 2, 2), mode="trilinear", align_corners=False)
# Output buffer
B, K, D, H, W = logits_1x1.shape
hr_logits = logits_1x1.new_empty((B, K, D, H, W))
# Chunked HR processing
for s in range(0, D, self.hr2d_chunk):
e = min(D, s + self.hr2d_chunk)
hr_in_chunk = torch.cat([logits_1x1[:, :, s:e], F2r_1x1[:, :, s:e]], dim=1)
hr_in_2d = hr_in_chunk.permute(0, 2, 1, 3, 4).reshape(B * (e - s), K + 16, H, W)
if self.training:
out2d = cp.checkpoint(self.hr2d, hr_in_2d, use_reentrant=False)
else:
out2d = self.hr2d(hr_in_2d)
out_chunk = out2d.view(B, (e - s), K, H, W).permute(0, 2, 1, 3, 4).contiguous()
hr_logits[:, :, s:e] = out_chunk
return hr_logits
# Layer-ID FiLM (per-layer scale/bias; weight-tying adapter)
class LayerFiLM(nn.Module):
def __init__(self, c: int, num_layers: int):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(num_layers, c)) # multiplicative
self.beta = nn.Parameter(torch.zeros(num_layers, c)) # additive
def forward(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor:
g = (1.0 + self.gamma[layer_idx]).view(1, -1, 1, 1, 1)
b = self.beta[layer_idx].view(1, -1, 1, 1, 1)
return x * g + b
class RotaryPositionalEmbedding1D(nn.Module):
"""
1D (Depth-only) Rotary Embedding.
- Assume head_dim is even number
"""
def __init__(self, dim: int, base: float = 10000.0, scale: float = 1.0):
super().__init__()
assert dim % 2 == 0, f"RoPE dim must be even, got {dim}"
self.dim = dim
self.base = float(base)
self.scale = float(scale)
inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def get_cos_sin(self, D: int, device, dtype):
# 0 ... D-1 integers for now
t = torch.arange(D, device=device, dtype=torch.float32) * self.scale # fp32
freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=torch.float32))
cos = freqs.cos()[None, None, :, :] # [1, 1, D, dim/2]
sin = freqs.sin()[None, None, :, :] # [1, 1, D, dim/2]
# cast q/k's dtype
return cos.to(dtype=dtype), sin.to(dtype=dtype)
def apply_rope_1d(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""
x: [BHW, heads, D, dim]
cos: [1, 1, D, dim/2]
sin: [1, 1, D, dim/2]
"""
x_even = x[..., ::2] # [BHW,h,D,dim/2]
x_odd = x[..., 1::2] # [BHW,h,D,dim/2]
x_rot_even = x_even * cos - x_odd * sin
x_rot_odd = x_even * sin + x_odd * cos
x_rot = torch.stack((x_rot_even, x_rot_odd), dim=-1) # [..., dim/2, 2]
x_rot = x_rot.flatten(-2) # [..., dim]
return x_rot
# FFA core (Slice Attn + Global Spatial Attn + FFN, with pos_slice)
class SliceSelfAttention1D(nn.Module):
"""
input: [B,C,D,H,W]
output: [B,C,D,H,W]
apply Attn only Depth axis, keeping H/W as batch dimensions for efficiency.
"""
def __init__(self, c: int, heads: int = 6, attn_dim: int | None = None, dropout: float = 0.0, use_rope: bool = True):
super().__init__()
self.c = c
self.heads = int(heads)
self.attn_dim = int(attn_dim or c)
assert self.attn_dim % self.heads == 0
self.head_dim = self.attn_dim // self.heads
self.use_rope = bool(use_rope)
self.ln = nn.LayerNorm(c)
self.q = nn.Linear(c, self.attn_dim, bias=False)
self.k = nn.Linear(c, self.attn_dim, bias=False)
self.v = nn.Linear(c, self.attn_dim, bias=False)
self.proj = nn.Linear(self.attn_dim, c, bias=False)
self.drop = nn.Dropout(dropout)
self.rope = RotaryPositionalEmbedding1D(self.head_dim) if self.use_rope else None
def forward(self, x_bcdhw: torch.Tensor) -> torch.Tensor:
B, C, D, H, W = x_bcdhw.shape
x = x_bcdhw.permute(0, 3, 4, 2, 1).contiguous() # [B,H,W,D,C]
x = self.ln(x)
x_seq = x.view(B * H * W, D, C) # [BHW, D, C]
q = self.q(x_seq).view(B * H * W, D, self.heads, self.head_dim).transpose(1, 2) # [BHW,h,D,d]
k = self.k(x_seq).view(B * H * W, D, self.heads, self.head_dim).transpose(1, 2)
v = self.v(x_seq).view(B * H * W, D, self.heads, self.head_dim).transpose(1, 2)
if self.use_rope:
cos, sin = self.rope.get_cos_sin(D, device=q.device, dtype=q.dtype)
q = apply_rope_1d(q, cos, sin)
k = apply_rope_1d(k, cos, sin)
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # [BHW,h,D,D]
attn = attn.softmax(dim=-1)
attn = self.drop(attn)
out = attn @ v # [BHW,h,D,d]
out = out.transpose(1, 2).contiguous().view(B * H * W, D, self.attn_dim)
out = self.proj(out) # [BHW,D,C]
out = self.drop(out)
out = out.view(B, H, W, D, C).permute(0, 4, 3, 1, 2).contiguous() # [B,C,D,H,W]
return out
class GlobalSpatialAttention2D(nn.Module):
"""
Performs global spatial attention within each slice.
Uses all tokens (H x W) for Queries, and downsampled tokens (H/p x W/p)
via average pooling for Keys and Values.
output: [B, C, D, H, W]
"""
def __init__(self, c: int, heads: int = 6, attn_dim: int | None = None,
kv_down: int = 2, dropout: float = 0.0):
super().__init__()
self.c = c
self.heads = int(heads)
self.attn_dim = int(attn_dim or c)
assert self.attn_dim % self.heads == 0
self.head_dim = self.attn_dim // self.heads
self.kv_down = int(kv_down)
self.ln = nn.LayerNorm(c)
self.q_conv = nn.Conv2d(c, self.attn_dim, kernel_size=1, bias=False)
self.k_conv = nn.Conv2d(c, self.attn_dim, kernel_size=1, bias=False)
self.v_conv = nn.Conv2d(c, self.attn_dim, kernel_size=1, bias=False)
self.proj = nn.Conv2d(self.attn_dim, c, kernel_size=1, bias=False)
self.drop = nn.Dropout(dropout)
self.pool = nn.AvgPool2d(kernel_size=self.kv_down, stride=self.kv_down) if self.kv_down > 1 else nn.Identity()
def forward(self, x_bcdhw: torch.Tensor) -> torch.Tensor:
B, C, D, H, W = x_bcdhw.shape
x = x_bcdhw.permute(0, 2, 1, 3, 4).contiguous() # [B,D,C,H,W]
x_ = x.permute(0, 1, 3, 4, 2).contiguous() # [B,D,H,W,C] for LN
x_ = self.ln(x_)
x = x_.permute(0, 1, 4, 2, 3).contiguous() # [B,D,C,H,W]
BD = B * D
x2d = x.view(BD, C, H, W)
q = self.q_conv(x2d) # [BD, attn_dim, H, W]
k = self.k_conv(self.pool(x2d)) # [BD, attn_dim, H', W']
v = self.v_conv(self.pool(x2d)) # [BD, attn_dim, H', W']
Hq, Wq = H, W
Hk, Wk = k.shape[-2], k.shape[-1]
# [BD, attn_dim, H, W] -> [BD, H*W, heads, head_dim] -> [BD, heads, Nq, d]
q = q.view(BD, self.heads, self.head_dim, Hq * Wq).transpose(2, 3).contiguous() # [BD,h,Nq,d]
k = k.view(BD, self.heads, self.head_dim, Hk * Wk).transpose(2, 3).contiguous() # [BD,h,Nk,d]
v = v.view(BD, self.heads, self.head_dim, Hk * Wk).transpose(2, 3).contiguous() # [BD,h,Nk,d]
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # [BD,h,Nq,Nk]
attn = attn.softmax(dim=-1)
attn = self.drop(attn)
out = attn @ v # [BD,h,Nq,d]
out = out.transpose(2, 3).contiguous().view(BD, self.attn_dim, Hq, Wq)
out = self.proj(out)
out = self.drop(out)
out = out.view(B, D, C, Hq, Wq).permute(0, 2, 1, 3, 4).contiguous() # [B,C,D,H,W]
return out
class FFN3D_Pointwise(nn.Module):
"""
pointwise-MLP (1x1x1) (Conv→GELU→Drop→Conv)
"""
def __init__(self, c: int, hidden: int | None = None, dropout: float = 0.0):
super().__init__()
h = int(hidden or (4 * c))
self.fc1 = nn.Conv3d(c, h, kernel_size=1, bias=True)
self.act = nn.GELU()
self.drop = nn.Dropout(dropout)
self.fc2 = nn.Conv3d(h, c, kernel_size=1, bias=True)
def forward(self, x):
return self.fc2(self.drop(self.act(self.fc1(x))))
class FFABlock(nn.Module):
"""
FFA 1block: [Slice Self-Attn (RoPE)] → [Global Spatial Attn(2D/each slice)] → [FFN]
"""
def __init__(self, c: int, heads: int = 6, attn_dim: int | None = None,
kv_down: int = 2, dropout: float = 0.0, drop_path: float = 0.0,
use_rope: bool = True):
super().__init__()
self.slice_attn = SliceSelfAttention1D(c, heads=heads, attn_dim=attn_dim, dropout=dropout, use_rope=use_rope)
self.global_attn = GlobalSpatialAttention2D(c, heads=heads, attn_dim=attn_dim, kv_down=kv_down, dropout=dropout)
self.ffn = FFN3D_Pointwise(c, hidden=None, dropout=dropout)
self.dp1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.dp2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.dp3 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ln3 = nn.LayerNorm(c)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B,C,D,H,W]
x = x + self.dp1(self.slice_attn(x))
x = x + self.dp2(self.global_attn(x))
# FFNはchannels-lastでLN → 3D pointwise
B, C, D, H, W = x.shape
y = x.permute(0, 2, 3, 4, 1).contiguous()
y = self.ln3(y).permute(0, 4, 1, 2, 3).contiguous()
x = x + self.dp3(self.ffn(y))
return x
class Aggregator3D_FFA(nn.Module):
"""
FFA shared core.
Input/Output: [B,C,D,H,W] (1/16 res)
"""
def __init__(self, c: int,
n_blocks: int = 2,
heads: int = 6,
attn_dim: int | None = None,
kv_down: int = 2,
dropout: float = 0.0,
drop_path: float = 0.05,
use_rope: bool = True,
use_pos_slice: bool = True,
max_depth: int = 512):
super().__init__()
self.use_pos_slice = bool(use_pos_slice)
if self.use_pos_slice:
self.pos_slice = nn.Embedding(max_depth, c)
self.pos_gain = nn.Parameter(torch.ones(1))
else:
self.pos_slice = None
self.pre = nn.Sequential(
nn.Conv3d(c, c, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(8, c),
nn.SiLU(inplace=False),
nn.Conv3d(c, c, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(8, c),
nn.SiLU(inplace=False),
)
blocks = []
for i in range(int(n_blocks)):
blocks.append(
FFABlock(
c=c, heads=heads, attn_dim=(attn_dim or c // 2),
kv_down=kv_down, dropout=dropout,
drop_path=drop_path if n_blocks > 1 else 0.0,
use_rope=use_rope
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B,C,D,H,W]
if self.use_pos_slice and self.pos_slice is not None:
B, C, D, H, W = x.shape
idx = torch.arange(D, device=x.device)
emb = self.pos_slice(idx).transpose(0, 1).view(1, C, D, 1, 1) # [1,C,D,1,1]
x = x + self.pos_gain * emb.to(dtype=x.dtype)
x = self.pre(x)
for blk in self.blocks:
x = cp.checkpoint(blk, x, use_reentrant=False) if self.training else blk(x)
return x
class ParallelAggregatorSharedFFA(nn.Module):
"""
For 3D features from multiple ViT layers List[B, C, D, Gh, Gw],
A shared Aggregator3D_FFA is applied across all layers.
"""
def __init__(self, c: int, num_layers: int,
n_blocks: int = 2, heads: int = 6, attn_dim: int | None = None,
kv_down: int = 2, dropout: float = 0.0, drop_path: float = 0.05,
use_rope: bool = True, use_pos_slice: bool = True):
super().__init__()
self.core = Aggregator3D_FFA(
c=c, n_blocks=n_blocks, heads=heads, attn_dim=attn_dim,
kv_down=kv_down, dropout=dropout, drop_path=drop_path,
use_rope=use_rope, use_pos_slice=use_pos_slice
)
self.film_in = LayerFiLM(c, num_layers)
self.film_out = LayerFiLM(c, num_layers)
def forward(self, feats_3d_list: List[torch.Tensor]) -> List[torch.Tensor]:
outs = []
for li, x in enumerate(feats_3d_list):
xin = self.film_in(x, li)
y = self.core(xin)
yout = self.film_out(y, li)
outs.append(yout)
return outs
class DiceLossMultiClass(nn.Module):
"""
Soft Dice over present classes (background excluded by default).