-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
Add SwiftFormer, SHViT, StarNet, FasterNet and GhostNetV3 #2499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@brianhou0208 there aren't any weights setup to test for swiftformer in the pretrained cfgs, I tested all the others and they appear to be working as expected. |
Hi @rwightman , Since the Swiftformer weights are stored on Google Drive and cannot be directly linked in the pretrained cfgs, I downloaded the weights locally for testing and was able to pass all the tests. Below are the Acc Top1 & Top5 results for reference. Acc Top1 & Top5
Shortly Test Codefrom typing import Any, Dict, Union, List
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import timm
from timm.utils.metrics import AverageMeter, accuracy
from timm.utils.model import reparameterize_model
from timm.models.swiftformer import checkpoint_filter_fn
device = torch.device('mps')
torch.mps.empty_cache()
def get_model_acc(model: torch.nn.Module):
cfg: Dict[str, Any]= model.default_cfg
_, height, width = cfg['input_size'] if 'test_input_size' not in cfg else cfg['test_input_size']
crop_pct = cfg['crop_pct'] if 'test_crop_pct' not in cfg else cfg['test_crop_pct']
imgsz = height if height == width else (height, width)
interp_mode = {"nearest": 0, "bilinear": 2, "bicubic": 3}
val_dataset = datasets.ImageFolder(
'./imagenet/val',
transforms.Compose([
transforms.Resize(int(imgsz / crop_pct), interpolation=interp_mode[cfg['interpolation']]),
transforms.CenterCrop(imgsz),
transforms.ToTensor(),
transforms.Normalize(cfg['mean'], cfg['std'])])
)
val_loader = DataLoader(
val_dataset, batch_size=64, shuffle=False, pin_memory=False, prefetch_factor=4, num_workers=4,
persistent_workers=True#, pin_memory_device='mps'
)
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
model = reparameterize_model(model)
model.to(device)
torch.mps.synchronize()
with torch.inference_mode():
for images, target in tqdm(val_loader):
images = images.to(device)
target = target.to(device)
output = model(images)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1, images.size(0))
top5.update(acc5, images.size(0))
torch.mps.synchronize()
return {"ACC@1": round(top1.avg.item(), 4), "ACC@5": round(top5.avg.item(), 4)}
model_weight_list = {
"swiftformer_xs": "./SwiftFormer_XS.pth",
"swiftformer_s": "./SwiftFormer_S.pth",
"swiftformer_l1": "./SwiftFormer_L1.pth",
"swiftformer_l3": "./SwiftFormer_L3.pth",
}
if __name__ == "__main__":
model_list = timm.list_models('swiftformer*', pretrained=False)
print(model_list)
for name in model_list:
torch.mps.empty_cache()
model = timm.create_model(name, pretrained=False).eval()
weight = torch.load(model_weight_list[name], map_location='cpu', weights_only=True)
weight = checkpoint_filter_fn(weight, model)
model.load_state_dict(weight)
result = get_model_acc(model)
print(name, result) Output ['swiftformer_l1', 'swiftformer_l3', 'swiftformer_s', 'swiftformer_xs']
swiftformer_l1 {'ACC@1': 80.902, 'ACC@5': 95.378}
swiftformer_l3 {'ACC@1': 83.0, 'ACC@5': 96.238}
swiftformer_s {'ACC@1': 78.466, 'ACC@5': 93.972}
swiftformer_xs {'ACC@1': 75.586, 'ACC@5': 92.326} |
…ckpoint filter fns and minor renames
@brianhou0208 woops, missed those drive links... okay, all sorted. Weights on the hub waiting for final checks |
@brianhou0208 thanks, all merged |
New Model
Model Request
resolve #2450
Result
Param / MACs / Throughput
NPU Latency
iOS latency reported for iPhone 14 Pro Max (iOS 18.5) uses the benchmark tool from Xcode 16.3
Android latency reported for Samsung Galaxy S24 (Android 14) uses the benchmark tool from Qualcomm® AI Hub Models
Measure Android Latency