Skip to content

Commit 8218627

Browse files
Dealing with encoder outputs with dimension > 3 when using the reshaper neck
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 92bd85a commit 8218627

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

terratorch/models/necks.py

+14
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,30 @@ def __init__(self, channel_list: list[int], remove_cls_token=True, effective_tim
141141
self.remove_cls_token = remove_cls_token
142142
self.effective_time_dim = effective_time_dim
143143

144+
def collapse_dims(self, x):
145+
"""
146+
When the encoder output has more than 3 dimensions, is necessary to
147+
reshape it.
148+
"""
149+
shape = x.shape
150+
batch = x.shape[0]
151+
e = x.shape[-1]
152+
collapsed_dim = np.prod(x.shape[1:-1])
153+
154+
return x.reshape(batch, collapsed_dim, e)
155+
144156
def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
145157
out = []
146158
for x in features:
147159
if self.remove_cls_token:
148160
x_no_token = x[:, 1:, :]
149161
else:
150162
x_no_token = x
163+
x_no_token = self.collapse_dims(x_no_token)
151164
number_of_tokens = x_no_token.shape[1]
152165
tokens_per_timestep = number_of_tokens // self.effective_time_dim
153166
h = int(np.sqrt(tokens_per_timestep))
167+
154168
encoded = rearrange(
155169
x_no_token,
156170
"batch (t h w) e -> batch (t e) h w",

0 commit comments

Comments
 (0)