Skip to content

Commit

Permalink
Work only on affected columns in the sinogram
Browse files Browse the repository at this point in the history
  • Loading branch information
namannimmo10 committed Mar 5, 2025
1 parent b52cb9c commit cc623ad
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions httomolibgpu/prep/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,38 +254,45 @@ def _rs_large(sinogram, snr, size, drop_ratio=0.1, norm=True):
"""
Remove large stripes.
"""
drop_ratio = max(min(drop_ratio, 0.8), 0) # = cp.clip(drop_ratio, 0.0, 0.8)
drop_ratio = max(min(drop_ratio, 0.8), 0)
(nrow, ncol) = sinogram.shape
ndrop = int(0.5 * drop_ratio * nrow)
sinosort = cp.sort(sinogram, axis=0)
sinosmooth = median_filter(sinosort, (1, size))

list1 = cp.mean(sinosort[ndrop : nrow - ndrop], axis=0)
list2 = cp.mean(sinosmooth[ndrop : nrow - ndrop], axis=0)
listfact = list1 / list2
# Locate stripes

listmask = _detect_stripe(listfact, snr)
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)

# Normalize
# Normalize only affected columns
if norm:
sinogram /= cp.tile(listfact, (nrow, 1))
sinogram[:, listmask > 0] /= listfact[None, listmask > 0]

# Identify affected columns
listxmiss = cp.where(listmask > 0.0)[0]
if listxmiss.size == 0:
return sinogram # No stripes, return early for efficiency

# Process only affected columns
sino_transposed = sinogram.T
ids_sort = cp.argsort(sino_transposed, axis=1)
sino_subset = sino_transposed[listxmiss] # Extract affected columns

# Apply sorting without explicit matindex
sino_sorted = cp.take_along_axis(sino_transposed, ids_sort, axis=1)
# Sort only the required subset
ids_sort = cp.argsort(sino_subset, axis=1)
sino_sorted = cp.take_along_axis(sino_subset, ids_sort, axis=1)

# Smoothen sorted sinogram
sino_sorted[:, :] = cp.transpose(sinosmooth)
# Apply smoothing
sino_sorted[:, :] = cp.transpose(sinosmooth[:, listxmiss])

# Restore original order
ids_restore = cp.argsort(ids_sort, axis=1)
sino_corrected = cp.take_along_axis(sino_sorted, ids_restore, axis=1).T
sino_corrected = cp.take_along_axis(sino_sorted, ids_restore, axis=1)

# Apply corrections only to affected columns
listxmiss = cp.where(listmask > 0.0)[0]
sinogram[:, listxmiss] = sino_corrected[:, listxmiss]
# Place back corrected data
sinogram[:, listxmiss] = sino_corrected.T

return sinogram

Expand Down

0 comments on commit cc623ad

Please sign in to comment.