diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..2218eb7 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,30 @@ +name: CI + +on: + push: + branches: ["main", "master"] + pull_request: + branches: ["main", "master"] + +jobs: + tests: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install torch --extra-index-url https://download.pytorch.org/whl/cpu + pip install -e . + pip install pytest + + - name: Run tests + run: pytest diff --git a/README.md b/README.md index 0eaf296..0e663ba 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,18 @@ cd vbll pip install -e . ``` +## Testing + +Install the test dependencies and run the suite with `pytest`: + +```bash +python -m pip install --upgrade pip +pip install torch --extra-index-url https://download.pytorch.org/whl/cpu +pip install -e . +pip install pytest +pytest +``` + ## Usage and Tutorials Documentation is available [here](https://vbll.readthedocs.io/en/latest/). @@ -55,4 +67,3 @@ If you find VBLL useful in your research, please consider citing our [paper](htt } ``` - diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..1394619 --- /dev/null +++ b/poetry.lock @@ -0,0 +1,714 @@ +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "exceptiongroup" +version = "1.3.0" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"}, + {file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.6.0", markers = "python_version < \"3.13\""} + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "filelock" +version = "3.19.1" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.9" +files = [ + {file = "filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d"}, + {file = "filelock-3.19.1.tar.gz", hash = "sha256:66eda1888b0171c998b35be2bcc0f6d75c388a7ce20c3f3f37aa8e96c2dddf58"}, +] + +[[package]] +name = "fsspec" +version = "2025.9.0" +description = "File-system specification" +optional = false +python-versions = ">=3.9" +files = [ + {file = "fsspec-2025.9.0-py3-none-any.whl", hash = "sha256:530dc2a2af60a414a832059574df4a6e10cce927f6f4a78209390fe38955cfb7"}, + {file = "fsspec-2025.9.0.tar.gz", hash = "sha256:19fd429483d25d28b65ec68f9f4adc16c17ea2c7c7bf54ec61360d478fb19c19"}, +] + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +dev = ["pre-commit", "ruff (>=0.5)"] +doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"] +test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"] +test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] +tqdm = ["tqdm"] + +[[package]] +name = "iniconfig" +version = "2.1.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.8" +files = [ + {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, + {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, +] + +[[package]] +name = "jinja2" +version = "3.1.6" +description = "A very fast and expressive template engine." +optional = false +python-versions = ">=3.7" +files = [ + {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"}, + {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + +[[package]] +name = "markupsafe" +version = "3.0.3" +description = "Safely add untrusted strings to HTML/XML markup." +optional = false +python-versions = ">=3.9" +files = [ + {file = "markupsafe-3.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f981d352f04553a7171b8e44369f2af4055f888dfb147d55e42d29e29e74559"}, + {file = "markupsafe-3.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1c1493fb6e50ab01d20a22826e57520f1284df32f2d8601fdd90b6304601419"}, + {file = "markupsafe-3.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1ba88449deb3de88bd40044603fafffb7bc2b055d626a330323a9ed736661695"}, + {file = "markupsafe-3.0.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f42d0984e947b8adf7dd6dde396e720934d12c506ce84eea8476409563607591"}, + {file = "markupsafe-3.0.3-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c0c0b3ade1c0b13b936d7970b1d37a57acde9199dc2aecc4c336773e1d86049c"}, + {file = "markupsafe-3.0.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0303439a41979d9e74d18ff5e2dd8c43ed6c6001fd40e5bf2e43f7bd9bbc523f"}, + {file = "markupsafe-3.0.3-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:d2ee202e79d8ed691ceebae8e0486bd9a2cd4794cec4824e1c99b6f5009502f6"}, + {file = "markupsafe-3.0.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:177b5253b2834fe3678cb4a5f0059808258584c559193998be2601324fdeafb1"}, + {file = "markupsafe-3.0.3-cp310-cp310-win32.whl", hash = "sha256:2a15a08b17dd94c53a1da0438822d70ebcd13f8c3a95abe3a9ef9f11a94830aa"}, + {file = "markupsafe-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:c4ffb7ebf07cfe8931028e3e4c85f0357459a3f9f9490886198848f4fa002ec8"}, + {file = "markupsafe-3.0.3-cp310-cp310-win_arm64.whl", hash = "sha256:e2103a929dfa2fcaf9bb4e7c091983a49c9ac3b19c9061b6d5427dd7d14d81a1"}, + {file = "markupsafe-3.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1cc7ea17a6824959616c525620e387f6dd30fec8cb44f649e31712db02123dad"}, + {file = "markupsafe-3.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4bd4cd07944443f5a265608cc6aab442e4f74dff8088b0dfc8238647b8f6ae9a"}, + {file = "markupsafe-3.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b5420a1d9450023228968e7e6a9ce57f65d148ab56d2313fcd589eee96a7a50"}, + {file = "markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bf2a864d67e76e5c9a34dc26ec616a66b9888e25e7b9460e1c76d3293bd9dbf"}, + {file = "markupsafe-3.0.3-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc51efed119bc9cfdf792cdeaa4d67e8f6fcccab66ed4bfdd6bde3e59bfcbb2f"}, + {file = "markupsafe-3.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:068f375c472b3e7acbe2d5318dea141359e6900156b5b2ba06a30b169086b91a"}, + {file = "markupsafe-3.0.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7be7b61bb172e1ed687f1754f8e7484f1c8019780f6f6b0786e76bb01c2ae115"}, + {file = "markupsafe-3.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a"}, + {file = "markupsafe-3.0.3-cp311-cp311-win32.whl", hash = "sha256:0db14f5dafddbb6d9208827849fad01f1a2609380add406671a26386cdf15a19"}, + {file = "markupsafe-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:de8a88e63464af587c950061a5e6a67d3632e36df62b986892331d4620a35c01"}, + {file = "markupsafe-3.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:3b562dd9e9ea93f13d53989d23a7e775fdfd1066c33494ff43f5418bc8c58a5c"}, + {file = "markupsafe-3.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e"}, + {file = "markupsafe-3.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce"}, + {file = "markupsafe-3.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d"}, + {file = "markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d"}, + {file = "markupsafe-3.0.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a"}, + {file = "markupsafe-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b"}, + {file = "markupsafe-3.0.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f"}, + {file = "markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b"}, + {file = "markupsafe-3.0.3-cp312-cp312-win32.whl", hash = "sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d"}, + {file = "markupsafe-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c"}, + {file = "markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f"}, + {file = "markupsafe-3.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795"}, + {file = "markupsafe-3.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219"}, + {file = "markupsafe-3.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6"}, + {file = "markupsafe-3.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ccfcd093f13f0f0b7fdd0f198b90053bf7b2f02a3927a30e63f3ccc9df56b676"}, + {file = "markupsafe-3.0.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:509fa21c6deb7a7a273d629cf5ec029bc209d1a51178615ddf718f5918992ab9"}, + {file = "markupsafe-3.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4afe79fb3de0b7097d81da19090f4df4f8d3a2b3adaa8764138aac2e44f3af1"}, + {file = "markupsafe-3.0.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:795e7751525cae078558e679d646ae45574b47ed6e7771863fcc079a6171a0fc"}, + {file = "markupsafe-3.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8485f406a96febb5140bfeca44a73e3ce5116b2501ac54fe953e488fb1d03b12"}, + {file = "markupsafe-3.0.3-cp313-cp313-win32.whl", hash = "sha256:bdd37121970bfd8be76c5fb069c7751683bdf373db1ed6c010162b2a130248ed"}, + {file = "markupsafe-3.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:9a1abfdc021a164803f4d485104931fb8f8c1efd55bc6b748d2f5774e78b62c5"}, + {file = "markupsafe-3.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:7e68f88e5b8799aa49c85cd116c932a1ac15caaa3f5db09087854d218359e485"}, + {file = "markupsafe-3.0.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:218551f6df4868a8d527e3062d0fb968682fe92054e89978594c28e642c43a73"}, + {file = "markupsafe-3.0.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3524b778fe5cfb3452a09d31e7b5adefeea8c5be1d43c4f810ba09f2ceb29d37"}, + {file = "markupsafe-3.0.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4e885a3d1efa2eadc93c894a21770e4bc67899e3543680313b09f139e149ab19"}, + {file = "markupsafe-3.0.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8709b08f4a89aa7586de0aadc8da56180242ee0ada3999749b183aa23df95025"}, + {file = "markupsafe-3.0.3-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:b8512a91625c9b3da6f127803b166b629725e68af71f8184ae7e7d54686a56d6"}, + {file = "markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9b79b7a16f7fedff2495d684f2b59b0457c3b493778c9eed31111be64d58279f"}, + {file = "markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:12c63dfb4a98206f045aa9563db46507995f7ef6d83b2f68eda65c307c6829eb"}, + {file = "markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8f71bc33915be5186016f675cd83a1e08523649b0e33efdb898db577ef5bb009"}, + {file = "markupsafe-3.0.3-cp313-cp313t-win32.whl", hash = "sha256:69c0b73548bc525c8cb9a251cddf1931d1db4d2258e9599c28c07ef3580ef354"}, + {file = "markupsafe-3.0.3-cp313-cp313t-win_amd64.whl", hash = "sha256:1b4b79e8ebf6b55351f0d91fe80f893b4743f104bff22e90697db1590e47a218"}, + {file = "markupsafe-3.0.3-cp313-cp313t-win_arm64.whl", hash = "sha256:ad2cf8aa28b8c020ab2fc8287b0f823d0a7d8630784c31e9ee5edea20f406287"}, + {file = "markupsafe-3.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:eaa9599de571d72e2daf60164784109f19978b327a3910d3e9de8c97b5b70cfe"}, + {file = "markupsafe-3.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c47a551199eb8eb2121d4f0f15ae0f923d31350ab9280078d1e5f12b249e0026"}, + {file = "markupsafe-3.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f34c41761022dd093b4b6896d4810782ffbabe30f2d443ff5f083e0cbbb8c737"}, + {file = "markupsafe-3.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:457a69a9577064c05a97c41f4e65148652db078a3a509039e64d3467b9e7ef97"}, + {file = "markupsafe-3.0.3-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e8afc3f2ccfa24215f8cb28dcf43f0113ac3c37c2f0f0806d8c70e4228c5cf4d"}, + {file = "markupsafe-3.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ec15a59cf5af7be74194f7ab02d0f59a62bdcf1a537677ce67a2537c9b87fcda"}, + {file = "markupsafe-3.0.3-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:0eb9ff8191e8498cca014656ae6b8d61f39da5f95b488805da4bb029cccbfbaf"}, + {file = "markupsafe-3.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2713baf880df847f2bece4230d4d094280f4e67b1e813eec43b4c0e144a34ffe"}, + {file = "markupsafe-3.0.3-cp314-cp314-win32.whl", hash = "sha256:729586769a26dbceff69f7a7dbbf59ab6572b99d94576a5592625d5b411576b9"}, + {file = "markupsafe-3.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:bdc919ead48f234740ad807933cdf545180bfbe9342c2bb451556db2ed958581"}, + {file = "markupsafe-3.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:5a7d5dc5140555cf21a6fefbdbf8723f06fcd2f63ef108f2854de715e4422cb4"}, + {file = "markupsafe-3.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:1353ef0c1b138e1907ae78e2f6c63ff67501122006b0f9abad68fda5f4ffc6ab"}, + {file = "markupsafe-3.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1085e7fbddd3be5f89cc898938f42c0b3c711fdcb37d75221de2666af647c175"}, + {file = "markupsafe-3.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1b52b4fb9df4eb9ae465f8d0c228a00624de2334f216f178a995ccdcf82c4634"}, + {file = "markupsafe-3.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50"}, + {file = "markupsafe-3.0.3-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f190daf01f13c72eac4efd5c430a8de82489d9cff23c364c3ea822545032993e"}, + {file = "markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e56b7d45a839a697b5eb268c82a71bd8c7f6c94d6fd50c3d577fa39a9f1409f5"}, + {file = "markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:f3e98bb3798ead92273dc0e5fd0f31ade220f59a266ffd8a4f6065e0a3ce0523"}, + {file = "markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5678211cb9333a6468fb8d8be0305520aa073f50d17f089b5b4b477ea6e67fdc"}, + {file = "markupsafe-3.0.3-cp314-cp314t-win32.whl", hash = "sha256:915c04ba3851909ce68ccc2b8e2cd691618c4dc4c4232fb7982bca3f41fd8c3d"}, + {file = "markupsafe-3.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4faffd047e07c38848ce017e8725090413cd80cbc23d86e55c587bf979e579c9"}, + {file = "markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa"}, + {file = "markupsafe-3.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:15d939a21d546304880945ca1ecb8a039db6b4dc49b2c5a400387cdae6a62e26"}, + {file = "markupsafe-3.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f71a396b3bf33ecaa1626c255855702aca4d3d9fea5e051b41ac59a9c1c41edc"}, + {file = "markupsafe-3.0.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f4b68347f8c5eab4a13419215bdfd7f8c9b19f2b25520968adfad23eb0ce60c"}, + {file = "markupsafe-3.0.3-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8fc20152abba6b83724d7ff268c249fa196d8259ff481f3b1476383f8f24e42"}, + {file = "markupsafe-3.0.3-cp39-cp39-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:949b8d66bc381ee8b007cd945914c721d9aba8e27f71959d750a46f7c282b20b"}, + {file = "markupsafe-3.0.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:3537e01efc9d4dccdf77221fb1cb3b8e1a38d5428920e0657ce299b20324d758"}, + {file = "markupsafe-3.0.3-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:591ae9f2a647529ca990bc681daebdd52c8791ff06c2bfa05b65163e28102ef2"}, + {file = "markupsafe-3.0.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a320721ab5a1aba0a233739394eb907f8c8da5c98c9181d1161e77a0c8e36f2d"}, + {file = "markupsafe-3.0.3-cp39-cp39-win32.whl", hash = "sha256:df2449253ef108a379b8b5d6b43f4b1a8e81a061d6537becd5582fba5f9196d7"}, + {file = "markupsafe-3.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:7c3fb7d25180895632e5d3148dbdc29ea38ccb7fd210aa27acbd1201a1902c6e"}, + {file = "markupsafe-3.0.3-cp39-cp39-win_arm64.whl", hash = "sha256:38664109c14ffc9e7437e86b4dceb442b0096dfe3541d7864d9cbe1da4cf36c8"}, + {file = "markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698"}, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = false +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + +[[package]] +name = "networkx" +version = "3.2.1" +description = "Python package for creating and manipulating graphs and networks" +optional = false +python-versions = ">=3.9" +files = [ + {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, + {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, +] + +[package.extras] +default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] +test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] + +[[package]] +name = "numpy" +version = "2.0.2" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-2.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece"}, + {file = "numpy-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04"}, + {file = "numpy-2.0.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66"}, + {file = "numpy-2.0.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b"}, + {file = "numpy-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd"}, + {file = "numpy-2.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318"}, + {file = "numpy-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8"}, + {file = "numpy-2.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326"}, + {file = "numpy-2.0.2-cp310-cp310-win32.whl", hash = "sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97"}, + {file = "numpy-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131"}, + {file = "numpy-2.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448"}, + {file = "numpy-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195"}, + {file = "numpy-2.0.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57"}, + {file = "numpy-2.0.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a"}, + {file = "numpy-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669"}, + {file = "numpy-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951"}, + {file = "numpy-2.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9"}, + {file = "numpy-2.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15"}, + {file = "numpy-2.0.2-cp311-cp311-win32.whl", hash = "sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4"}, + {file = "numpy-2.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc"}, + {file = "numpy-2.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b"}, + {file = "numpy-2.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e"}, + {file = "numpy-2.0.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c"}, + {file = "numpy-2.0.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c"}, + {file = "numpy-2.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692"}, + {file = "numpy-2.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a"}, + {file = "numpy-2.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c"}, + {file = "numpy-2.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded"}, + {file = "numpy-2.0.2-cp312-cp312-win32.whl", hash = "sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5"}, + {file = "numpy-2.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a"}, + {file = "numpy-2.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c"}, + {file = "numpy-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd"}, + {file = "numpy-2.0.2-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b"}, + {file = "numpy-2.0.2-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729"}, + {file = "numpy-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1"}, + {file = "numpy-2.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd"}, + {file = "numpy-2.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d"}, + {file = "numpy-2.0.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d"}, + {file = "numpy-2.0.2-cp39-cp39-win32.whl", hash = "sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa"}, + {file = "numpy-2.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73"}, + {file = "numpy-2.0.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8"}, + {file = "numpy-2.0.2-pp39-pypy39_pp73-macosx_14_0_x86_64.whl", hash = "sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4"}, + {file = "numpy-2.0.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c"}, + {file = "numpy-2.0.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385"}, + {file = "numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78"}, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.6.4.1" +description = "CUBLAS native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb"}, + {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668"}, + {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.6.80" +description = "CUDA profiling tools runtime libs." +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc"}, + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4"}, + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132"}, + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73"}, + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.6.77" +description = "NVRTC native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5847f1d6e5b757f1d2b3991a01082a44aad6f10ab3c5c0213fa3e25bddc25a13"}, + {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53"}, + {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.6.77" +description = "CUDA Runtime native Libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd"}, + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e"}, + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7"}, + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8"}, + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.5.1.17" +description = "cuDNN runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9fd4584468533c61873e5fda8ca41bac3a38bcb2d12350830c69b0a96a7e4def"}, + {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2"}, + {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.0.4" +description = "CUFFT native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6"}, + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8510990de9f96c803a051822618d42bf6cb8f069ff3f48d93a8486efdacb48fb"}, + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5"}, + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca"}, + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.11.1.6" +description = "cuFile GPUDirect libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159"}, + {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:8f57a0051dcf2543f6dc2b98a98cb2719c37d3cee1baba8965d57f3bbc90d4db"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.7.77" +description = "CURAND native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6e82df077060ea28e37f48a3ec442a8f47690c7499bff392a5938614b56c98d8"}, + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf"}, + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117"}, + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7b2ed8e95595c3591d984ea3603dd66fe6ce6812b886d59049988a712ed06b6e"}, + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.1.2" +description = "CUDA solver native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0"}, + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c"}, + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6"}, + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e"}, + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.4.2" +description = "CUSPARSE native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887"}, + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7aa32fa5470cf754f72d1116c7cbc300b4e638d3ae5304cfa4a638a5b87161b1"}, + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73"}, + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f"}, + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.6.3" +description = "NVIDIA cuSPARSELt" +optional = false +python-versions = "*" +files = [ + {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8371549623ba601a06322af2133c4a44350575f5a3108fb75f3ef20b822ad5f1"}, + {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46"}, + {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-win_amd64.whl", hash = "sha256:3b325bcbd9b754ba43df5a311488fca11a6b5dc3d11df4d190c000cf1a0765c7"}, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.26.2" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522"}, + {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.6.85" +description = "Nvidia JIT LTO Library" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a"}, + {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41"}, + {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.6.77" +description = "NVIDIA Tools Extension" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f44f8d86bb7d5629988d61c8d3ae61dddb2015dee142740536bc7481b022fe4b"}, + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:adcaabb9d436c9761fca2b13959a2d237c5f9fd406c8e4b723c695409ff88059"}, + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2"}, + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1"}, + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:2fb11a4af04a5e6c84073e6404d26588a34afd35379f0855a99797897efa75c0"}, +] + +[[package]] +name = "packaging" +version = "25.0" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"}, + {file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"}, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"}, + {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["coverage", "pytest", "pytest-benchmark"] + +[[package]] +name = "pytest" +version = "7.4.4" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, + {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "setuptools" +version = "80.9.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.9" +files = [ + {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, + {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"] +core = ["importlib_metadata (>=6)", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"] + +[[package]] +name = "sympy" +version = "1.14.0" +description = "Computer algebra system (CAS) in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5"}, + {file = "sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517"}, +] + +[package.dependencies] +mpmath = ">=1.1.0,<1.4" + +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] + +[[package]] +name = "tomli" +version = "2.2.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"}, + {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"}, + {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"}, + {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"}, + {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"}, + {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"}, + {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"}, + {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"}, + {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, +] + +[[package]] +name = "torch" +version = "2.7.1" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +optional = false +python-versions = ">=3.9.0" +files = [ + {file = "torch-2.7.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a103b5d782af5bd119b81dbcc7ffc6fa09904c423ff8db397a1e6ea8fd71508f"}, + {file = "torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:fe955951bdf32d182ee8ead6c3186ad54781492bf03d547d31771a01b3d6fb7d"}, + {file = "torch-2.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:885453d6fba67d9991132143bf7fa06b79b24352f4506fd4d10b309f53454162"}, + {file = "torch-2.7.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:d72acfdb86cee2a32c0ce0101606f3758f0d8bb5f8f31e7920dc2809e963aa7c"}, + {file = "torch-2.7.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:236f501f2e383f1cb861337bdf057712182f910f10aeaf509065d54d339e49b2"}, + {file = "torch-2.7.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:06eea61f859436622e78dd0cdd51dbc8f8c6d76917a9cf0555a333f9eac31ec1"}, + {file = "torch-2.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:8273145a2e0a3c6f9fd2ac36762d6ee89c26d430e612b95a99885df083b04e52"}, + {file = "torch-2.7.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:aea4fc1bf433d12843eb2c6b2204861f43d8364597697074c8d38ae2507f8730"}, + {file = "torch-2.7.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:27ea1e518df4c9de73af7e8a720770f3628e7f667280bce2be7a16292697e3fa"}, + {file = "torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:c33360cfc2edd976c2633b3b66c769bdcbbf0e0b6550606d188431c81e7dd1fc"}, + {file = "torch-2.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:d8bf6e1856ddd1807e79dc57e54d3335f2b62e6f316ed13ed3ecfe1fc1df3d8b"}, + {file = "torch-2.7.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:787687087412c4bd68d315e39bc1223f08aae1d16a9e9771d95eabbb04ae98fb"}, + {file = "torch-2.7.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:03563603d931e70722dce0e11999d53aa80a375a3d78e6b39b9f6805ea0a8d28"}, + {file = "torch-2.7.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:d632f5417b6980f61404a125b999ca6ebd0b8b4bbdbb5fbbba44374ab619a412"}, + {file = "torch-2.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:23660443e13995ee93e3d844786701ea4ca69f337027b05182f5ba053ce43b38"}, + {file = "torch-2.7.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:0da4f4dba9f65d0d203794e619fe7ca3247a55ffdcbd17ae8fb83c8b2dc9b585"}, + {file = "torch-2.7.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:e08d7e6f21a617fe38eeb46dd2213ded43f27c072e9165dc27300c9ef9570934"}, + {file = "torch-2.7.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:30207f672328a42df4f2174b8f426f354b2baa0b7cca3a0adb3d6ab5daf00dc8"}, + {file = "torch-2.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:79042feca1c634aaf6603fe6feea8c6b30dfa140a6bbc0b973e2260c7e79a22e"}, + {file = "torch-2.7.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:988b0cbc4333618a1056d2ebad9eb10089637b659eb645434d0809d8d937b946"}, + {file = "torch-2.7.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:e0d81e9a12764b6f3879a866607c8ae93113cbcad57ce01ebde63eb48a576369"}, + {file = "torch-2.7.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8394833c44484547ed4a47162318337b88c97acdb3273d85ea06e03ffff44998"}, + {file = "torch-2.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:df41989d9300e6e3c19ec9f56f856187a6ef060c3662fe54f4b6baf1fc90bd19"}, + {file = "torch-2.7.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:a737b5edd1c44a5c1ece2e9f3d00df9d1b3fb9541138bee56d83d38293fb6c9d"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +jinja2 = "*" +networkx = "*" +nvidia-cublas-cu12 = {version = "12.6.4.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.6.80", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.6.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.6.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "9.5.1.17", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.3.0.4", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufile-cu12 = {version = "1.11.1.6", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.7.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.7.1.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.5.4.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparselt-cu12 = {version = "0.6.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.26.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvjitlink-cu12 = {version = "12.6.85", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.6.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +setuptools = {version = "*", markers = "python_version >= \"3.12\""} +sympy = ">=1.13.3" +triton = {version = "3.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +typing-extensions = ">=4.10.0" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] +optree = ["optree (>=0.13.0)"] + +[[package]] +name = "triton" +version = "3.3.1" +description = "A language and compiler for custom Deep Learning operations" +optional = false +python-versions = "*" +files = [ + {file = "triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b74db445b1c562844d3cfad6e9679c72e93fdfb1a90a24052b03bb5c49d1242e"}, + {file = "triton-3.3.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b31e3aa26f8cb3cc5bf4e187bf737cbacf17311e1112b781d4a059353dfd731b"}, + {file = "triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9999e83aba21e1a78c1f36f21bce621b77bcaa530277a50484a7cb4a822f6e43"}, + {file = "triton-3.3.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b89d846b5a4198317fec27a5d3a609ea96b6d557ff44b56c23176546023c4240"}, + {file = "triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42"}, + {file = "triton-3.3.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f6139aeb04a146b0b8e0fbbd89ad1e65861c57cfed881f21d62d3cb94a36bab7"}, +] + +[package.dependencies] +setuptools = ">=40.8.0" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "isort", "llnl-hatchet", "numpy", "pytest", "pytest-forked", "pytest-xdist", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +description = "Backported and Experimental Type Hints for Python 3.9+" +optional = false +python-versions = ">=3.9" +files = [ + {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, + {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, +] + +[metadata] +lock-version = "2.0" +python-versions = "^3.9" +content-hash = "0ffcd7df83911a51a58be2e339e877eaead9941a81b5358bc98238bf62f536b6" diff --git a/pyproject.toml b/pyproject.toml index 313cad4..7b4a67d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,9 @@ python = "^3.9" torch = ">= 1.6.0" numpy = ">= 1.21.0" +[tool.poetry.group.dev.dependencies] +pytest = "^7.0.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..8c20334 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,21 @@ +[tool:pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + -v + --tb=short + --strict-markers + --disable-warnings + --color=yes + --durations=10 +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + gpu: marks tests that require GPU + integration: marks integration tests + unit: marks unit tests +filterwarnings = + ignore::UserWarning + ignore::DeprecationWarning + ignore::FutureWarning \ No newline at end of file diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..8a19043 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,49 @@ +# VBLL Test Suite + +This directory contains comprehensive tests for the VBLL (Variational Bayesian Last Layers) library. + +## Test Structure + +### Core Test Files + +- **`test_distributions.py`** - Tests for distribution classes (`Normal`, `DenseNormal`, `DenseNormalPrec`, `LowRankNormal`) +- **`test_classification.py`** - Tests for classification layers (`DiscClassification`, `tDiscClassification`, `HetClassification`) +- **`test_regression.py`** - Tests for regression layers (`Regression`, `tRegression`, `HetRegression`) +- **`test_integration.py`** - End-to-end integration tests + +### Configuration Files + +- **`conftest.py`** - Pytest fixtures and configuration +- **`__init__.py`** - Test package initialization + +## Running Tests + +### Basic Test Execution + +```bash +# Run all tests +pytest + +# Run with verbose output +pytest -v + +# Run specific test file +pytest tests/test_distributions.py + +# Run specific test class +pytest tests/test_classification.py::TestDiscClassification + +# Run specific test function +pytest tests/test_distributions.py::TestNormal::test_initialization +``` + +## Adding New Tests + +When adding new tests: + +1. **Follow naming conventions** - Use `test_` prefix for functions +2. **Use appropriate markers** - Mark tests as `@pytest.mark.slow`, `@pytest.mark.gpu`, etc. +3. **Add docstrings** - Explain what the test validates +4. **Use fixtures** - Leverage existing fixtures for common setup +5. **Test edge cases** - Include boundary conditions and error cases +6. **Verify numerical stability** - Check for NaN, inf, and extreme values \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..7998a5a --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# Test package for VBLL + diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ba24de0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,142 @@ +""" +Pytest configuration and shared fixtures for VBLL tests. +""" +import pytest +import torch +import numpy as np +from typing import Tuple, Dict, Any + + +@pytest.fixture +def device(): + """Get the device for testing (CPU or GPU if available).""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.fixture +def seed(): + """Set random seed for reproducible tests.""" + torch.manual_seed(42) + np.random.seed(42) + return 42 + + +@pytest.fixture +def small_batch(): + """Small batch size for quick tests.""" + return 32 + + +@pytest.fixture +def medium_batch(): + """Medium batch size for more comprehensive tests.""" + return 128 + + +@pytest.fixture +def small_features(): + """Small feature dimension for quick tests.""" + return 10 + + +@pytest.fixture +def medium_features(): + """Medium feature dimension for more comprehensive tests.""" + return 50 + + +@pytest.fixture +def small_outputs(): + """Small output dimension for quick tests.""" + return 5 + + +@pytest.fixture +def medium_outputs(): + """Medium output dimension for more comprehensive tests.""" + return 20 + + +@pytest.fixture +def sample_data_small(small_batch, small_features, small_outputs, device): + """Generate small sample data for testing.""" + x = torch.randn(small_batch, small_features, device=device) + y_class = torch.randint(0, small_outputs, (small_batch,), device=device) + y_reg = torch.randn(small_batch, small_outputs, device=device) + return { + 'x': x, + 'y_classification': y_class, + 'y_regression': y_reg, + 'batch_size': small_batch, + 'input_features': small_features, + 'output_features': small_outputs + } + + +@pytest.fixture +def sample_data_medium(medium_batch, medium_features, medium_outputs, device): + """Generate medium sample data for testing.""" + x = torch.randn(medium_batch, medium_features, device=device) + y_class = torch.randint(0, medium_outputs, (medium_batch,), device=device) + y_reg = torch.randn(medium_batch, medium_outputs, device=device) + return { + 'x': x, + 'y_classification': y_class, + 'y_regression': y_reg, + 'batch_size': medium_batch, + 'input_features': medium_features, + 'output_features': medium_outputs + } + + +@pytest.fixture +def classification_params(): + """Common parameters for classification layers.""" + return { + 'regularization_weight': 0.1, + 'prior_scale': 1.0, + 'wishart_scale': 1.0, + 'dof': 2.0, + 'return_ood': False + } + +@pytest.fixture +def het_classification_params(): + """Common parameters for hetclassification layers.""" + return { + 'regularization_weight': 0.1, + 'prior_scale': 1.0, + 'noise_prior_scale': 2.0, + 'return_ood': False + } + + +@pytest.fixture +def regression_params(): + """Common parameters for regression layers.""" + return { + 'regularization_weight': 0.1, + 'prior_scale': 1.0, + 'wishart_scale': 1e-2, + 'dof': 1.0 + } + +@pytest.fixture +def het_regression_params(): + """Common parameters for heteroscedastic regression layers.""" + return { + 'regularization_weight': 0.1, + 'prior_scale': 1.0, + 'noise_prior_scale': 1e-2, + } + +@pytest.fixture +def parameterizations(): + """Available covariance parameterizations.""" + return ['diagonal', 'dense', 'lowrank'] + + +@pytest.fixture +def softmax_bounds(): + """Available softmax bounds for classification.""" + return ['jensen', 'semimontecarlo', 'reduced_kn'] diff --git a/tests/test_classification.py b/tests/test_classification.py new file mode 100644 index 0000000..7ad9a71 --- /dev/null +++ b/tests/test_classification.py @@ -0,0 +1,321 @@ +""" +Tests for VBLL classification layers. +""" +import pytest +import torch +import torch.nn as nn +import numpy as np +from vbll.layers.classification import ( + DiscClassification, tDiscClassification, HetClassification, GenClassification, + gaussian_kl, gamma_kl, expected_gaussian_kl +) + + +class TestDiscClassification: + """Test Discriminative Classification layer.""" + + def test_initialization(self, sample_data_small, classification_params): + """Test DiscClassification initialization.""" + data = sample_data_small + layer = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **classification_params + ) + + assert layer.W_mean.shape == (data['output_features'], data['input_features']) + assert layer.regularization_weight == classification_params['regularization_weight'] + assert layer.return_ood == classification_params['return_ood'] + + def test_forward_pass(self, sample_data_small, classification_params): + """Test forward pass of DiscClassification.""" + data = sample_data_small + layer = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **classification_params + ) + + x = data['x'] + result = layer(x) + + # Check return type + assert hasattr(result, 'predictive') + assert hasattr(result, 'train_loss_fn') + assert hasattr(result, 'val_loss_fn') + + # Check predictive distribution + pred_dist = result.predictive + assert isinstance(pred_dist, torch.distributions.Categorical) + assert pred_dist.probs.shape == (data['batch_size'], data['output_features']) + + # Check that probabilities sum to 1 + assert torch.allclose(pred_dist.probs.sum(-1), torch.ones(data['batch_size'])) + + def test_loss_functions(self, sample_data_small, classification_params): + """Test train and validation loss functions.""" + data = sample_data_small + layer = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **classification_params + ) + + x = data['x'] + y = data['y_classification'] + result = layer(x) + + # Test train loss + train_loss = result.train_loss_fn(y) + assert isinstance(train_loss, torch.Tensor) + assert train_loss.requires_grad + + # Test validation loss + val_loss = result.val_loss_fn(y) + assert isinstance(val_loss, torch.Tensor) + assert val_loss.requires_grad + + def test_ood_scores(self, sample_data_small, classification_params): + """Test OOD score computation.""" + params_with_ood = classification_params.copy() + params_with_ood['return_ood'] = True + + layer = DiscClassification( + in_features=sample_data_small['input_features'], + out_features=sample_data_small['output_features'], + **params_with_ood + ) + + x = sample_data_small['x'] + result = layer(x) + + assert result.ood_scores is not None + ood_scores = result.ood_scores + assert ood_scores.shape == (sample_data_small['batch_size'],) + assert torch.all(ood_scores >= 0) and torch.all(ood_scores <= 1) + + def test_different_parameterizations(self, sample_data_small, classification_params, parameterizations): + """Test different covariance parameterizations.""" + data = sample_data_small + + for param in parameterizations: + if param == 'lowrank': + # Skip lowrank for small features to avoid issues + continue + + layer = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + parameterization=param, + **classification_params + ) + + x = data['x'] + result = layer(x) + + # Should work without errors + assert result.predictive.probs.shape == (data['batch_size'], data['output_features']) + + def test_softmax_bounds(self, sample_data_small, classification_params): + """Test different softmax bounds.""" + data = sample_data_small + + # Test Jensen bound (default) + layer = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + softmax_bound='jensen', + **classification_params + ) + + x = data['x'] + y = data['y_classification'] + result = layer(x) + + # Should work without errors + train_loss = result.train_loss_fn(y) + assert isinstance(train_loss, torch.Tensor) + + def test_gradient_flow(self, sample_data_small, classification_params): + """Test that gradients flow properly.""" + data = sample_data_small + layer = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **classification_params + ) + + x = data['x'] + y = data['y_classification'] + result = layer(x) + + train_loss = result.train_loss_fn(y) + train_loss.backward() + + # Check that gradients exist for parameters + for param in layer.parameters(): + if param.requires_grad: + assert param.grad is not None + assert not torch.all(param.grad == 0) + + +class TesttDiscClassification: + """Test t-Discriminative Classification layer.""" + + def test_initialization(self, sample_data_small, classification_params): + """Test tDiscClassification initialization.""" + data = sample_data_small + layer = tDiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **classification_params + ) + + assert layer.W_mean.shape == (data['output_features'], data['input_features']) + assert layer.regularization_weight == classification_params['regularization_weight'] + + def test_forward_pass(self, sample_data_small, classification_params): + """Test forward pass of tDiscClassification.""" + data = sample_data_small + layer = tDiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **classification_params + ) + + x = data['x'] + result = layer(x) + + # Check return type + assert hasattr(result, 'predictive') + assert hasattr(result, 'train_loss_fn') + assert hasattr(result, 'val_loss_fn') + + # Check predictive distribution + pred_dist = result.predictive + assert isinstance(pred_dist, torch.distributions.Categorical) + assert pred_dist.probs.shape == (data['batch_size'], data['output_features']) + + def test_noise_distribution(self, sample_data_small, classification_params): + """Test noise distribution properties.""" + data = sample_data_small + layer = tDiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **classification_params + ) + + noise_dist = layer.noise + assert isinstance(noise_dist, torch.distributions.Gamma) + + # Test noise prior + noise_prior = layer.noise_prior + assert isinstance(noise_prior, torch.distributions.Gamma) + + def test_softmax_bounds(self, sample_data_small, classification_params): + """Test different softmax bounds for tDiscClassification.""" + data = sample_data_small + + # Test reduced_kn bound + layer = tDiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + softmax_bound='reduced_kn', + **classification_params + ) + + x = data['x'] + y = data['y_classification'] + result = layer(x) + + train_loss = result.train_loss_fn(y) + assert isinstance(train_loss, torch.Tensor) + + def test_alpha_parameter(self, sample_data_small, classification_params): + """Test alpha parameter for reduced_kn bound.""" + data = sample_data_small + layer = tDiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + softmax_bound='reduced_kn', + kn_alpha=0.5, + **classification_params + ) + + assert hasattr(layer, 'alpha') + assert layer.alpha.item() == 0.5 + + +class TestHetClassification: + """Test Heteroscedastic Classification layer.""" + + def test_initialization(self, sample_data_small, het_classification_params): + """Test HetClassification initialization.""" + data = sample_data_small + layer = HetClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **het_classification_params + ) + + assert layer.W_mean.shape == (data['output_features'], data['input_features']) + assert layer.M_mean.shape == (data['output_features'], data['input_features']) + assert layer.regularization_weight == het_classification_params['regularization_weight'] + + def test_forward_pass(self, sample_data_small, het_classification_params): + """Test forward pass of HetClassification.""" + data = sample_data_small + layer = HetClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **het_classification_params + ) + + x = data['x'] + result = layer(x) + + # Check return type + assert hasattr(result, 'predictive') + assert hasattr(result, 'train_loss_fn') + assert hasattr(result, 'val_loss_fn') + + # Check predictive distribution + pred_dist = result.predictive + assert isinstance(pred_dist, torch.distributions.Categorical) + assert pred_dist.probs.shape == (data['batch_size'], data['output_features']) + + def test_consistent_variance(self, sample_data_small, het_classification_params): + """Test consistent variance sampling.""" + data = sample_data_small + layer = HetClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **het_classification_params + ) + + x = data['x'] + + # Test with consistent_variance=True + result_consistent = layer(x, consistent_variance=True) + pred_consistent = result_consistent.predictive + + # Test with consistent_variance=False + result_inconsistent = layer(x, consistent_variance=False) + pred_inconsistent = result_inconsistent.predictive + + # Both should produce valid probability distributions + assert pred_consistent.probs.shape == (data['batch_size'], data['output_features']) + assert pred_inconsistent.probs.shape == (data['batch_size'], data['output_features']) + + def test_M_distribution(self, sample_data_small, het_classification_params): + """Test M distribution for noise modeling.""" + data = sample_data_small + layer = HetClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **het_classification_params + ) + + M_dist = layer.M + assert hasattr(M_dist, 'mean') + assert hasattr(M_dist, 'scale_tril') \ No newline at end of file diff --git a/tests/test_distributions.py b/tests/test_distributions.py new file mode 100644 index 0000000..252c523 --- /dev/null +++ b/tests/test_distributions.py @@ -0,0 +1,58 @@ +""" +Tests for VBLL distribution classes. +""" +import pytest +import torch +import numpy as np +from vbll.utils.distributions import Normal, DenseNormal, DenseNormalPrec, LowRankNormal, get_parameterization + + +class TestNormal: + """Test the Normal (diagonal covariance) distribution.""" + + def test_initialization(self, device): + """Test Normal distribution initialization.""" + mean = torch.randn(5, 3, device=device) + scale = torch.ones(5, 3, device=device) + dist = Normal(mean, scale) + + assert torch.allclose(dist.mean, mean) + assert torch.allclose(dist.scale, scale) + assert torch.allclose(dist.var, scale ** 2) + + def test_properties(self, device): + """Test Normal distribution properties.""" + mean = torch.randn(2, 4, device=device) + scale = torch.ones(2, 4, device=device) + dist = Normal(mean, scale) + + # Test covariance diagonal + expected_var = torch.ones(2, 4, device=device) + assert torch.allclose(dist.covariance_diagonal, expected_var) + + # Test trace + expected_trace = torch.tensor([4.0, 4.0], device=device) + assert torch.allclose(dist.trace_covariance, expected_trace) + + # Test log determinant + expected_logdet = torch.tensor([0.0, 0.0], device=device) + assert torch.allclose(dist.logdet_covariance, expected_logdet) + + def test_addition(self, device): + """Test Normal distribution addition.""" + mean1 = torch.randn(3, 2, device=device) + scale1 = torch.ones(3, 2, device=device) + dist1 = Normal(mean1, scale1) + + mean2 = torch.randn(3, 2, device=device) + scale2 = torch.ones(3, 2, device=device) + dist2 = Normal(mean2, scale2) + + result = dist1 + dist2 + + expected_mean = mean1 + mean2 + expected_var = scale1**2 + scale2**2 + expected_scale = torch.sqrt(expected_var) + + assert torch.allclose(result.mean, expected_mean) + assert torch.allclose(result.scale, expected_scale) \ No newline at end of file diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..dc992aa --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,168 @@ +""" +Integration tests for VBLL components. +""" +import pytest +import torch +import torch.nn as nn +import numpy as np +from vbll.layers.classification import DiscClassification +from vbll.layers.regression import Regression, HetRegression +from vbll.utils.distributions import Normal, DenseNormal + + +class TestEndToEndClassification: + """End-to-end tests for classification.""" + + def test_simple_classification_training(self, sample_data_medium, classification_params): + """Test simple classification training loop.""" + data = sample_data_medium + layer = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **classification_params + ) + + optimizer = torch.optim.Adam(layer.parameters(), lr=0.01) + + # Training loop + for epoch in range(5): + optimizer.zero_grad() + + result = layer(data['x']) + train_loss = result.train_loss_fn(data['y_classification']) + + train_loss.backward() + optimizer.step() + + # Check that loss is finite and decreasing + assert torch.isfinite(train_loss) + + def test_classification_uncertainty(self, sample_data_medium, classification_params): + """Test classification uncertainty estimation.""" + data = sample_data_medium + classification_params = classification_params.copy() + classification_params['return_ood'] = True + + layer = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + **classification_params + ) + + # Train for a few epochs + optimizer = torch.optim.Adam(layer.parameters(), lr=0.01) + for _ in range(3): + optimizer.zero_grad() + result = layer(data['x']) + train_loss = result.train_loss_fn(data['y_classification']) + train_loss.backward() + optimizer.step() + + # Test uncertainty estimation + result = layer(data['x']) + ood_scores = result.ood_scores + + assert ood_scores.shape == (data['batch_size'],) + assert torch.all(ood_scores >= 0) and torch.all(ood_scores <= 1) + + # Test predictive probabilities + pred_probs = result.predictive.probs + assert pred_probs.shape == (data['batch_size'], data['output_features']) + assert torch.allclose(pred_probs.sum(-1), torch.ones(data['batch_size'])) + + def test_different_softmax_bounds(self, sample_data_small, classification_params): + """Test different softmax bounds in classification.""" + data = sample_data_small + + # Test Jensen bound + layer_jensen = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + softmax_bound='jensen', + **classification_params + ) + + result_jensen = layer_jensen(data['x']) + loss_jensen = result_jensen.train_loss_fn(data['y_classification']) + + assert torch.isfinite(loss_jensen) + assert result_jensen.predictive.probs.shape == (data['batch_size'], data['output_features']) + + def test_classification_with_different_parameterizations(self, sample_data_small, classification_params): + """Test classification with different covariance parameterizations.""" + data = sample_data_small + + for param in ['diagonal', 'dense']: + layer = DiscClassification( + in_features=data['input_features'], + out_features=data['output_features'], + parameterization=param, + **classification_params + ) + + result = layer(data['x']) + loss = result.train_loss_fn(data['y_classification']) + + assert torch.isfinite(loss) + assert result.predictive.probs.shape == (data['batch_size'], data['output_features']) + + +class TestEndToEndRegression: + """End-to-end tests for regression.""" + + def test_simple_regression_training(self, sample_data_medium, regression_params): + """Test simple regression training loop.""" + data = sample_data_medium + layer = Regression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + optimizer = torch.optim.Adam(layer.parameters(), lr=0.01) + + # Training loop + for epoch in range(5): + optimizer.zero_grad() + + result = layer(data['x']) + train_loss = result.train_loss_fn(data['y_regression']) + + train_loss.backward() + optimizer.step() + + # Check that loss is finite + assert torch.isfinite(train_loss) + + def test_regression_uncertainty(self, sample_data_medium, regression_params): + """Test regression uncertainty estimation.""" + data = sample_data_medium + layer = Regression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + # Train for a few epochs + optimizer = torch.optim.Adam(layer.parameters(), lr=0.01) + for _ in range(3): + optimizer.zero_grad() + result = layer(data['x']) + train_loss = result.train_loss_fn(data['y_regression']) + train_loss.backward() + optimizer.step() + + # Test uncertainty estimation + pred_dist = layer.predictive(data['x']) + + # Test mean and variance + mean = pred_dist.mean + variance = pred_dist.scale ** 2 + + assert mean.shape == (data['batch_size'], data['output_features']) + assert variance.shape == (data['batch_size'], data['output_features']) + assert torch.all(variance >= 0) + + # Test sampling + samples = pred_dist.sample((100,)) + assert samples.shape == (100, data['batch_size'], data['output_features']) \ No newline at end of file diff --git a/tests/test_regression.py b/tests/test_regression.py new file mode 100644 index 0000000..7ce5739 --- /dev/null +++ b/tests/test_regression.py @@ -0,0 +1,349 @@ +""" +Tests for VBLL regression layers. +""" +import pytest +import torch +import torch.nn as nn +import numpy as np +from vbll.layers.regression import ( + Regression, tRegression, HetRegression, + gaussian_kl, gamma_kl, expected_gaussian_kl +) + + +class TestRegression: + """Test standard VBLL Regression layer.""" + + def test_initialization(self, sample_data_small, regression_params): + """Test Regression initialization.""" + data = sample_data_small + layer = Regression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + assert layer.regularization_weight == regression_params['regularization_weight'] + assert layer.W_mean.shape == (data['output_features'], data['input_features']) + + def test_forward_pass(self, sample_data_small, regression_params): + """Test forward pass of Regression.""" + data = sample_data_small + layer = Regression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + x = data['x'] + result = layer(x) + + # Check return type + assert hasattr(result, 'predictive') + assert hasattr(result, 'train_loss_fn') + assert hasattr(result, 'val_loss_fn') + + # Check predictive distribution + pred_dist = result.predictive + assert hasattr(pred_dist, 'mean') + assert hasattr(pred_dist, 'scale') + assert pred_dist.mean.shape == (data['batch_size'], data['output_features']) + assert pred_dist.scale.shape == (data['batch_size'], data['output_features']) + + def test_loss_functions(self, sample_data_small, regression_params): + """Test train and validation loss functions.""" + data = sample_data_small + layer = Regression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + x = data['x'] + y = data['y_regression'] + result = layer(x) + + # Test train loss + train_loss = result.train_loss_fn(y) + assert isinstance(train_loss, torch.Tensor) + assert train_loss.requires_grad + + # Test validation loss + val_loss = result.val_loss_fn(y) + assert isinstance(val_loss, torch.Tensor) + assert val_loss.requires_grad + + def test_different_parameterizations(self, sample_data_small, regression_params, parameterizations): + """Test different covariance parameterizations.""" + data = sample_data_small + + for param in parameterizations: + if param == 'lowrank': + # Skip lowrank for small features to avoid issues + continue + + layer = Regression( + in_features=data['input_features'], + out_features=data['output_features'], + parameterization=param, + **regression_params + ) + + x = data['x'] + result = layer(x) + + # Should work without errors + assert result.predictive.mean.shape == (data['batch_size'], data['output_features']) + + def test_predictive_distribution(self, sample_data_small, regression_params): + """Test predictive distribution properties.""" + data = sample_data_small + layer = Regression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + x = data['x'] + pred_dist = layer.predictive(x) + + # Test sampling + samples = pred_dist.sample((10,)) + assert samples.shape == (10, data['batch_size'], data['output_features']) + + # Test log probability + y = data['y_regression'] + log_prob = pred_dist.log_prob(y) + assert log_prob.shape == (data['batch_size'], data['output_features']) + + def test_gradient_flow(self, sample_data_small, regression_params): + """Test that gradients flow properly.""" + data = sample_data_small + layer = Regression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + x = data['x'] + y = data['y_regression'] + result = layer(x) + + train_loss = result.train_loss_fn(y) + train_loss.backward() + + # Check that gradients exist for parameters + for param in layer.parameters(): + if param.requires_grad: + assert param.grad is not None + assert not torch.all(param.grad == 0) + + +class TesttRegression: + """Test t-Distribution Regression layer.""" + + def test_initialization(self, sample_data_small, regression_params): + """Test tRegression initialization.""" + data = sample_data_small + layer = tRegression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + assert layer.W_mean.shape == (data['output_features'], data['input_features']) + assert layer.regularization_weight == regression_params['regularization_weight'] + + def test_forward_pass(self, sample_data_small, regression_params): + """Test forward pass of tRegression.""" + data = sample_data_small + layer = tRegression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + x = data['x'] + result = layer(x) + + # Check return type + assert hasattr(result, 'predictive') + assert hasattr(result, 'train_loss_fn') + assert hasattr(result, 'val_loss_fn') + + # Check predictive distribution + pred_dist = result.predictive + assert isinstance(pred_dist, torch.distributions.StudentT) + assert pred_dist.loc.shape == (data['batch_size'], data['output_features']) + assert pred_dist.scale.shape == (data['batch_size'], data['output_features']) + + def test_noise_distribution(self, sample_data_small, regression_params): + """Test noise distribution properties.""" + data = sample_data_small + layer = tRegression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + noise_dist = layer.noise + assert isinstance(noise_dist, torch.distributions.Gamma) + + # Test noise prior + noise_prior = layer.noise_prior + assert isinstance(noise_prior, torch.distributions.Gamma) + + def test_predictive_distribution(self, sample_data_small, regression_params): + """Test predictive distribution properties.""" + data = sample_data_small + layer = tRegression( + in_features=data['input_features'], + out_features=data['output_features'], + **regression_params + ) + + x = data['x'] + pred_dist = layer.predictive(x) + + # Test sampling + samples = pred_dist.sample((10,)) + assert samples.shape == (10, data['batch_size'], data['output_features']) + + # Test log probability + y = data['y_regression'] + log_prob = pred_dist.log_prob(y) + assert log_prob.shape == (data['batch_size'], data['output_features']) + + def test_different_parameterizations(self, sample_data_small, regression_params): + """Test different covariance parameterizations for tRegression.""" + data = sample_data_small + + for param in ['diagonal', 'dense']: + layer = tRegression( + in_features=data['input_features'], + out_features=data['output_features'], + parameterization=param, + **regression_params + ) + + x = data['x'] + result = layer(x) + + # Should work without errors + assert result.predictive.loc.shape == (data['batch_size'], data['output_features']) + + +class TestHetRegression: + """Test Heteroscedastic Regression layer.""" + + def test_initialization(self, sample_data_small, het_regression_params): + """Test HetRegression initialization.""" + data = sample_data_small + layer = HetRegression( + in_features=data['input_features'], + out_features=data['output_features'], + **het_regression_params + ) + + assert layer.W_mean.shape == (data['output_features'], data['input_features']) + assert layer.M_mean.shape == (data['output_features'], data['input_features']) + assert layer.regularization_weight == het_regression_params['regularization_weight'] + + def test_forward_pass(self, sample_data_small, het_regression_params): + """Test forward pass of HetRegression.""" + data = sample_data_small + layer = HetRegression( + in_features=data['input_features'], + out_features=data['output_features'], + **het_regression_params + ) + + x = data['x'] + result = layer(x) + + # Check return type + assert hasattr(result, 'predictive') + assert hasattr(result, 'train_loss_fn') + assert hasattr(result, 'val_loss_fn') + + # Check predictive distribution + pred_dist = result.predictive + assert hasattr(pred_dist, 'mean') + assert hasattr(pred_dist, 'scale') + assert pred_dist.mean.shape == (data['batch_size'], data['output_features']) + assert pred_dist.scale.shape == (data['batch_size'], data['output_features']) + + def test_consistent_variance(self, sample_data_small, het_regression_params): + """Test consistent variance sampling.""" + data = sample_data_small + layer = HetRegression( + in_features=data['input_features'], + out_features=data['output_features'], + **het_regression_params + ) + + x = data['x'] + + # Test with consistent_variance=True + result_consistent = layer(x, consistent_variance=True) + pred_consistent = result_consistent.predictive + + # Test with consistent_variance=False + result_inconsistent = layer(x, consistent_variance=False) + pred_inconsistent = result_inconsistent.predictive + + # Both should produce valid distributions + assert pred_consistent.mean.shape == (data['batch_size'], data['output_features']) + assert pred_inconsistent.mean.shape == (data['batch_size'], data['output_features']) + + def test_M_distribution(self, sample_data_small, het_regression_params): + """Test M distribution for noise modeling.""" + data = sample_data_small + layer = HetRegression( + in_features=data['input_features'], + out_features=data['output_features'], + **het_regression_params + ) + + M_dist = layer.M + assert hasattr(M_dist, 'mean') + assert hasattr(M_dist, 'scale_tril') + + def test_predictive_sample(self, sample_data_small, het_regression_params): + """Test predictive sampling.""" + data = sample_data_small + layer = HetRegression( + in_features=data['input_features'], + out_features=data['output_features'], + **het_regression_params + ) + + x = data['x'] + + # Test consistent variance sampling + pred_consistent = layer.predictive_sample(x, consistent_variance=True) + assert pred_consistent.mean.shape == (data['batch_size'], data['output_features']) + assert pred_consistent.scale.shape == (data['batch_size'], data['output_features']) + + # Test inconsistent variance sampling + pred_inconsistent = layer.predictive_sample(x, consistent_variance=False) + assert pred_inconsistent.mean.shape == (data['batch_size'], data['output_features']) + assert pred_inconsistent.scale.shape == (data['batch_size'], data['output_features']) + + def test_different_parameterizations(self, sample_data_small, het_regression_params): + """Test different covariance parameterizations for HetRegression.""" + data = sample_data_small + + for param in ['diagonal', 'dense']: + layer = HetRegression( + in_features=data['input_features'], + out_features=data['output_features'], + parameterization=param, + **het_regression_params + ) + + x = data['x'] + result = layer(x) + + # Should work without errors + assert result.predictive.mean.shape == (data['batch_size'], data['output_features']) \ No newline at end of file