Skip to content

Support Complex Numbers #3330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

prakash-shekhar
Copy link

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

As suggested by #3265

Changes

Core

  • Complex Tensor Kind: Added Complex as a new TensorKind
  • Backend Trait Extensions: Extended Backend trait with ComplexTensorPrimitive and ComplexElem types
  • Element Types: Added Complex32 and Complex64 element types
  • Data Support: Extended TensorData to support complex numbers

ComplexTensorOps Trait

  • Basic Operations: add, sub, mul, div, neg
  • Complex-Specific: conj (conjugate), real, imag, abs (magnitude), arg (phase)
  • Construction: from_parts (real/imag), from_polar (magnitude/phase)
  • Transcendental: exp, log, sin, cos, tan, sqrt, powc
  • Tensor Ops: reshape, transpose, creation functions

Fully Implemented:

  • burn-ndarray

Partially Implemented:

  • burn-autodiff: most operations implemented, but 6 conversion functions stubbed with todo!()

Stub Implementations:

  • burn-tch
  • burn-candle
  • burn-cubecl
  • burn-fusion
  • burn-router

Testing

cargo test complex

should pass 62 tests

- Add Complex as first-class TensorKind alongside Float, Int, Bool
- Add ComplexTensorPrimitive and ComplexElem to Backend trait
- Add complex tensor type aliases and exports
- Add ComplexTensorPrimitive support to NdArray backend
- Implement complex arithmetic and transcendental functions in NdArray backend
- Add autodiff backend wrapper for complex tensors
- Begin enabling support across backend ecosystem
- Add high-level Tensor<B, D, Complex> API with BasicOps and Numeric traits
- Add complex-specific methods: conj(), real(), imag(), magnitude(), phase()
- Add creation utilities: from_parts(), from_polar(), zeros(), ones()
- Start adding test suite covering operations
- Remove non-existent testgen_complex\!() macro call that was causing compilation errors
- Add ComplexTensorOps implementations for all backends (tch, candle, cubecl, fusion, router)
- Fix complex tensor assertion logic in CubeCL backend to avoid Float trait requirements
- Add missing transcendental functions (exp, log, sin, cos, tan, sqrt, powc) to Complex tensor API
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response! Just a general comment:

I think you're on the right track given the separation for complex tensor ops. But since complex numbers are not just another dtype, as they introduce different constraints and rules, I think it would be best to not overload the core Backend trait.

Instead, it can be implemented as a backend extension (e.g., ComplexBackend). That way, not all backends have this additional restriction. It can live entirely outside of the core backend and tensor stuff.

Copy link

codecov bot commented Jul 7, 2025

Codecov Report

Attention: Patch coverage is 0.26762% with 1118 lines in your changes missing coverage. Please review.

Project coverage is 45.66%. Comparing base (81985bd) to head (a2fe37e).
Report is 34 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-tensor/src/tensor/api/complex.rs 0.00% 321 Missing ⚠️
crates/burn-ndarray/src/ops/complex_tensor.rs 0.00% 269 Missing ⚠️
crates/burn-fusion/src/ops/complex.rs 0.00% 98 Missing ⚠️
crates/burn-tensor/src/tensor/element/base.rs 0.00% 93 Missing ⚠️
crates/burn-autodiff/src/ops/complex_tensor.rs 0.00% 91 Missing ⚠️
crates/burn-cubecl/src/ops/complex_ops.rs 0.00% 66 Missing ⚠️
crates/burn-tch/src/ops/complex_tensor.rs 0.00% 66 Missing ⚠️
crates/burn-tensor/src/tensor/element/cast.rs 0.00% 48 Missing ⚠️
crates/burn-tensor/src/tensor/data.rs 10.34% 26 Missing ⚠️
...rates/burn-tensor/src/tensor/ops/complex_tensor.rs 0.00% 13 Missing ⚠️
... and 7 more

❌ Your patch check has failed because the patch coverage (0.26%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (45.66%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #3330       +/-   ##
===========================================
- Coverage   82.66%   45.66%   -37.00%     
===========================================
  Files         995      420      -575     
  Lines      127626    61936    -65690     
===========================================
- Hits       105498    28282    -77216     
- Misses      22128    33654    +11526     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants