Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: CI

on:
push:
branches: ["main", "master"]
pull_request:
branches: ["main", "master"]

jobs:
tests:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e .
pip install pytest

- name: Run tests
run: pytest
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ cd vbll
pip install -e .
```

## Testing

Install the test dependencies and run the suite with `pytest`:

```bash
python -m pip install --upgrade pip
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e .
pip install pytest
pytest
```

## Usage and Tutorials
Documentation is available [here](https://vbll.readthedocs.io/en/latest/).

Expand All @@ -55,4 +67,3 @@ If you find VBLL useful in your research, please consider citing our [paper](htt
}
```


714 changes: 714 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ python = "^3.9"
torch = ">= 1.6.0"
numpy = ">= 1.21.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.0.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
21 changes: 21 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[tool:pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
-v
--tb=short
--strict-markers
--disable-warnings
--color=yes
--durations=10
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
gpu: marks tests that require GPU
integration: marks integration tests
unit: marks unit tests
filterwarnings =
ignore::UserWarning
ignore::DeprecationWarning
ignore::FutureWarning
49 changes: 49 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# VBLL Test Suite

This directory contains comprehensive tests for the VBLL (Variational Bayesian Last Layers) library.

## Test Structure

### Core Test Files

- **`test_distributions.py`** - Tests for distribution classes (`Normal`, `DenseNormal`, `DenseNormalPrec`, `LowRankNormal`)
- **`test_classification.py`** - Tests for classification layers (`DiscClassification`, `tDiscClassification`, `HetClassification`)
- **`test_regression.py`** - Tests for regression layers (`Regression`, `tRegression`, `HetRegression`)
- **`test_integration.py`** - End-to-end integration tests

### Configuration Files

- **`conftest.py`** - Pytest fixtures and configuration
- **`__init__.py`** - Test package initialization

## Running Tests

### Basic Test Execution

```bash
# Run all tests
pytest

# Run with verbose output
pytest -v

# Run specific test file
pytest tests/test_distributions.py

# Run specific test class
pytest tests/test_classification.py::TestDiscClassification

# Run specific test function
pytest tests/test_distributions.py::TestNormal::test_initialization
```

## Adding New Tests

When adding new tests:

1. **Follow naming conventions** - Use `test_` prefix for functions
2. **Use appropriate markers** - Mark tests as `@pytest.mark.slow`, `@pytest.mark.gpu`, etc.
3. **Add docstrings** - Explain what the test validates
4. **Use fixtures** - Leverage existing fixtures for common setup
5. **Test edge cases** - Include boundary conditions and error cases
6. **Verify numerical stability** - Check for NaN, inf, and extreme values
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Test package for VBLL

142 changes: 142 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Pytest configuration and shared fixtures for VBLL tests.
"""
import pytest
import torch
import numpy as np
from typing import Tuple, Dict, Any


@pytest.fixture
def device():
"""Get the device for testing (CPU or GPU if available)."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.fixture
def seed():
"""Set random seed for reproducible tests."""
torch.manual_seed(42)
np.random.seed(42)
return 42


@pytest.fixture
def small_batch():
"""Small batch size for quick tests."""
return 32


@pytest.fixture
def medium_batch():
"""Medium batch size for more comprehensive tests."""
return 128


@pytest.fixture
def small_features():
"""Small feature dimension for quick tests."""
return 10


@pytest.fixture
def medium_features():
"""Medium feature dimension for more comprehensive tests."""
return 50


@pytest.fixture
def small_outputs():
"""Small output dimension for quick tests."""
return 5


@pytest.fixture
def medium_outputs():
"""Medium output dimension for more comprehensive tests."""
return 20


@pytest.fixture
def sample_data_small(small_batch, small_features, small_outputs, device):
"""Generate small sample data for testing."""
x = torch.randn(small_batch, small_features, device=device)
y_class = torch.randint(0, small_outputs, (small_batch,), device=device)
y_reg = torch.randn(small_batch, small_outputs, device=device)
return {
'x': x,
'y_classification': y_class,
'y_regression': y_reg,
'batch_size': small_batch,
'input_features': small_features,
'output_features': small_outputs
}


@pytest.fixture
def sample_data_medium(medium_batch, medium_features, medium_outputs, device):
"""Generate medium sample data for testing."""
x = torch.randn(medium_batch, medium_features, device=device)
y_class = torch.randint(0, medium_outputs, (medium_batch,), device=device)
y_reg = torch.randn(medium_batch, medium_outputs, device=device)
return {
'x': x,
'y_classification': y_class,
'y_regression': y_reg,
'batch_size': medium_batch,
'input_features': medium_features,
'output_features': medium_outputs
}


@pytest.fixture
def classification_params():
"""Common parameters for classification layers."""
return {
'regularization_weight': 0.1,
'prior_scale': 1.0,
'wishart_scale': 1.0,
'dof': 2.0,
'return_ood': False
}

@pytest.fixture
def het_classification_params():
"""Common parameters for hetclassification layers."""
return {
'regularization_weight': 0.1,
'prior_scale': 1.0,
'noise_prior_scale': 2.0,
'return_ood': False
}


@pytest.fixture
def regression_params():
"""Common parameters for regression layers."""
return {
'regularization_weight': 0.1,
'prior_scale': 1.0,
'wishart_scale': 1e-2,
'dof': 1.0
}

@pytest.fixture
def het_regression_params():
"""Common parameters for heteroscedastic regression layers."""
return {
'regularization_weight': 0.1,
'prior_scale': 1.0,
'noise_prior_scale': 1e-2,
}

@pytest.fixture
def parameterizations():
"""Available covariance parameterizations."""
return ['diagonal', 'dense', 'lowrank']


@pytest.fixture
def softmax_bounds():
"""Available softmax bounds for classification."""
return ['jensen', 'semimontecarlo', 'reduced_kn']
Loading