Skip to content

ZJLAB-AMMI/PSWA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 

Repository files navigation

PSWA

This repository contains a PyTorch implementation of methods described in the paper "Stochastic Weight Averaging Revisited" by Hao Guo, Jiyong Jin, and Bin Liu (https://arxiv.org/abs/2201.00519).

Dependencies

PyTorch 1.9.0

Usage

Image Classification

The code in this sub-folder implements periodic Stochastic Weight Averaging (PSWA), conventional SGD, Double Stochastic Weight Averaging (DSWA), and Triple Stochastic Weight Averaging (TSWA) on datasets CIFAR-10 and CIFAR-100.

PSWA Training

You can train a DNN with PSWA using the following command

python3 pswa.py \
        --dir=<DIR> \
        --dataset=<DATASET> \
        --data_path=<PATH> \
        --model=<MODEL> \
        --epochs=<EPOCHS> \
        --lr_init=<LR> \
        --wd=<WD> \
        --pswa \
        --pswa_start=<PSWA_START> \
        --P=<P> 

Parameters:
DIR — path to training directory where checkpoints will be stored
DATASET — dataset name (default: CIFAR10)
PATH — path to the data directory
MODEL — DNN model name: VGG16, PreResNet164 and WideResNet28x10
EPOCHS — number of training epochs
LR — initial learning rate
WD — weight decay
PSWA_START — the number of epoch after which PSWA will start to average models (default: 40)
P — model recording period(default:20)

Example

VGG16:

CIFAR10
python3 pswa.py --dir=<DIR> --data_path=<PATH> --dataset=CIFAR10 --model=VGG16 --epochs=160 --lr_init=0.05 \
                --wd=5e-4 --pswa --pswa_start=40 --P=20
CIFAR100
python3 pswa.py --dir=<DIR> --data_path=<PATH> --dataset=CIFAR100 --model=VGG16 --epochs=160 --lr_init=0.05 \
                --wd=5e-4 --pswa --pswa_start=40 --P=20

Backbone SGD Training

You can run conventional SGD training use the following command:

python3 backbone-SGD.py \
        --dir=<DIR> \
        --dataset=<DATASET> \
        --data_path=<PATH> \
        --model=<MODEL> \
        --epochs=<EPOCHS> \
        --momentum=<MOMENTUM> \
        --lr_init=<LR> \
        --wd=<WD> \

SGD with momentum and weight decay

VGG16 CIFAR10
python3 backbone-SGD.py --dir=<DIR> --data_path=<PATH> --dataset=CIFAR10 --model=VGG16 --epochs=160 --lr_init=0.05 \
                        --momentum=0.9 --wd=5e-4

SGD without momentum and weight decay

VGG16 CIFAR10
python3 backbone-SGD.py --dir=<DIR> --data_path=<PATH> --dataset=CIFAR10 --model=VGG16 --epochs=160 --lr_init=0.05 \
                        --momentum=0 --wd=0

DSWA Training

VGG16 CIFAR10
python3 dswa.py --dir=<DIR> --data_path=<PATH> --resume=<MODEL PATH> --dataset=CIFAR10 --model=VGG16 --epochs=40 \
                --wd=5e-4 --swa --dswa --swa_start=0 --dswa_start=20

TSWA Training

VGG16 CIFAR10
python3 tswa.py --dir=<DIR> --data_path=<PATH> --resume=<MODEL PATH> --dataset=CIFAR10 --model=VGG16 --epochs=60 
                --wd=5e-4 --swa --dswa --tswa --swa_start=0 --dswa_start=20 --tswa_start=40

Graph Classification

python3 gin_swa.py --dir=<DIR> --data_path=<PATH> --dataset=MUTAG --epochs=300 --lr_init=0.01 --swa_start=270

Text Classification

python3 text-classification.py --dir=<DIR> --task_name=mrpc --model_name_or_path=<PATH> --epochs=50 \
                               --learning_rate=1e-4 --weight_decay=1e-2 --swa_start=45

Reference

Provided model implementations were adapted from

About

Python codes to implement the periodic stochastic weight averaging (PSWA) algorithm

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages