-
Notifications
You must be signed in to change notification settings - Fork 39
Description
A Data Type Mismatch Causes a Code Error
In the code for wavelet transformation and invert wavelet transformation, why are the type of wavelet filters cast to torch.float16? The input x is still torch.float32, which cause error when running code due to mismatch. Did I encounter this issue because I used float32 inputs during debugging, while you used float16? If that’s the case, I am sorry, because I mistakenly assumed the data type mismatch was a problem.
Code references:
self.w_ll = self.w_ll.to(dtype=torch.float16) |
self.filters = self.filters.to(dtype=torch.float16) |
The Acquisition of K and V in Self-Attention Does Not Match the Paper

According to the paper, k and v is obtained from down-sampled feature map through wavelet transformation, which is 1/2 size of q. However, in the code, down-sampled feature map is further downsampled by a factor of 4 in the first stage and by a factor of 2 in the second stage., which is not matched the paper. Could you explain the reason behind this discrepancy? Just to reduce the computational cost?
Code references:
self.kv_embed = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) if sr_ratio > 1 else nn.Identity() |
kv = self.kv_embed(x_dwt).reshape(B, C, -1).permute(0, 2, 1) |