Skip to content

Latest commit

 

History

History
64 lines (44 loc) · 2.43 KB

File metadata and controls

64 lines (44 loc) · 2.43 KB

Segmentation Algorithms

This repository provides a minimal framework to train and evaluate various image segmentation models on datasets in the VOC2012 format. Models include:

  • UNet
  • DeepLabV3+
  • segNext (uses a simple fallback if the segnext package is unavailable)
  • Swin Transformer
  • SegFormer (falls back to a lightweight implementation when the transformers package is missing)
  • PVT (uses a basic model if the pvt package is unavailable)

The training pipeline uses Hugging Face Accelerate for easy single or multi-GPU training and Weights & Biases for optional experiment tracking.

Installation

Install the required packages:

pip install torch torchvision accelerate tqdm wandb segmentation-models-pytorch timm transformers pyyaml

Additional packages are needed for segNext or PVT models.

Configuration

Sample configuration files for each model are stored in segmentation/configs/. Use these as starting points and modify as needed. Each configuration includes an output path specifying where checkpoints for that model will be saved.

Training

python -m segmentation.train --data-dir /path/to/VOC2012 \
    --config segmentation/configs/unet.yaml --wandb

Checkpoints and logs are written to a directory specified by the configuration's output path. Inside this directory the script creates train_logs and checkpoints subfolders. Metrics for every epoch are appended to train_logs/metrics.csv, while the best model (by mIoU) is stored as checkpoints/best.pt. Use --save-every N to control how often regular checkpoints are written.

Each model has a sample YAML configuration under segmentation/configs/. These files define the model variant, whether to use pretrained weights and other hyperparameters. Command line arguments override values from the config file.

Inference

python -m segmentation.inference --config segmentation/configs/unet.yaml \
    --checkpoint checkpoints/model_49.pt \
    --image input.jpg --output pred.png

Metrics

During validation the following metrics are computed:

  • mIoU – mean intersection over union
  • Accuracy – pixel accuracy
  • F1-score – averaged F1-score over all classes

Dataset

Training and validation loaders expect the standard VOC2012 directory structure. Download the dataset from the official site.