Skip to content

Data Type Mismatch and Discrepancy in K and V Acquisition in Self-Attention #14

@camerayuhang

Description

@camerayuhang

A Data Type Mismatch Causes a Code Error

In the code for wavelet transformation and invert wavelet transformation, why are the type of wavelet filters cast to torch.float16? The input x is still torch.float32, which cause error when running code due to mismatch. Did I encounter this issue because I used float32 inputs during debugging, while you used float16? If that’s the case, I am sorry, because I mistakenly assumed the data type mismatch was a problem.

Code references:

self.w_ll = self.w_ll.to(dtype=torch.float16)

self.filters = self.filters.to(dtype=torch.float16)

The Acquisition of K and V in Self-Attention Does Not Match the Paper

Screenshot 2024-10-01 at 23 33 41

According to the paper, k and v is obtained from down-sampled feature map through wavelet transformation, which is 1/2 size of q. However, in the code, down-sampled feature map is further downsampled by a factor of 4 in the first stage and by a factor of 2 in the second stage., which is not matched the paper. Could you explain the reason behind this discrepancy? Just to reduce the computational cost?

Code references:

self.kv_embed = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) if sr_ratio > 1 else nn.Identity()

kv = self.kv_embed(x_dwt).reshape(B, C, -1).permute(0, 2, 1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions