Skip to content

WIP: ENH: enable complex dtype support unconditionally #790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions pywt/_dwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np

from ._c99_config import _have_c99_complex
from ._extensions._dwt import downcoef as _downcoef
from ._extensions._dwt import dwt_axis, dwt_single, idwt_axis, idwt_single
from ._extensions._dwt import dwt_coeff_len as _dwt_coeff_len
Expand Down Expand Up @@ -161,12 +160,6 @@ def dwt(data, wavelet, mode='symmetric', axis=-1):
array([-0.70710678, -0.70710678, -0.70710678])

"""
if not _have_c99_complex and np.iscomplexobj(data):
data = np.asarray(data)
cA_r, cD_r = dwt(data.real, wavelet, mode, axis)
cA_i, cD_i = dwt(data.imag, wavelet, mode, axis)
return (cA_r + 1j*cA_i, cD_r + 1j*cD_i)

# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.asarray(data, dtype=dt, order='C')
Expand Down Expand Up @@ -241,17 +234,6 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
raise ValueError("At least one coefficient parameter must be "
"specified.")

# for complex inputs: compute real and imaginary separately then combine
if not _have_c99_complex and (np.iscomplexobj(cA) or np.iscomplexobj(cD)):
if cA is None:
cD = np.asarray(cD)
cA = np.zeros_like(cD)
elif cD is None:
cA = np.asarray(cA)
cD = np.zeros_like(cA)
return (idwt(cA.real, cD.real, wavelet, mode, axis) +
1j*idwt(cA.imag, cD.imag, wavelet, mode, axis))

if cA is not None:
dt = _check_dtype(cA)
cA = np.asarray(cA, dtype=dt, order='C')
Expand Down Expand Up @@ -328,9 +310,6 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1):
upcoef

"""
if not _have_c99_complex and np.iscomplexobj(data):
return (downcoef(part, data.real, wavelet, mode, level) +
1j*downcoef(part, data.imag, wavelet, mode, level))
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.asarray(data, dtype=dt, order='C')
Expand Down Expand Up @@ -387,9 +366,6 @@ def upcoef(part, coeffs, wavelet, level=1, take=0):
array([ 1., 2., 3., 4., 5., 6.])

"""
if not _have_c99_complex and np.iscomplexobj(coeffs):
return (upcoef(part, coeffs.real, wavelet, level, take) +
1j*upcoef(part, coeffs.imag, wavelet, level, take))
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(coeffs)
coeffs = np.asarray(coeffs, dtype=dt, order='C')
Expand Down
293 changes: 146 additions & 147 deletions pywt/_extensions/_dwt.pyx

Large diffs are not rendered by default.

18 changes: 5 additions & 13 deletions pywt/_extensions/_pywt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,17 @@ cimport numpy as np

np.import_array()

include "config.pxi"

ctypedef Py_ssize_t pywt_index_t

ctypedef fused data_t:
np.float32_t
np.float64_t

cdef int have_c99_complex
IF HAVE_C99_CPLX:
ctypedef fused cdata_t:
np.float32_t
np.float64_t
np.complex64_t
np.complex128_t
have_c99_complex = 1
ELSE:
ctypedef data_t cdata_t
have_c99_complex = 0
ctypedef fused cdata_t:
np.float32_t
np.float64_t
np.complex64_t
np.complex128_t

cdef public class Wavelet [type WaveletType, object WaveletObject]:
cdef wavelet.DiscreteWavelet* w
Expand Down
161 changes: 78 additions & 83 deletions pywt/_extensions/_swt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ cimport numpy as np
from .common cimport pywt_index_t
from ._pywt cimport c_wavelet_from_object, cdata_t, Wavelet, _check_dtype

include "config.pxi"

np.import_array()


Expand Down Expand Up @@ -99,21 +97,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level,
&cD[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")
IF HAVE_C99_CPLX:
if cdata_t is np.complex128_t:
cD = np.zeros(output_len, dtype=np.complex128)
with nogil:
retval = c_wt.double_complex_swt_d(&data[0], data_size, wavelet.w,
&cD[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")
elif cdata_t is np.complex64_t:
cD = np.zeros(output_len, dtype=np.complex64)
with nogil:
retval = c_wt.float_complex_swt_d(&data[0], data_size, wavelet.w,
&cD[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")
elif cdata_t is np.complex128_t:
cD = np.zeros(output_len, dtype=np.complex128)
with nogil:
retval = c_wt.double_complex_swt_d(&data[0], data_size, wavelet.w,
&cD[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")
elif cdata_t is np.complex64_t:
cD = np.zeros(output_len, dtype=np.complex64)
with nogil:
retval = c_wt.float_complex_swt_d(&data[0], data_size, wavelet.w,
&cD[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")

# alloc memory, decompose A
if cdata_t is np.float64_t:
Expand All @@ -130,21 +127,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level,
&cA[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")
IF HAVE_C99_CPLX:
if cdata_t is np.complex128_t:
cA = np.zeros(output_len, dtype=np.complex128)
with nogil:
retval = c_wt.double_complex_swt_a(&data[0], data_size, wavelet.w,
&cA[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")
elif cdata_t is np.complex64_t:
cA = np.zeros(output_len, dtype=np.complex64)
with nogil:
retval = c_wt.float_complex_swt_a(&data[0], data_size, wavelet.w,
&cA[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")
elif cdata_t is np.complex128_t:
cA = np.zeros(output_len, dtype=np.complex128)
with nogil:
retval = c_wt.double_complex_swt_a(&data[0], data_size, wavelet.w,
&cA[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")
elif cdata_t is np.complex64_t:
cA = np.zeros(output_len, dtype=np.complex64)
with nogil:
retval = c_wt.float_complex_swt_a(&data[0], data_size, wavelet.w,
&cA[0], output_len, i)
if retval < 0:
raise RuntimeError("C swt failed.")

data = cA
if not trim_approx:
Expand Down Expand Up @@ -253,58 +249,57 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
if retval:
raise RuntimeError(
"C wavelet transform failed with error code %d" % retval)
elif data.dtype == np.complex128:
cA = np.zeros(output_shape, dtype=np.complex128)
with nogil:
retval = c_wt.double_complex_downcoef_axis(
<double complex *> data.data, data_info,
<double complex *> cA.data, output_info,
wavelet.w, axis,
common.COEF_APPROX, common.MODE_PERIODIZATION,
i, common.SWT_TRANSFORM)
if retval:
raise RuntimeError(
"C wavelet transform failed with error code %d" %
retval)
cD = np.zeros(output_shape, dtype=np.complex128)
with nogil:
retval = c_wt.double_complex_downcoef_axis(
<double complex *> data.data, data_info,
<double complex *> cD.data, output_info,
wavelet.w, axis,
common.COEF_DETAIL, common.MODE_PERIODIZATION,
i, common.SWT_TRANSFORM)
if retval:
raise RuntimeError(
"C wavelet transform failed with error code %d" %
retval)
elif data.dtype == np.complex64:
cA = np.zeros(output_shape, dtype=np.complex64)
with nogil:
retval = c_wt.float_complex_downcoef_axis(
<float complex *> data.data, data_info,
<float complex *> cA.data, output_info,
wavelet.w, axis,
common.COEF_APPROX, common.MODE_PERIODIZATION,
i, common.SWT_TRANSFORM)
if retval:
raise RuntimeError(
"C wavelet transform failed with error code %d" %
retval)
cD = np.zeros(output_shape, dtype=np.complex64)
with nogil:
retval = c_wt.float_complex_downcoef_axis(
<float complex *> data.data, data_info,
<float complex *> cD.data, output_info,
wavelet.w, axis,
common.COEF_DETAIL, common.MODE_PERIODIZATION,
i, common.SWT_TRANSFORM)
if retval:
raise RuntimeError(
"C wavelet transform failed with error code %d" %
retval)

IF HAVE_C99_CPLX:
if data.dtype == np.complex128:
cA = np.zeros(output_shape, dtype=np.complex128)
with nogil:
retval = c_wt.double_complex_downcoef_axis(
<double complex *> data.data, data_info,
<double complex *> cA.data, output_info,
wavelet.w, axis,
common.COEF_APPROX, common.MODE_PERIODIZATION,
i, common.SWT_TRANSFORM)
if retval:
raise RuntimeError(
"C wavelet transform failed with error code %d" %
retval)
cD = np.zeros(output_shape, dtype=np.complex128)
with nogil:
retval = c_wt.double_complex_downcoef_axis(
<double complex *> data.data, data_info,
<double complex *> cD.data, output_info,
wavelet.w, axis,
common.COEF_DETAIL, common.MODE_PERIODIZATION,
i, common.SWT_TRANSFORM)
if retval:
raise RuntimeError(
"C wavelet transform failed with error code %d" %
retval)
elif data.dtype == np.complex64:
cA = np.zeros(output_shape, dtype=np.complex64)
with nogil:
retval = c_wt.float_complex_downcoef_axis(
<float complex *> data.data, data_info,
<float complex *> cA.data, output_info,
wavelet.w, axis,
common.COEF_APPROX, common.MODE_PERIODIZATION,
i, common.SWT_TRANSFORM)
if retval:
raise RuntimeError(
"C wavelet transform failed with error code %d" %
retval)
cD = np.zeros(output_shape, dtype=np.complex64)
with nogil:
retval = c_wt.float_complex_downcoef_axis(
<float complex *> data.data, data_info,
<float complex *> cD.data, output_info,
wavelet.w, axis,
common.COEF_DETAIL, common.MODE_PERIODIZATION,
i, common.SWT_TRANSFORM)
if retval:
raise RuntimeError(
"C wavelet transform failed with error code %d" %
retval)
if retval == -5:
raise TypeError("Array must be floating point, not {}"
.format(data.dtype))
Expand Down
9 changes: 4 additions & 5 deletions pywt/_extensions/c/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

#pragma once

#ifdef HAVE_C99_COMPLEX
/* For templating, we need typedefs without spaces for complex types. */
typedef float _Complex float_complex;
typedef double _Complex double_complex;
#endif
/* For templating, we need typedefs without spaces for complex types. */
/* FIXME: needs more portable complex types here */
typedef float _Complex float_complex;
typedef double _Complex double_complex;

/* ##### Typedefs ##### */

Expand Down
22 changes: 10 additions & 12 deletions pywt/_extensions/c/convolution.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,17 @@
#undef REAL_TYPE
#undef TYPE

#ifdef HAVE_C99_COMPLEX
#define TYPE float_complex
#define REAL_TYPE float
#include "convolution.template.c"
#undef REAL_TYPE
#undef TYPE
#define TYPE float_complex
#define REAL_TYPE float
#include "convolution.template.c"
#undef REAL_TYPE
#undef TYPE

#define TYPE double_complex
#define REAL_TYPE double
#include "convolution.template.c"
#undef REAL_TYPE
#undef TYPE
#endif
#define TYPE double_complex
#define REAL_TYPE double
#include "convolution.template.c"
#undef REAL_TYPE
#undef TYPE

#endif /* REAL_TYPE */
#endif /* TYPE */
24 changes: 11 additions & 13 deletions pywt/_extensions/c/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,17 @@
#undef REAL_TYPE
#undef TYPE

#ifdef HAVE_C99_COMPLEX
#define TYPE float_complex
#define REAL_TYPE float
#include "convolution.template.h"
#undef REAL_TYPE
#undef TYPE

#define TYPE double_complex
#define REAL_TYPE double
#include "convolution.template.h"
#undef REAL_TYPE
#undef TYPE
#endif
#define TYPE float_complex
#define REAL_TYPE float
#include "convolution.template.h"
#undef REAL_TYPE
#undef TYPE

#define TYPE double_complex
#define REAL_TYPE double
#include "convolution.template.h"
#undef REAL_TYPE
#undef TYPE

#endif /* REAL_TYPE */
#endif /* TYPE */
22 changes: 10 additions & 12 deletions pywt/_extensions/c/wt.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,17 @@
#undef REAL_TYPE
#undef TYPE

#ifdef HAVE_C99_COMPLEX
#define TYPE float_complex
#define REAL_TYPE float
#include "wt.template.c"
#undef REAL_TYPE
#undef TYPE
#define TYPE float_complex
#define REAL_TYPE float
#include "wt.template.c"
#undef REAL_TYPE
#undef TYPE

#define TYPE double_complex
#define REAL_TYPE double
#include "wt.template.c"
#undef REAL_TYPE
#undef TYPE
#endif
#define TYPE double_complex
#define REAL_TYPE double
#include "wt.template.c"
#undef REAL_TYPE
#undef TYPE

#endif /* REAL_TYPE */
#endif /* TYPE */
Loading