Skip to content
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,16 @@ output/
Untitled.ipynb
Testing notebook.ipynb


# MacOS
*.DS_Store

# Root dir exclusions
/*.csv
/*.yaml
/*.json
/*.jpg
/*.png
/*.zip
/*.tar.*
/*.tar.*

3 changes: 2 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*'
'eva_*', 'flexivit*', 'eva02*', 'pvig_*', 'samvit_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

Expand Down Expand Up @@ -405,6 +405,7 @@ def _create_fx_model(model, train=False):
'vit_large*',
'vit_base_patch8*',
'xcit_large*',
'pvig_*',
]


Expand Down
1 change: 1 addition & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .gnn_layers import DyGraphConv2d
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
Expand Down
215 changes: 215 additions & 0 deletions timm/layers/gnn_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Layers for GNN model
# Reference: https://github.com/lightaime/deep_gcns_torch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from .drop import DropPath


def pairwise_distance(x, y):
"""
Compute pairwise distance of a point cloud
"""
with torch.no_grad():
xy_inner = -2*torch.matmul(x, y.transpose(2, 1))
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True)
return x_square + xy_inner + y_square.transpose(2, 1)


def dense_knn_matrix(x, y, k=16, relative_pos=None):
"""Get KNN based on the pairwise distance
"""
with torch.no_grad():
x = x.transpose(2, 1).squeeze(-1)
y = y.transpose(2, 1).squeeze(-1)
batch_size, n_points, n_dims = x.shape
dist = pairwise_distance(x.detach(), y.detach())
if relative_pos is not None:
dist += relative_pos
_, nn_idx = torch.topk(-dist, k=k)
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
return torch.stack((nn_idx, center_idx), dim=0)


class DenseDilated(nn.Module):
"""
Find dilated neighbor from neighbor list
"""
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(DenseDilated, self).__init__()
self.dilation = dilation
self.stochastic = stochastic
self.epsilon = epsilon
self.k = k

def forward(self, edge_index):
if self.stochastic:
if torch.rand(1) < self.epsilon and self.training:
num = self.k * self.dilation
randnum = torch.randperm(num)[:self.k]
edge_index = edge_index[:, :, :, randnum]
else:
edge_index = edge_index[:, :, :, ::self.dilation]
else:
edge_index = edge_index[:, :, :, ::self.dilation]
return edge_index


class DenseDilatedKnnGraph(nn.Module):
"""
Find the neighbors' indices based on dilated knn
"""
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(DenseDilatedKnnGraph, self).__init__()
self.dilation = dilation
self.k = k
self._dilated = DenseDilated(k, dilation, stochastic, epsilon)

def forward(self, x, y=None, relative_pos=None):
x = F.normalize(x, p=2.0, dim=1)
if y is not None:
y = F.normalize(y, p=2.0, dim=1)
edge_index = dense_knn_matrix(x, y, self.k * self.dilation, relative_pos)
else:
edge_index = dense_knn_matrix(x, x, self.k * self.dilation, relative_pos)
return self._dilated(edge_index)


def batched_index_select(x, idx):
# fetches neighbors features from a given neighbor idx
batch_size, num_dims, num_vertices_reduced = x.shape[:3]
_, num_vertices, k = idx.shape
idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced
idx = idx + idx_base
idx = idx.contiguous().view(-1)

x = x.transpose(2, 1)
feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :]
feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous()
return feature


def norm_layer(norm, nc):
# normalization layer 2d
norm = norm.lower()
if norm == 'batch':
layer = nn.BatchNorm2d(nc, affine=True)
elif norm == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm)
return layer


class MRConv2d(nn.Module):
"""
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
"""
def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True):
super(MRConv2d, self).__init__()
# self.nn = BasicConv([in_channels*2, out_channels], act_layer, norm, bias)
self.nn = nn.Sequential(
nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4),
norm_layer(norm, out_channels),
act_layer(),
)

self.init_weights()

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, x, edge_index, y=None):
x_i = batched_index_select(x, edge_index[1])
if y is not None:
x_j = batched_index_select(y, edge_index[0])
else:
x_j = batched_index_select(x, edge_index[0])
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
b, c, n, _ = x.shape
x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _)
return self.nn(x)


class EdgeConv2d(nn.Module):
"""
Edge convolution layer (with activation, batch normalization) for dense data type
"""
def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True):
super(EdgeConv2d, self).__init__()
self.nn = nn.Sequential(
nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4),
norm_layer(norm, out_channels),
act_layer(),
)

self.init_weights()

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, x, edge_index, y=None):
x_i = batched_index_select(x, edge_index[1])
if y is not None:
x_j = batched_index_select(y, edge_index[0])
else:
x_j = batched_index_select(x, edge_index[0])
max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
return max_value


class GraphConv2d(nn.Module):
"""
Static graph convolution layer
"""
def __init__(self, in_channels, out_channels, conv='mr', act_layer=nn.GELU, norm=None, bias=True):
super(GraphConv2d, self).__init__()
if conv == 'edge':
self.gconv = EdgeConv2d(in_channels, out_channels, act_layer, norm, bias)
elif conv == 'mr':
self.gconv = MRConv2d(in_channels, out_channels, act_layer, norm, bias)
else:
raise NotImplementedError('conv:{} is not supported'.format(conv))

def forward(self, x, edge_index, y=None):
return self.gconv(x, edge_index, y)


class DyGraphConv2d(GraphConv2d):
"""
Dynamic graph convolution layer
"""
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU,
norm=None, bias=True, stochastic=False, epsilon=0.0, r=1):
super(DyGraphConv2d, self).__init__(in_channels, out_channels, conv, act_layer, norm, bias)
self.k = kernel_size
self.d = dilation
self.r = r
self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)

def forward(self, x, relative_pos=None):
B, C, H, W = x.shape
y = None
if self.r > 1:
y = F.avg_pool2d(x, self.r, self.r)
y = y.reshape(B, C, -1, 1).contiguous()
x = x.reshape(B, C, -1, 1).contiguous()
edge_index = self.dilated_knn_graph(x, y, relative_pos)
x = super(DyGraphConv2d, self).forward(x, edge_index, y)
return x.reshape(B, -1, H, W).contiguous()
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .twins import *
from .vgg import *
from .visformer import *
from .vision_gnn import *
from .vision_transformer import *
from .vision_transformer_hybrid import *
from .vision_transformer_relpos import *
Expand Down
1 change: 1 addition & 0 deletions timm/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from timm.layers.gather_excite import GatherExcite
from timm.layers.global_context import GlobalContext
from timm.layers.gnn_layers import DyGraphConv2d
from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from timm.layers.inplace_abn import InplaceAbn
from timm.layers.linear import Linear
Expand Down
Loading