diff --git a/.gitignore b/.gitignore index a0232e3..7df47aa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ **/build/ **/__pycache__/ +*.npy +.DS_Store # Common editor backups ~$* diff --git a/.gitmodules b/.gitmodules index e392867..e69de29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +0,0 @@ -[submodule "PETSIRD"] - path = PETSIRD - url = https://github.com/ETSInitiative/PETSIRD - branch=main diff --git a/PETSIRD b/PETSIRD deleted file mode 160000 index ae3a048..0000000 --- a/PETSIRD +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ae3a0484a128f805a2143cc0d010d01557cf902f diff --git a/README.md b/README.md index 4e2c57e..0f29641 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,10 @@ ``` conda env create -f environment.yml conda activate petsird-analytic-simulator -cd python +cd python-hackathon-2025 ``` -## Simulate petsird LM data +## Simulate petsird LM data (work in progress) ``` python 01_analytic_petsird_lm_simulator.py @@ -43,7 +43,7 @@ python 01_analytic_petsird_lm_simulator.py -h value > 0 is given via `--num_epochs_mlem`. Otherwise it is skipped to save time. -## Run a listmode OSEM recon on the simulated +## Run a listmode OSEM recon on the simulated (work in progress) ``` python 02_lm_osem_recon_simulated_data.py diff --git a/environment.yml b/environment.yml index 9a56d59..8f56de7 100644 --- a/environment.yml +++ b/environment.yml @@ -1,13 +1,11 @@ name: petsird-analytic-simulator channels: - conda-forge - - defaults dependencies: - - h5py>=3.7.0 - - ipykernel>=6.19.2 - numpy>=1.24.3 - python>=3.11.3 - - matplotlib~=3.8.0 + - matplotlib>=3.8.0 - parallelproj>=1.10.0 - pymirc>=0.29 - - petsird~=0.2.1 + - petsird~=0.7.2 + - ipython diff --git a/python/.gitignore b/python/.gitignore index c44d932..24c7ae6 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,4 +1,2 @@ -*.bin -*.npy -tmp -my_lm_sim/* +data/* +.DS_Store diff --git a/python/01_analytic_petsird_lm_simulator.py b/python/01_analytic_petsird_lm_simulator.py index dc72bf3..c99e2f7 100755 --- a/python/01_analytic_petsird_lm_simulator.py +++ b/python/01_analytic_petsird_lm_simulator.py @@ -1,22 +1,33 @@ -"""analytic simulation of petsird v0.2 listmode data for a block PET scanner - we only simulate true events and ignore the effect of attenuation - however, we simulate the effect of crystal efficiencies and LOR symmetry group efficiencies +"""analytic simulation of petsird v0.7.2 listmode data for a block PET scanner +we only simulate true events and ignore the effect of attenuation +however, we simulate the effect of crystal efficiencies and LOR symmetry group efficiencies """ # %% -import array_api_compat.numpy as xp +from importlib.metadata import version + +# raise an error if petsird version is not at least 0.7.2 +petsird_version = tuple(map(int, version("petsird").split("."))) +if petsird_version < (0, 7, 2): + raise ImportError( + f"petsird version {petsird_version} is not supported, please install petsird >= 0.7.2" + ) + + +import numpy as np import argparse from array_api_compat import size from itertools import combinations import parallelproj -import petsird import matplotlib.pyplot as plt import math import json from pathlib import Path +import petsird + # %% def circular_distance(i_mod_1: int, i_mod_2: int, num_modules: int) -> int: @@ -60,26 +71,30 @@ def parse_float_tuple(arg): # parse the command line for the input parameters below parser = argparse.ArgumentParser() -parser.add_argument("--fname", type=str, default="simulated_lm_file.bin") -parser.add_argument("--output_dir", type=str, default="my_lm_sim") -parser.add_argument("--num_true_counts", type=int, default=int(4e6)) +parser.add_argument("--fname", type=str, default="simulated_petsird_lm_file.bin") +parser.add_argument("--output_dir", type=str, default=None) +# HACK TO make things faster by default +parser.add_argument("--num_true_counts", type=int, default=int(4e5)) +# parser.add_argument("--num_true_counts", type=int, default=int(4e6)) parser.add_argument("--skip_plots", action="store_true") parser.add_argument("--check_backprojection", default=False, action="store_true") parser.add_argument("--num_epochs_mlem", type=int, default=0) parser.add_argument("--skip_writing", default=False, action="store_true") -parser.add_argument("--fwhm_mm", type=float, default=1.5) -parser.add_argument("--tof_fwhm_mm", type=float, default=30.0) +parser.add_argument("--fwhm_mm", type=float, default=2.5) +parser.add_argument("--tof_fwhm_mm", type=float, default=20.0) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--uniform_crystal_eff", action="store_true") parser.add_argument("--uniform_sg_eff", action="store_true") -parser.add_argument("--img_shape", type=parse_int_tuple, default=(100, 100, 11)) -parser.add_argument("--voxel_size", type=parse_float_tuple, default=(1.0, 1.0, 1.0)) +parser.add_argument("--img_shape", type=parse_int_tuple, default=(55, 55, 19)) +parser.add_argument("--voxel_size", type=parse_float_tuple, default=(2.0, 2.0, 2.0)) parser.add_argument( "--phantom", type=str, - default="squares", - choices=["uniform_cylinder", "squares"], + default="points", + choices=["uniform_cylinder", "squares", "points"], ) +parser.add_argument("--num_time_blocks", type=int, default=3) +parser.add_argument("--event_block_duration", type=int, default=100) args = parser.parse_args() @@ -93,11 +108,19 @@ def parse_float_tuple(arg): tof_fwhm_mm = args.tof_fwhm_mm seed = args.seed phantom = args.phantom -output_dir = Path(args.output_dir) uniform_crystal_eff = args.uniform_crystal_eff uniform_sg_eff = args.uniform_sg_eff img_shape = args.img_shape voxel_size = args.voxel_size +num_time_blocks: int = args.num_time_blocks +event_block_duration: int = args.event_block_duration + +if args.output_dir is None: + output_dir = Path("data") / f"sim_{phantom}_{num_true_counts}_{seed}" +else: + output_dir = Path(args.output_dir) + +num_energy_bins: int = 1 if not output_dir.exists(): output_dir.mkdir(parents=True) @@ -108,14 +131,13 @@ def parse_float_tuple(arg): # %% # "fixed" input parameters -dev = "cpu" -xp.random.seed(args.seed) +np.random.seed(args.seed) # %% # input parameters related to the scanner geometry # number of LOR endpoints per block module in all 3 directions -block_shape = (10, 2, 3) +block_shape = (10, 2, 9) # spacing between LOR endpoints in a block module in all three directions (mm) block_spacing = (4.5, 10.0, 4.5) # radius of the scanner - distance from the center to the block modules (mm) @@ -127,19 +149,19 @@ def parse_float_tuple(arg): # %% # Setup of a modularized parallelproj PET scanner geometry -modules = [] +modules: list[parallelproj.BlockPETScannerModule] = [] # setup an affine transformation matrix to translate the block modules from the # center to the radius of the scanner -aff_mat_trans = xp.eye(4, dtype="float32", device=dev) +aff_mat_trans = np.eye(4, dtype="float32") aff_mat_trans[1, -1] = scanner_radius module_transforms = [] -for i, phi in enumerate(xp.linspace(0, 2 * xp.pi, num_blocks, endpoint=False)): +for i, phi in enumerate(np.linspace(0, 2 * np.pi, num_blocks, endpoint=False)): # setup an affine transformation matrix to rotate the block modules around the center # (of the "2" axis) - aff_mat_rot = xp.asarray( + aff_mat_rot = np.asarray( [ [math.cos(phi), -math.sin(phi), 0, 0], [math.sin(phi), math.cos(phi), 0, 0], @@ -147,15 +169,14 @@ def parse_float_tuple(arg): [0, 0, 0, 1], ], dtype="float32", - device=dev, ) module_transforms.append(aff_mat_rot @ aff_mat_trans) modules.append( parallelproj.BlockPETScannerModule( - xp, - dev, + np, + "cpu", block_shape, block_spacing, affine_transformation_matrix=module_transforms[i], @@ -164,7 +185,7 @@ def parse_float_tuple(arg): # create the scanner geometry from a list of identical block modules at # different locations in space -scanner = parallelproj.ModularizedPETScannerGeometry(modules) +scanner = parallelproj.ModularizedPETScannerGeometry(tuple(modules)) # %% # Setup of a parllelproj LOR descriptor that connectes LOR endpoints in modules @@ -179,25 +200,34 @@ def parse_float_tuple(arg): lor_desc = parallelproj.EqualBlockPETLORDescriptor( scanner, - xp.asarray(block_pairs, device=dev), + np.asarray(block_pairs), ) # %% # setup of the ground truth image used for the data simulation -img = xp.zeros(img_shape, dtype=xp.float32, device=dev) +img = np.zeros(img_shape, dtype=np.float32) if phantom == "uniform_cylinder": - tmp = xp.linspace(-1, 1, img_shape[0]) - X0, X1 = xp.meshgrid(tmp, tmp, indexing="ij") - disk = xp.astype(xp.sqrt(X0**2 + X1**2) < 0.7, "float32") - for i in range(img_shape[2]): + tmp = np.linspace(-1, 1, img_shape[0]) + X0, X1 = np.meshgrid(tmp, tmp, indexing="ij") + disk = np.astype(np.sqrt(X0**2 + X1**2) < 0.7, "float32") + for i in range(2, img_shape[2] - 2): img[..., i] = disk elif phantom == "squares": img[2:-12, 32:-20, 2:-1] = 3 img[24:-40, 36:-28, 4:-2] = 9 img[76:78, 68:72, :-2] = 18 img[14:20, 35:75, 5:-3] = 0 +elif phantom == "points": + img[img_shape[0] // 2, img_shape[1] // 2, img_shape[2] // 2] = 8 + img[img_shape[0] // 2, img_shape[1] // 6, img_shape[2] // 2] = 4 + img[img_shape[0] // 4, img_shape[1] // 2, img_shape[2] // 2] = 6 + + img[img_shape[0] // 2, img_shape[1] // 2, img_shape[2] // 6] = 8 + img[img_shape[0] // 2, img_shape[1] // 6, img_shape[2] // 6] = 4 + img[img_shape[0] // 4, img_shape[1] // 2, img_shape[2] // 6] = 6 + else: raise ValueError("Invalid phantom {phantom}") @@ -210,7 +240,7 @@ def parse_float_tuple(arg): # calculate the number of TOF bins # we set it to twice the image diagonal divided by the tof bin width # and make sure it is an odd number -num_tof_bins = int(2 * xp.sqrt(2) * img_shape[0] * voxel_size[0] / tof_bin_width) +num_tof_bins = int(np.sqrt(2) * img_shape[0] * voxel_size[0] / tof_bin_width) if num_tof_bins % 2 == 0: num_tof_bins += 1 @@ -223,35 +253,35 @@ def parse_float_tuple(arg): ) # check if the projector passes the adjointness test -assert proj.adjointness_test(xp, dev) +assert proj.adjointness_test(np, "cpu") # %% # setup a simple image space resolution model -sig = fwhm_mm / (2.35 * xp.asarray(voxel_size, device=dev)) +sig = fwhm_mm / (2.35 * np.asarray(voxel_size)) res_model = parallelproj.GaussianFilterOperator(img_shape, sigma=sig) # %% # setup the sensitivity sinogram consisting of the crystal efficiencies factors # and the LOR symmetry group efficiencies -tmp = xp.arange(proj.lor_descriptor.num_lorendpoints_per_block) -start_el, end_el = xp.meshgrid(tmp, tmp, indexing="ij") -start_el_arr = xp.reshape(start_el, (size(start_el),)) -end_el_arr = xp.reshape(end_el, (size(end_el),)) +tmp = np.arange(proj.lor_descriptor.num_lorendpoints_per_block) +start_el, end_el = np.meshgrid(tmp, tmp, indexing="ij") +start_el_arr = np.reshape(start_el, (size(start_el),)) +end_el_arr = np.reshape(end_el, (size(end_el),)) -nontof_sens_histo = xp.ones(proj.out_shape[:-1], dtype="float32", device=dev) +nontof_sens_histo = np.ones(proj.out_shape[:-1], dtype="float32") if uniform_crystal_eff: # crystal efficiencies are all 1 - det_el_efficiencies = xp.ones( - scanner.num_modules, lor_desc.num_lorendpoints_per_block, dtype="float32" + det_el_efficiencies = np.ones( + (scanner.num_modules, lor_desc.num_lorendpoints_per_block), dtype="float32" ) else: # simulate random crystal eff. uniformly distributed between 0.2 - 2.2 - det_el_efficiencies = 0.2 + 2 * xp.astype( - xp.random.rand(scanner.num_modules, lor_desc.num_lorendpoints_per_block), + det_el_efficiencies = 0.2 + 2 * np.astype( + np.random.rand(scanner.num_modules, lor_desc.num_lorendpoints_per_block), "float32", ) # multiply the det el eff. of the first module by 3 to introduce more variation @@ -292,8 +322,8 @@ def parse_float_tuple(arg): # %% # calculate the sensitivity image -sens_img = fwd_op.adjoint(xp.ones(fwd_op.out_shape, dtype=xp.float32, device=dev)) -xp.save(output_dir / "reference_sensitivity_image.npy", sens_img) +sens_img = fwd_op.adjoint(np.ones(fwd_op.out_shape, dtype=np.float32)) +np.save(output_dir / "reference_sensitivity_image.npy", sens_img) # %% # add poisson noise on the forward projection @@ -301,27 +331,27 @@ def parse_float_tuple(arg): scale_fac = num_true_counts / img_fwd_tof.sum() img *= scale_fac img_fwd_tof *= scale_fac - emission_data = xp.random.poisson(img_fwd_tof) + emission_data = np.random.poisson(img_fwd_tof) else: emission_data = img_fwd_tof # save the ground truth image -xp.save(output_dir / "ground_truth_image.npy", img) +np.save(output_dir / "ground_truth_image.npy", img) # %% if num_epochs_mlem > 0: - recon = xp.ones(img_shape, dtype=xp.float32, device=dev) + recon = np.ones(img_shape, dtype=np.float32) for i in range(num_epochs_mlem): print(f"{(i+1):03}/{num_epochs_mlem:03}", end="\r") - exp = xp.clip(fwd_op(recon), 1e-6, None) + exp = np.clip(fwd_op(recon), 1e-6, None) grad = fwd_op.adjoint((exp - emission_data) / exp) step = recon / sens_img recon -= step * grad print("") - xp.save( + np.save( output_dir / f"reference_histogram_mlem_{num_epochs_mlem}_epochs.npy", recon ) @@ -330,22 +360,22 @@ def parse_float_tuple(arg): if num_true_counts > 0: num_events = emission_data.sum() - event_start_block = xp.zeros(num_events, dtype="uint32", device=dev) - event_start_el = xp.zeros(num_events, dtype="uint32", device=dev) - event_end_block = xp.zeros(num_events, dtype="uint32", device=dev) - event_end_el = xp.zeros(num_events, dtype="uint32", device=dev) - event_tof_bin = xp.zeros(num_events, dtype="int32", device=dev) + event_start_block = np.zeros(num_events, dtype="uint32") + event_start_el = np.zeros(num_events, dtype="uint32") + event_end_block = np.zeros(num_events, dtype="uint32") + event_end_el = np.zeros(num_events, dtype="uint32") + event_tof_bin = np.zeros(num_events, dtype="int32") event_counter = 0 for ibp, block_pair in enumerate(proj.lor_descriptor.all_block_pairs): for it, tof_bin in enumerate( - xp.arange(proj.tof_parameters.num_tofbins) + np.arange(proj.tof_parameters.num_tofbins) - proj.tof_parameters.num_tofbins // 2 ): ss = emission_data[ibp, :, it] num_slice_events = ss.sum() - inds = xp.repeat(xp.arange(ss.shape[0]), ss) + inds = np.repeat(np.arange(ss.shape[0]), ss) # event start block event_start_block[event_counter : (event_counter + num_slice_events)] = ( @@ -353,14 +383,14 @@ def parse_float_tuple(arg): ) # event start element in block event_start_el[event_counter : (event_counter + num_slice_events)] = ( - xp.take(start_el_arr, inds) + np.take(start_el_arr, inds) ) # event end module event_end_block[event_counter : (event_counter + num_slice_events)] = ( block_pair[1] ) # event end element in block - event_end_el[event_counter : (event_counter + num_slice_events)] = xp.take( + event_end_el[event_counter : (event_counter + num_slice_events)] = np.take( end_el_arr, inds ) # event TOF bin - starting at 0 @@ -369,8 +399,8 @@ def parse_float_tuple(arg): event_counter += num_slice_events # shuffle lm_event_table along 0 axis - inds = xp.arange(num_events) - xp.random.shuffle(inds) + inds = np.arange(num_events) + np.random.shuffle(inds) event_start_block = event_start_block[inds] event_start_el = event_start_el[inds] @@ -381,7 +411,7 @@ def parse_float_tuple(arg): del inds # create the unsigned tof bin (the index to the tof bin edges) that we need to write - unsigned_event_tof_bin = xp.asarray( + unsigned_event_tof_bin = np.asarray( event_tof_bin + proj.tof_parameters.num_tofbins // 2, dtype="uint32" ) @@ -416,8 +446,8 @@ def parse_float_tuple(arg): fig_geom.show() fig2, ax2 = plt.subplots(1, 4, figsize=(12, 3), tight_layout=True) - vmin = float(xp.min(sens_img)) - vmax = float(xp.max(sens_img)) + vmin = float(np.min(sens_img)) + vmax = float(np.max(sens_img)) for i, sl in enumerate( [img_shape[2] // 4, img_shape[2] // 2, 3 * img_shape[2] // 4] ): @@ -445,7 +475,7 @@ def parse_float_tuple(arg): if check_backprojection and (num_true_counts > 0): histo_back = proj.adjoint(emission_data) - xp.save(output_dir / "histogram_backprojection_tof.npy", histo_back) + np.save(output_dir / "histogram_backprojection_tof.npy", histo_back) lm_back = parallelproj.joseph3d_back_tof_lm( xstart=scanner.get_lor_endpoints(event_start_block, event_start_el), @@ -453,14 +483,14 @@ def parse_float_tuple(arg): img_shape=img_shape, img_origin=proj.img_origin, voxsize=proj.voxel_size, - img_fwd=xp.ones(num_events, dtype=xp.float32, device=dev), + img_fwd=np.ones(num_events, dtype=np.float32), tofbin_width=proj.tof_parameters.tofbin_width, - sigma_tof=xp.asarray([proj.tof_parameters.sigma_tof]), - tofcenter_offset=xp.asarray([proj.tof_parameters.tofcenter_offset]), + sigma_tof=np.asarray([proj.tof_parameters.sigma_tof]), + tofcenter_offset=np.asarray([proj.tof_parameters.tofcenter_offset]), nsigmas=proj.tof_parameters.num_sigmas, tofbin=event_tof_bin, ) - xp.save(output_dir / "lm_backprojection_tof.npy", lm_back) + np.save(output_dir / "lm_backprojection_tof.npy", lm_back) lm_back_non_tof = parallelproj.joseph3d_back( xstart=scanner.get_lor_endpoints(event_start_block, event_start_el), @@ -468,200 +498,248 @@ def parse_float_tuple(arg): img_shape=img_shape, img_origin=proj.img_origin, voxsize=proj.voxel_size, - img_fwd=xp.ones(num_events, dtype=xp.float32, device=dev), + img_fwd=np.ones(num_events, dtype=np.float32), ) - xp.save(output_dir / "lm_backprojection_non_tof.npy", lm_back_non_tof) + np.save(output_dir / "lm_backprojection_non_tof.npy", lm_back_non_tof) + +################################################################################ +################################################################################ +################################################################################ +################################################################################ +################################################################################ # %% -# create the petsird header +# create ScannerGeometry + +# The top down hiearchy of the scanner geometry is as follows: +# ScannerGeometry(list[ReplicatedDetectorModule]) +# ReplicatedDetectorModule(list[RigidTransformation], DetectorModule) +# DetectorModule(ReplicatedBoxSolidVolume) +# ReplicatedBoxSolidVolume(list[RigidTransformation], BoxSolidVolume) + +crystal_centers = parallelproj.BlockPETScannerModule( + np, "cpu", block_shape, block_spacing +).lor_endpoints + +# crystal widths in all dimensions +cw0 = block_spacing[0] +cw1 = block_spacing[1] +cw2 = block_spacing[2] + +crystal_shape = petsird.BoxShape( + corners=[ + petsird.Coordinate(c=np.array((-cw0 / 2, -cw1 / 2, -cw2 / 2), dtype="float32")), + petsird.Coordinate(c=np.array((-cw0 / 2, -cw1 / 2, cw2 / 2), dtype="float32")), + petsird.Coordinate(c=np.array((-cw0 / 2, cw1 / 2, cw2 / 2), dtype="float32")), + petsird.Coordinate(c=np.array((-cw0 / 2, cw1 / 2, -cw2 / 2), dtype="float32")), + petsird.Coordinate(c=np.array((cw0 / 2, -cw1 / 2, -cw2 / 2), dtype="float32")), + petsird.Coordinate(c=np.array((cw0 / 2, -cw1 / 2, cw2 / 2), dtype="float32")), + petsird.Coordinate(c=np.array((cw0 / 2, cw1 / 2, cw2 / 2), dtype="float32")), + petsird.Coordinate(c=np.array((cw0 / 2, cw1 / 2, -cw2 / 2), dtype="float32")), + ] +) +crystal = petsird.BoxSolidVolume(shape=crystal_shape, material_id=1) -if num_true_counts > 0: - subject = petsird.Subject(id="42") - institution = petsird.Institution( - name="Ministry of Silly Walks", - address="42 Silly Walks Street, Silly Walks City", - ) +# setup the petsird geometry of a module / block - # create non geometry related scanner information +rep_volume = petsird.ReplicatedBoxSolidVolume(object=crystal) - num_energy_bins = 1 +for i_c, crystal_center in enumerate(crystal_centers): + translation_matrix = np.eye(4, dtype="float32")[:-1, :] + for j in range(3): + translation_matrix[j, -1] = crystal_center[j] + transform = petsird.RigidTransformation(matrix=translation_matrix) - # TOF bin edges (in mm) - tofBinEdges = xp.linspace( - -proj.tof_parameters.num_tofbins * proj.tof_parameters.tofbin_width / 2, - proj.tof_parameters.num_tofbins * proj.tof_parameters.tofbin_width / 2, - proj.tof_parameters.num_tofbins + 1, - dtype="float32", + rep_volume.transforms.append(transform) + +detector_module = petsird.DetectorModule(detecting_elements=rep_volume) + +# setup the PETSIRD scanner geometry +rep_module = petsird.ReplicatedDetectorModule(object=detector_module) + +for i in range(num_blocks): + transform = petsird.RigidTransformation(matrix=module_transforms[i][:-1, :]) + + rep_module.transforms.append( + petsird.RigidTransformation(matrix=module_transforms[i][:-1, :]) ) - energyBinEdges = xp.linspace(430, 650, num_energy_bins + 1, dtype="float32") +scanner_geometry = petsird.ScannerGeometry(replicated_modules=[rep_module]) - num_total_elements = proj.lor_descriptor.scanner.num_lor_endpoints +################################################################################ +################################################################################ +################################################################################ - # setup the symmetry group ID LUT - # we only create one symmetry group ID (1) and set the group ID to -1 for block - # block pairs that are not in coincidence +# %% +# setup of detection efficiencies - module_pair_sgid_lut = xp.full((num_blocks, num_blocks), -1, dtype="int32") +# The top down hierarchy of the detection efficiencies is as follows: +# petsird.DetectionEfficiencies(detection_bin_efficiencies = list[numpy.ndarray], -> one 2D table per module type +# module_pair_sgidlut = list[list[numpy.ndarray]] -> one 2D table per module type combination +# module_pair_efficiencies_vectors = list[list[list[ModulePairEfficiencies]]]) -> list of modulepair efficiency vectors per module type combination - for bp in proj.lor_descriptor.all_block_pairs: - # generate a random sgd - sgid = sgid_from_module_pair(bp[0], bp[1], num_blocks) - module_pair_sgid_lut[bp[0], bp[1]] = sgid +# the following only works for a scanner with one module type +# if there are more module types, we need a list of DetectionEfficiencies +# and list of list of module_pair_sgidlut and list of list of list of ModulePairEfficiencies +assert scanner_geometry.number_of_replicated_modules() == 1 - num_SGIDs = module_pair_sgid_lut.max() + 1 +# setup the symmetry group ID LUT +# we only create one symmetry group ID (1) and set the group ID to -1 for block +# block pairs that are not in coincidence - num_el_per_module = proj.lor_descriptor.scanner.num_lor_endpoints_per_module[0] +module_pair_sgid_lut = np.full((num_blocks, num_blocks), -1, dtype="int32") - module_pair_efficiencies_shape = ( - num_el_per_module, - num_energy_bins, - num_el_per_module, - num_energy_bins, - ) +for bp in proj.lor_descriptor.all_block_pairs: + # generate a random sgd + sgid = sgid_from_module_pair(bp[0], bp[1], num_blocks) + module_pair_sgid_lut[bp[0], bp[1]] = sgid - module_pair_efficiencies_vector = [] +num_SGIDs = module_pair_sgid_lut.max() + 1 - for sgid in range(num_SGIDs): - eff = module_pair_eff_from_sgd(sgid, uniform=uniform_sg_eff) - vals = xp.full(module_pair_efficiencies_shape, eff, dtype="float32", device=dev) +num_el_per_module = proj.lor_descriptor.scanner.num_lor_endpoints_per_module[0] - module_pair_efficiencies_vector.append( - petsird.ModulePairEfficiencies(values=vals, sgid=sgid) - ) +module_pair_efficiencies_shape = ( + num_el_per_module * num_energy_bins, + num_el_per_module * num_energy_bins, +) - det_effs = petsird.DetectionEfficiencies( - det_el_efficiencies=xp.reshape( - det_el_efficiencies, (size(det_el_efficiencies), 1) - ), - module_pair_sgidlut=module_pair_sgid_lut, - module_pair_efficiencies_vector=module_pair_efficiencies_vector, +module_pair_efficiencies_vector = [] + +for sgid in range(num_SGIDs): + eff = module_pair_eff_from_sgd(sgid, uniform=uniform_sg_eff) + vals = np.full(module_pair_efficiencies_shape, eff, dtype="float32") + + module_pair_efficiencies_vector.append( + petsird.ModulePairEfficiencies(values=vals, sgid=sgid) ) - # setup crystal box object - - crystal_centers = parallelproj.BlockPETScannerModule( - xp, dev, block_shape, block_spacing - ).lor_endpoints - - # crystal widths in all dimensions - cw0 = block_spacing[0] - cw1 = block_spacing[1] - cw2 = block_spacing[2] - - crystal_shape = petsird.BoxShape( - corners=[ - petsird.Coordinate( - c=xp.asarray((-cw0 / 2, -cw1 / 2, -cw2 / 2), dtype="float32") - ), - petsird.Coordinate( - c=xp.asarray((-cw0 / 2, -cw1 / 2, cw2 / 2), dtype="float32") - ), - petsird.Coordinate( - c=xp.asarray((-cw0 / 2, cw1 / 2, cw2 / 2), dtype="float32") - ), - petsird.Coordinate( - c=xp.asarray((-cw0 / 2, cw1 / 2, -cw2 / 2), dtype="float32") - ), - petsird.Coordinate( - c=xp.asarray((cw0 / 2, -cw1 / 2, -cw2 / 2), dtype="float32") - ), - petsird.Coordinate( - c=xp.asarray((cw0 / 2, -cw1 / 2, cw2 / 2), dtype="float32") - ), - petsird.Coordinate( - c=xp.asarray((cw0 / 2, cw1 / 2, cw2 / 2), dtype="float32") - ), - petsird.Coordinate( - c=xp.asarray((cw0 / 2, cw1 / 2, -cw2 / 2), dtype="float32") - ), - ] +# only correct for scanner with one module type +det_effs = petsird.DetectionEfficiencies( + detection_bin_efficiencies=[det_el_efficiencies.ravel()], + module_pair_sgidlut=[[module_pair_sgid_lut]], + module_pair_efficiencies_vectors=[[module_pair_efficiencies_vector]], +) + +################################################################################ +################################################################################ +################################################################################ + +# %% +# setup ScannerInformation and Header + +# TOF bin edges (in mm) +tofBinEdges = petsird.BinEdges( + edges=np.linspace( + -proj.tof_parameters.num_tofbins * proj.tof_parameters.tofbin_width / 2, + proj.tof_parameters.num_tofbins * proj.tof_parameters.tofbin_width / 2, + proj.tof_parameters.num_tofbins + 1, + dtype="float32", ) - crystal = petsird.BoxSolidVolume(shape=crystal_shape, material_id=1) +) - # setup the petsird geometry of a module / block +energyBinEdges = petsird.BinEdges( + edges=np.linspace(430, 650, num_energy_bins + 1, dtype="float32") +) - rep_volume = petsird.ReplicatedBoxSolidVolume(object=crystal) +# num_total_elements = proj.lor_descriptor.scanner.num_lor_endpoints + +# need energy bin info before being able to construct the detection efficiencies +# so we will construct a scanner without the efficiencies first +petsird_scanner = petsird.ScannerInformation( + model_name="PETSIRD_TEST", + scanner_geometry=scanner_geometry, + tof_bin_edges=[[tofBinEdges]], # list of list for all module type combinations + tof_resolution=[ + [2.35 * proj.tof_parameters.sigma_tof] + ], # FWHM in mm, list of list for all module type combinations + event_energy_bin_edges=[energyBinEdges], # list for all module types + energy_resolution_at_511=[0.11], # as fraction of 511, list for all module types + detection_efficiencies=det_effs, +) - for i_c, crystal_center in enumerate(crystal_centers): - translation_matrix = xp.eye(4, dtype="float32")[:-1, :] - for j in range(3): - translation_matrix[j, -1] = crystal_center[j] - transform = petsird.RigidTransformation(matrix=translation_matrix) +petsird_scanner.coincidence_policy = petsird.CoincidencePolicy.REJECT_MULTIPLES +petsird_scanner.delayed_coincidences_are_stored = False +petsird_scanner.triple_events_are_stored = False - rep_volume.transforms.append(transform) - rep_volume.ids.append(i_c) +################################################################################ +################################################################################ +################################################################################ - detector_module = petsird.DetectorModule( - detecting_elements=[rep_volume], detecting_element_ids=[0] - ) +# %% +# create the petsird header - # setup the PETSIRD scanner geometry - rep_module = petsird.ReplicatedDetectorModule(object=detector_module) +subject = petsird.Subject(id="42") +institution = petsird.Institution( + name="Ministry of Silly Walks", + address="42 Silly Walks Street, Silly Walks City", +) - for i in range(num_blocks): - transform = petsird.RigidTransformation(matrix=module_transforms[i][:-1, :]) +header = petsird.Header( + exam=petsird.ExamInformation(subject=subject, institution=institution), + scanner=petsird_scanner, +) - rep_module.ids.append(i) - rep_module.transforms.append( - petsird.RigidTransformation(matrix=module_transforms[i][:-1, :]) - ) - scanner_geometry = petsird.ScannerGeometry(replicated_modules=[rep_module], ids=[0]) - - # need energy bin info before being able to construct the detection efficiencies - # so we will construct a scanner without the efficiencies first - petsird_scanner = petsird.ScannerInformation( - model_name="PETSIRD_TEST", - scanner_geometry=scanner_geometry, - tof_bin_edges=tofBinEdges, - tof_resolution=2.35 * proj.tof_parameters.sigma_tof, # FWHM in mm - energy_bin_edges=energyBinEdges, - energy_resolution_at_511=0.11, # as fraction of 511 - event_time_block_duration=1, # ms - ) +################################################################################ +################################################################################ +################################################################################ - petsird_scanner.detection_efficiencies = det_effs +# %% +# create petsird coincidence events - all in one timeblock without energy information - header = petsird.Header( - exam=petsird.ExamInformation(subject=subject, institution=institution), - scanner=petsird_scanner, - ) +num_el_per_block = proj.lor_descriptor.num_lorendpoints_per_block - # %% - # create petsird coincidence events - all in one timeblock without energy information - - num_el_per_block = proj.lor_descriptor.num_lorendpoints_per_block - - det_ID_start = event_start_block * num_el_per_block + event_start_el - det_ID_end = event_end_block * num_el_per_block + event_end_el - - # %% - # write petsird data - - if not skip_writing: - print(f"Writing LM file to {str(output_dir / fname)}") - with petsird.BinaryPETSIRDWriter(str(output_dir / fname)) as writer: - writer.write_header(header) - for i_t in range(1): - start = i_t * header.scanner.event_time_block_duration - - time_block_prompt_events = [ - petsird.CoincidenceEvent( - detector_ids=[det_ID_start[i], det_ID_end[i]], - tof_idx=unsigned_event_tof_bin[i], - energy_indices=[0, 0], - ) - for i in range(num_events) - ] - - # Normally we'd write multiple blocks, but here we have just one, so let's write a tuple with just one element - writer.write_time_blocks( - ( - petsird.TimeBlock.EventTimeBlock( - petsird.EventTimeBlock( - start=start, prompt_events=time_block_prompt_events - ) - ), - ) +# split the data into chuncks such that we loop over chunks (time blocks) +# every chunk is an array of shape (num_events_per_chunk, 3) +# the first column is the start detector ID, the second column is the end detector ID, +# and the third column is the unsigned TOF bin index + +chunked_data = np.array_split( + np.array( + [ + event_start_block * num_el_per_block + event_start_el, + event_end_block * num_el_per_block + event_end_el, + unsigned_event_tof_bin, + ] + ).T, + num_time_blocks, +) + +# %% +# write petsird data + +if not skip_writing: + print(f"Writing LM file to {str(output_dir / fname)}") + with petsird.BinaryPETSIRDWriter(str(output_dir / fname)) as writer: + writer.write_header(header) + for i_t, data_chunk in enumerate(chunked_data): + + print(f"Writing time block {i_t + 1}/{len(chunked_data)}") + print( + "First 5 events (start / stop detection element, unsigned tofbin number):" + ) + print(data_chunk[:5, :]) + print() + + time_block_prompt_events = [ + petsird.CoincidenceEvent( + detection_bins=[x[0], x[1]], + tof_idx=x[2], ) + for x in data_chunk + ] + + # Normally we'd write multiple blocks, but here we have just one, so let's write a tuple with just one element + writer.write_time_blocks( + ( + petsird.TimeBlock.EventTimeBlock( + petsird.EventTimeBlock( + prompt_events=[[time_block_prompt_events]], + time_interval=petsird.TimeInterval( + start=i_t * event_block_duration, + stop=(i_t + 1) * event_block_duration, + ), + ) + ), + ) + ) diff --git a/python/02_lm_osem_recon_simulated_data.py b/python/02_lm_osem_recon_simulated_data.py deleted file mode 100644 index dfeb772..0000000 --- a/python/02_lm_osem_recon_simulated_data.py +++ /dev/null @@ -1,303 +0,0 @@ -import array_api_compat.numpy as xp -import matplotlib.pyplot as plt -import petsird -import parallelproj - -from petsird_helpers import ( - get_module_and_element, - get_detection_efficiency, -) - -from utils import ( - parse_float_tuple, - parse_int_tuple, - mult_transforms, - transform_BoxShape, - draw_BoxShape, -) - -from pathlib import Path -import argparse - -# %% - -parser = argparse.ArgumentParser() -parser.add_argument("--lm_fname", type=str, default="my_lm_sim/simulated_lm_file.bin") -parser.add_argument("--num_epochs", type=int, default=5) -parser.add_argument("--num_subsets", type=int, default=20) -parser.add_argument("--img_shape", type=parse_int_tuple, default=(100, 100, 11)) -parser.add_argument("--voxel_size", type=parse_float_tuple, default=(1.0, 1.0, 1.0)) -parser.add_argument("--fwhm_mm", type=float, default=1.5) -parser.add_argument("--output_dir", type=str, default="my_lm_sim") - -args = parser.parse_args() - -lm_fname = args.lm_fname -num_epochs = args.num_epochs -num_subsets = args.num_subsets -img_shape = args.img_shape -voxel_size = args.voxel_size -fwhm_mm = args.fwhm_mm -output_dir = Path(args.output_dir) - -dev = "cpu" - -if not output_dir.exists(): - output_dir.mkdir(parents=True) - -# %% -if not Path(lm_fname).exists(): - raise FileNotFoundError( - f"{args.lm_fname} not found. Create it first using the generator." - ) - -# %% -# read the scanner geometry - - -reader = petsird.BinaryPETSIRDReader(lm_fname) -header = reader.read_header() - -# %% -# check whether we only have 1 type of module -assert ( - len(header.scanner.scanner_geometry.replicated_modules) == 1 -), "Only scanners with 1 module type supported yet" - -# %% -# lists where we store the detecting element coordinates and transforms for each module -# the list has one entry per module - -det_element_center_list = [] - -# %% -# read the LOR endpoint coordinates for each detecting element in each crystal -# we assume that the LOR endpoint corresponds to the center of the BoxShape - -fig_scanner = plt.figure(figsize=(8, 8), tight_layout=True) -ax_scanner = fig_scanner.add_subplot(111, projection="3d") - -for rep_module in header.scanner.scanner_geometry.replicated_modules: - det_el = rep_module.object.detecting_elements - - num_modules = len(rep_module.transforms) - - for i_mod, mod_transform in enumerate(rep_module.transforms): - for rep_volume in det_el: - - det_element_centers = xp.zeros( - (len(rep_volume.transforms), 3), dtype="float32" - ) - - num_el_per_module = len(rep_volume.transforms) - - for i_el, el_transform in enumerate(rep_volume.transforms): - - combined_transform = mult_transforms([mod_transform, el_transform]) - transformed_boxshape = transform_BoxShape( - combined_transform, rep_volume.object.shape - ) - - transformed_boxshape_vertices = xp.array( - [c.c for c in transformed_boxshape.corners] - ) - - det_element_centers[i_el, ...] = transformed_boxshape_vertices.mean( - axis=0 - ) - - # visualize the detecting elements - draw_BoxShape(ax_scanner, transformed_boxshape) - if i_el == 0 or i_el == len(rep_volume.transforms) - 1: - ax_scanner.text( - float(transformed_boxshape_vertices[0][0]), - float(transformed_boxshape_vertices[0][1]), - float(transformed_boxshape_vertices[0][2]), - f"{i_el:02}/{i_mod:02}", - fontsize=7, - ) - - det_element_center_list.append(det_element_centers) - -# %% -# create a list of the element detection efficiencies per module -# this is a simple re-ordering of the detection efficiencies array which -# makes the access easier -# we assume that all modules have the same number of detecting elements -det_el_efficiencies = [ - header.scanner.detection_efficiencies.det_el_efficiencies[ - i * num_el_per_module : (i + 1) * num_el_per_module, 0 - ] - for i in range(num_modules) -] - -num_tofbins = len(header.scanner.tof_bin_edges) - 1 -tofbin_width = header.scanner.tof_bin_edges[1] - header.scanner.tof_bin_edges[0] -sigma_tof = header.scanner.tof_resolution / 2.35 - -tof_params = parallelproj.TOFParameters( - num_tofbins=num_tofbins, tofbin_width=tofbin_width, sigma_tof=sigma_tof -) - -assert num_tofbins % 2 == 1, "Number of TOF bins must be odd" -# %% -# calculate the sensitivity image -print("Calculating sensitivity image") - -# we loop through the symmetric group ID look up table to see which module pairs -# are in coincidence - - -sig = fwhm_mm / (2.35 * xp.asarray(voxel_size, device=dev)) -res_model = parallelproj.GaussianFilterOperator(img_shape, sigma=sig) - -sens_img = xp.zeros(img_shape, dtype="float32") - -for i in range(num_modules): - for j in range(num_modules): - sgid = header.scanner.detection_efficiencies.module_pair_sgidlut[i, j] - - if sgid >= 0: - print(f"mod1 {i:03}, mod2 {j:03}, SGID {sgid:03}", end="\r") - - start_det_el = det_element_center_list[i] - end_det_el = det_element_center_list[j] - - # create an array of that contains all possible combinations of start and end detecting element coordinates - # these define all possible LORs between the two modules - start_coords = xp.repeat(start_det_el, len(end_det_el), axis=0) - end_coords = xp.tile(end_det_el, (len(start_det_el), 1)) - - proj = parallelproj.ListmodePETProjector( - start_coords, end_coords, img_shape, voxel_size - ) - proj.tof_parameters = tof_params - - # get the module pair efficiencies - asumming that we only use 1 energy bin - module_pair_eff = ( - header.scanner.detection_efficiencies.module_pair_efficiencies_vector[ - sgid - ].values[:, 0, :, 0] - ).ravel() - - start_el_eff = xp.repeat(det_el_efficiencies[i], len(end_det_el), axis=0) - end_el_eff = xp.tile(det_el_efficiencies[j], (len(start_det_el))) - - for tofbin in xp.arange(-(num_tofbins // 2), num_tofbins // 2 + 1): - # print(tofbin) - proj.event_tofbins = xp.full( - start_coords.shape[0], tofbin, dtype="int32" - ) - proj.tof = True - sens_img += proj.adjoint(start_el_eff * end_el_eff * module_pair_eff) - -print("") - -# for some reason we have to divide the sens image by the number of TOF bins -# right now unclear why that is -sens_img = res_model.adjoint(sens_img) - -# %% -# read all coincidence events -print("Reading LM events") - -num_prompts = 0 -event_counter = 0 -num_tof_bins = header.scanner.number_of_tof_bins() - -xstart = [] -xend = [] -tof_bin = [] -effs = [] - -for i_time_block, time_block in enumerate(reader.read_time_blocks()): - if isinstance(time_block, petsird.TimeBlock.EventTimeBlock): - num_prompts += len(time_block.value.prompt_events) - - for i_event, event in enumerate(time_block.value.prompt_events): - event_mods_and_els = get_module_and_element( - header.scanner.scanner_geometry, event.detector_ids - ) - - event_start_coord = det_element_center_list[event_mods_and_els[0].module][ - event_mods_and_els[0].el - ] - xstart.append(event_start_coord) - - event_end_coord = det_element_center_list[event_mods_and_els[1].module][ - event_mods_and_els[1].el - ] - xend.append(event_end_coord) - - # get the event efficiencies - effs.append(get_detection_efficiency(header.scanner, event)) - # get the signed event TOF bin (0 is the central bin) - tof_bin.append(event.tof_idx - num_tof_bins // 2) - - # visualize the first 5 events in the time block - if i_time_block == 0 and i_event < 5: - ax_scanner.plot( - [event_start_coord[0], event_end_coord[0]], - [event_start_coord[1], event_end_coord[1]], - [event_start_coord[2], event_end_coord[2]], - ) - - event_counter += 1 - -reader.close() - -xstart = xp.asarray(xstart, device=dev) -xend = xp.asarray(xend, device=dev) -effs = xp.asarray(effs, device=dev) -tof_bin = xp.asarray(tof_bin, device=dev) - -# %% -# set the x, y, z limits of the scanner plot -xmin = xp.asarray([x.min(axis=0) for x in det_element_center_list]).min(axis=0) -xmax = xp.asarray([x.max(axis=0) for x in det_element_center_list]).max(axis=0) -r = (xmax - xmin).max() - -ax_scanner.set_xlabel("x0") -ax_scanner.set_ylabel("x1") -ax_scanner.set_zlabel("x2") -ax_scanner.set_xlim(xmin.min() - 0.05 * r, xmax.max() + 0.05 * r) -ax_scanner.set_ylim(xmin.min() - 0.05 * r, xmax.max() + 0.05 * r) -ax_scanner.set_zlim(xmin.min() - 0.05 * r, xmax.max() + 0.05 * r) - -fig_scanner.savefig(output_dir / "scanner_geometry.png") -fig_scanner.show() - -# %% -# run a LM OSEM recon -recon = xp.ones(img_shape, dtype="float32") - -lm_subset_projs = [] -subset_slices = [slice(i, None, num_subsets) for i in range(num_subsets)] - -for i_subset, sl in enumerate(subset_slices): - lm_subset_projs.append( - parallelproj.ListmodePETProjector( - xstart[sl, :], xend[sl, :], img_shape, voxel_size - ) - ) - lm_subset_projs[i_subset].tof_parameters = tof_params - lm_subset_projs[i_subset].event_tofbins = tof_bin[sl] - lm_subset_projs[i_subset].tof = True - -for i_epoch in range(num_epochs): - for i_subset, sl in enumerate(subset_slices): - print( - f"it {(i_epoch +1):03} / {num_epochs:03}, ss {(i_subset+1):03} / {num_subsets:03}", - end="\r", - ) - lm_exp = effs[sl] * lm_subset_projs[i_subset](res_model(recon)) - tmp = num_subsets * res_model( - lm_subset_projs[i_subset].adjoint(effs[sl] / lm_exp) - ) - recon *= tmp / sens_img - -print("") - -opath = output_dir / f"lm_osem_{num_epochs}_{num_subsets}.npy" -xp.save(opath, recon) -print(f"LM OSEM recon saved to {opath}") diff --git a/python/02_reconstruct_petsird.py b/python/02_reconstruct_petsird.py new file mode 100644 index 0000000..6c91105 --- /dev/null +++ b/python/02_reconstruct_petsird.py @@ -0,0 +1,341 @@ +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +import parallelproj +import petsird + +# for the parallelproj recon we can use cupy as array backend if available +# otherwise we fall back to numpy +try: + import cupy as xp +except ModuleNotFoundError: + import numpy as xp +import argparse + +print(f"Using {xp.__name__} for parallelproj reconstructions") + +from utils import ( + get_all_detector_centers, + read_listmode_prompt_events, + backproject_efficiencies, +) + + +# %% +################################################################################ +#### PARSE THE COMMAND LINE #################################################### +################################################################################ + +parser = argparse.ArgumentParser( + description="PETSIRD analytic simulator reconstruction" +) +parser.add_argument("fname", type=str, help="Path to the PETSIRD listmode file") +parser.add_argument( + "--img_shape", + type=int, + nargs=3, + default=None, + help="Shape of the image to be reconstructed", +) +parser.add_argument( + "--voxel_size", + type=float, + nargs=3, + default=[2.0, 2.0, 2.0], + help="Voxel size in mm", +) +parser.add_argument( + "--fwhm_mm", + type=float, + default=2.5, + help="FWHM in mm for Gaussian filter for resolution model", +) +parser.add_argument( + "--store_energy_bins", action="store_true", help="Whether to store energy bins" +) +parser.add_argument("--num_epochs", type=int, default=5, help="Number of OSEM epochs") +parser.add_argument( + "--num_subsets", type=int, default=20, help="Number of OSEM subsets" +) +parser.add_argument( + "--verbose", action="store_true", help="Whether to print verbose output" +) +parser.add_argument( + "--unity_sens_img", + action="store_true", + help="Whether to skip sensitivity image calculation and use unity image", +) +parser.add_argument( + "--unity_effs", + action="store_true", + help="Whether to use unity efficiencies for LORs", +) +parser.add_argument( + "--non-tof", action="store_true", help="Whether to disable TOF in recon" +) + +args = parser.parse_args() + +fname = args.fname +img_shape: list[int] | None = args.img_shape +voxel_size = tuple(args.voxel_size) +fwhm_mm = args.fwhm_mm +store_energy_bins = args.store_energy_bins +num_epochs = args.num_epochs +num_subsets = args.num_subsets +verbose = args.verbose +unity_sens_img = args.unity_sens_img +unity_effs = args.unity_effs +tof = not args.non_tof + +# %% +################################################################################ +#### READ PETSIRD HEADER ####################################################### +################################################################################ + +reader = petsird.BinaryPETSIRDReader(fname) +header: petsird.Header = reader.read_header() +scanner_info: petsird.ScannerInformation = header.scanner +scanner_geom: petsird.ScannerGeometry = scanner_info.scanner_geometry + +num_replicated_modules = scanner_geom.number_of_replicated_modules() +print(f"Scanner with {num_replicated_modules} types of replicated modules.") + +# %% +################################################################################ +### READ DETECTOR CENTERS AND VISUALIZE ######################################## +################################################################################ + +# Create a new figure +fig = plt.figure() +ax = fig.add_subplot(111, projection="3d") + +print("Calculating all detector centers ...") +all_detector_centers = get_all_detector_centers(scanner_geom, ax=ax) + + +if not ax is None: + min_coords = all_detector_centers[0].reshape(-1, 3).min(0) + max_coords = all_detector_centers[0].reshape(-1, 3).max(0) + + ax.set_xlim3d([min_coords.min(), max_coords.max()]) + ax.set_ylim3d([min_coords.min(), max_coords.max()]) + ax.set_zlim3d([min_coords.min(), max_coords.max()]) + + for detector_centers in all_detector_centers: + ax.scatter( + detector_centers[:, :, 0].ravel(), + detector_centers[:, :, 1].ravel(), + detector_centers[:, :, 2].ravel(), + s=0.5, + c="k", + alpha=0.3, + ) + fig.show() + +# %% +################################################################################ +### DETERMINE IMAGE ORIGIN AND IMAGE SHAPE IF NOT GIVEN ####################### +################################################################################ + + +if img_shape is not None: + img_shape = tuple(img_shape) +else: + # get the bounding box of the scanner detection elements + scanner_bbox = all_detector_centers[0].reshape(-1, 3).max(0) - all_detector_centers[ + 0 + ].reshape(-1, 3).min(0) + + i_ax = int( + np.argmin( + np.array( + [ + np.abs(scanner_bbox[1] - scanner_bbox[2]), + np.abs(scanner_bbox[0] - scanner_bbox[2]), + np.abs(scanner_bbox[0] - scanner_bbox[1]), + ] + ) + ) + ) + + img_shape = (0.53 * scanner_bbox / np.array(voxel_size)).astype(int) + img_shape[i_ax] = int(scanner_bbox[i_ax] / voxel_size[i_ax]) + if img_shape[i_ax] % 2 == 0: + img_shape[i_ax] += 1 # make sure the image shape is odd + img_shape = tuple(img_shape.tolist()) + +# calculate the scanner iso center to set the image origin that we need for the projectors +scanner_iso_center = xp.asarray(all_detector_centers[0].reshape(-1, 3).mean(0)) +img_origin = scanner_iso_center - 0.5 * (xp.asarray(img_shape) - 1) * xp.asarray( + voxel_size +) + +if verbose: + print(f"Image shape: {img_shape}") + print(f"Image origin: {img_origin}") + print(f"Voxel size: {voxel_size}") + +# %% +################################################################################ +### CALCULATE THE SENSITIVTY IMAGE ############################################# +################################################################################ + +sig = fwhm_mm / (2.35 * np.asarray(voxel_size)) +res_model = parallelproj.GaussianFilterOperator(img_shape, sigma=sig) + +if unity_sens_img: + print("Using ones as sensitivity image ...") + sens_img = np.ones(img_shape, dtype="float32") +else: + print("Calculating sensitivity image ...") + sens_img: xp.ndarray = backproject_efficiencies( + scanner_info, + all_detector_centers, + img_shape, + voxel_size, + verbose=verbose, + tof=tof, + xp=xp, + ) + + # apply adjoint of image-based resolution model + sens_img = res_model.adjoint(sens_img) + +# %% +################################################################################ +### CHECK WHETHER THE CALC SENS IMAGE MATCHES THE REF. SENSE IMAGE ############# +################################################################################ + +ref_sens_img_path = Path(fname).parent / "reference_sensitivity_image.npy" + +if ref_sens_img_path.exists(): + print(f"loading reference sensitivity image from {ref_sens_img_path}") + ref_sens_img = np.load(ref_sens_img_path) + if ref_sens_img.shape == sens_img.shape: + if np.allclose(np.asarray(sens_img), ref_sens_img): + print( + f"calculated sensitivity image matches reference image {ref_sens_img_path}" + ) + else: + print( + f"calculated sensitivity image does NOT match reference image {ref_sens_img_path}" + ) + +# %% +################################################################################ +### READ THE LM PROMPT EVENTS AND CONVERT TO COODINATES ######################## +################################################################################ + +print("Reading prompt events from listmode file ...") + +coords0, coords1, signed_tof_bins, effs, energy_idx0, energy_idx1 = ( + read_listmode_prompt_events( + reader, + header, + all_detector_centers, + store_energy_bins=True, + verbose=verbose, + unity_effs=unity_effs, + ) +) + +print(signed_tof_bins.min(), signed_tof_bins.max()) + +# %% +################################################################################ +### VISUALIZE GEOMETRY AND FIRST 5 EVENTS ###################################### +################################################################################ +if not ax is None: + for i in range(5): + ax.plot( + [coords0[i, 0], coords1[i, 0]], + [coords0[i, 1], coords1[i, 1]], + [coords0[i, 2], coords1[i, 2]], + ) + fig.show() + +# %% +################################################################################ +### PARALLELRPOJ BACKPROJECTIONS ############################################### +################################################################################ + +proj = parallelproj.ListmodePETProjector( + xp.asarray(coords0).copy(), + xp.asarray(coords1).copy(), + img_shape, + voxel_size, + img_origin=img_origin, +) + +non_tof_backproj = proj.adjoint(xp.ones(coords0.shape[0], dtype="float32")) + +#### HACK assumes same TOF parameters for all module type pairs +sigma_tof = scanner_info.tof_resolution[0][0] / 2.35 +tof_bin_edges = scanner_info.tof_bin_edges[0][0].edges +num_tofbins = tof_bin_edges.size - 1 +tofbin_width = float(tof_bin_edges[1] - tof_bin_edges[0]) + +tof_params = parallelproj.TOFParameters( + num_tofbins=num_tofbins, tofbin_width=tofbin_width, sigma_tof=sigma_tof +) + +proj.tof_parameters = tof_params +proj.event_tofbins = xp.asarray(signed_tof_bins).copy() +proj.tof = True + +tof_backproj = proj.adjoint(xp.ones(coords0.shape[0], dtype="float32")) + +del proj +# %% +################################################################################ +### PARALLELRPOJ LM OSEM RECONSTRUCTION ######################################## +################################################################################ + +print("Starting parallelproj LM OSEM reconstruction ...") + +lm_subset_projs = [] +subset_slices = [slice(i, None, num_subsets) for i in range(num_subsets)] + +# init recon, sens and eff arrays and covert to xp (numpy or cupy) arrays +recon = xp.ones(img_shape, dtype="float32") +effs = xp.asarray(effs, dtype="float32") + +for i_subset, sl in enumerate(subset_slices): + lm_subset_projs.append( + parallelproj.ListmodePETProjector( + xp.asarray(coords0[sl, :]).copy(), + xp.asarray(coords1[sl, :]).copy(), + img_shape, + voxel_size, + img_origin=img_origin, + ) + ) + + #### HACK assumes same TOF parameters for all module type pairs + if tof: + lm_subset_projs[i_subset].tof_parameters = tof_params + lm_subset_projs[i_subset].event_tofbins = xp.asarray(signed_tof_bins[sl]).copy() + lm_subset_projs[i_subset].tof = True + +for i_epoch in range(num_epochs): + for i_subset, sl in enumerate(subset_slices): + print( + f"it {(i_epoch +1):03} / {num_epochs:03}, ss {(i_subset+1):03} / {num_subsets:03}", + end="\r", + ) + lm_exp = effs[sl] * lm_subset_projs[i_subset](res_model(recon)) + tmp = num_subsets * res_model.adjoint( + lm_subset_projs[i_subset].adjoint(effs[sl] / lm_exp) + ) + recon *= tmp / sens_img + +opath = Path(fname).parent / f"lm_osem_{num_epochs}_{num_subsets}.npy" +xp.save(opath, recon) +print(f"LM OSEM recon saved to {opath}") + +# %% +# SHOW RECON +import pymirc.viewer as pv + +vi = pv.ThreeAxisViewer([parallelproj.to_numpy_array(x) for x in [recon, sens_img]]) diff --git a/python/README.md b/python/README.md deleted file mode 100644 index b5d6ff5..0000000 --- a/python/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# PETSIRD basic Python example - -As you can now install the `petsird` package from PyPI, you likely no longer -need this repository and can just use `pip install petsird`. -You can of course use the Python package generated from the local PETSIRD -clone (`cd python; pip install .`). See -https://github.com/ETSInitiative/PETSIRD/tree/main/python#readme -for more information. - diff --git a/python/petsird_helpers.py b/python/petsird_helpers.py deleted file mode 100644 index 62a85c8..0000000 --- a/python/petsird_helpers.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Preliminary helpers for PETSIRD data -""" - -# Copyright (C) 2024 University College London -# -# SPDX-License-Identifier: Apache-2.0 - -import typing -from dataclasses import dataclass - -import petsird - - -def get_num_det_els(scanner_geometry: petsird.ScannerGeometry) -> int: - """Compute total number of detecting elements in the scanner""" - num_det_els = 0 - for rep_module in scanner_geometry.replicated_modules: - det_els = rep_module.object.detecting_elements - for rep_volume in det_els: - num_det_els += len(rep_volume.transforms) * len( - rep_module.transforms) - return num_det_els - - -@dataclass -class ModuleAndElement: - """ - Dataclass to store a module ID and element ID, where the latter runs - over all detecting volumes in the module - """ - - module: int - el: int - - -def get_module_and_element( - scanner_geometry: petsird.ScannerGeometry, - scanner_det_ids: typing.Iterable[int]) -> list[ModuleAndElement]: - """Find ModuleAndElement for a list of detector_ids""" - assert len(scanner_geometry.replicated_modules) == 1 - rep_module = scanner_geometry.replicated_modules[0] - assert len(rep_module.object.detecting_elements) == 1 - num_el_per_module = len(rep_module.object.detecting_elements[0].ids) - - return [ - ModuleAndElement(module=det // num_el_per_module, - el=det % num_el_per_module) for det in scanner_det_ids - ] - - -def get_detection_efficiency(scanner: petsird.ScannerInformation, - event: petsird.CoincidenceEvent) -> float: - """Compute the detection efficiency for a coincidence event""" - eff = 1 - - # per det_el efficiencies - det_el_efficiencies = scanner.detection_efficiencies.det_el_efficiencies - if det_el_efficiencies is not None: - eff *= (det_el_efficiencies[event.detector_ids[0], - event.energy_indices[0]] * - det_el_efficiencies[event.detector_ids[1], - event.energy_indices[1]]) - - # per module-pair efficiencies - module_pair_efficiencies_vector = ( - scanner.detection_efficiencies.module_pair_efficiencies_vector) - if module_pair_efficiencies_vector is not None: - module_pair_SGID_LUT = scanner.detection_efficiencies.module_pair_sgidlut - assert module_pair_SGID_LUT is not None - mod_and_els = get_module_and_element(scanner.scanner_geometry, - event.detector_ids) - assert len(scanner.scanner_geometry.replicated_modules) == 1 - SGID = module_pair_SGID_LUT[mod_and_els[0].module, - mod_and_els[1].module] - assert SGID >= 0 - module_pair_efficiencies = module_pair_efficiencies_vector[SGID] - assert module_pair_efficiencies.sgid == SGID - eff *= module_pair_efficiencies.values[ - mod_and_els[0].el, - event.energy_indices[0], - mod_and_els[1].el, - event.energy_indices[1], - ] - - return eff diff --git a/python/utils.py b/python/utils.py index fb4e61e..2bebded 100644 --- a/python/utils.py +++ b/python/utils.py @@ -1,83 +1,503 @@ -import petsird -import numpy.typing as npt +from importlib.metadata import version + +# raise an error if petsird version is not at least 0.7.2 +petsird_version = tuple(map(int, version("petsird").split("."))) +if petsird_version < (0, 7, 2): + raise ImportError( + f"petsird version {petsird_version} is not supported, please install petsird >= 0.7.2" + ) + + import numpy as np +import warnings + +import parallelproj +import petsird +import petsird.helpers.geometry -import matplotlib.pyplot as plt +from types import ModuleType from mpl_toolkits.mplot3d.art3d import Poly3DCollection +from petsird.helpers import ( + expand_detection_bin, + get_detection_efficiency, +) -def parse_int_tuple(arg): - return tuple(map(int, arg.split(","))) +def draw_BoxShape(ax, box: petsird.BoxShape) -> None: + vertices = np.array([c.c for c in box.corners]) + edges = [ + [vertices[j] for j in [0, 1, 2, 3]], + [vertices[j] for j in [4, 5, 6, 7]], + [vertices[j] for j in [0, 1, 5, 4]], + [vertices[j] for j in [2, 3, 7, 6]], + [vertices[j] for j in [1, 2, 6, 5]], + [vertices[j] for j in [4, 7, 3, 0]], + ] + box_poly = Poly3DCollection(edges, alpha=0.1, linewidths=0.25, edgecolors="r") + ax.add_collection3d(box_poly) -def parse_float_tuple(arg): - return tuple(map(float, arg.split(","))) +def get_all_detector_centers( + scanner_geometry: petsird.ScannerGeometry, ax=None +) -> list[np.ndarray]: + # a list containing the center of all detecting elements for all modules + # every element of the list corresponds to one module type + # for every module type, we have an numpy array of shape (num_modules, num_det_els, 3) + # for a given module type, module number and detector el number, we can access the center of the detector element with + # all_det_el_centers[module_type][module_number, detector_el_number, :] -def transform_to_mat44( - transform: petsird.RigidTransformation, -) -> npt.NDArray[np.float32]: - return np.vstack([transform.matrix, [0, 0, 0, 1]]) + all_det_el_centers = [] + # draw all crystals + for rep_module in scanner_geometry.replicated_modules: + det_els = rep_module.object.detecting_elements + det_el_centers = np.zeros( + (len(rep_module.transforms), len(det_els.transforms), 3) + ) + for i_mod, mod_transform in enumerate(rep_module.transforms): -def mat44_to_transform(mat: npt.NDArray[np.float32]) -> petsird.RigidTransformation: - return petsird.RigidTransformation(matrix=mat[0:3, :]) + for i_det_el, transform in enumerate(det_els.transforms): + transformed_boxshape = ( + petsird.helpers.geometry.transform_BoxShape( + petsird.helpers.geometry.mult_transforms( + [mod_transform, transform] + ), + det_els.object.shape, + ), + )[0] + transformed_boxshape_vertices = np.array( + [c.c for c in transformed_boxshape.corners] + ) -def coordinate_to_homogeneous(coord: petsird.Coordinate) -> npt.NDArray[np.float32]: - return np.hstack([coord.c, 1]) + det_el_centers[i_mod, i_det_el, :] = transformed_boxshape_vertices.mean( + axis=0 + ) + if ax is not None: + draw_BoxShape( + ax, + transformed_boxshape, + ) + all_det_el_centers.append(det_el_centers) -def homogeneous_to_coordinate( - hom_coord: npt.NDArray[np.float32], -) -> petsird.Coordinate: - return petsird.Coordinate(c=hom_coord[0:3]) + return all_det_el_centers -def mult_transforms( - transforms: list[petsird.RigidTransformation], -) -> petsird.RigidTransformation: - """multiply rigid transformations""" - mat = np.array( - ((1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)), - dtype="float32", - ) +def backproject_efficiencies( + scanner_info: petsird.ScannerInformation, + all_detector_centers: list[np.ndarray], + img_shape: tuple[int, int, int], + voxel_size: tuple[float, float, float], + tof: bool = False, + verbose: bool = False, + xp: ModuleType = np, +) -> np.ndarray: - for t in reversed(transforms): - mat = np.matmul(transform_to_mat44(t), mat) - return mat44_to_transform(mat) + scanner_geom: petsird.ScannerGeometry = scanner_info.scanner_geometry + num_replicated_modules = scanner_geom.number_of_replicated_modules() + # read detection element / module pair efficiencies -def mult_transforms_coord( - transforms: list[petsird.RigidTransformation], coord: petsird.Coordinate -) -> petsird.Coordinate: - """apply list of transformations to coordinate""" - # TODO better to multiply with coordinates in sequence, as first multiplying the matrices - hom = np.matmul( - transform_to_mat44(mult_transforms(transforms)), - coordinate_to_homogeneous(coord), + # get the dection bin efficiencies for all module types + # index via: det_bin_effs = all_detection_bin_effs[rep_mod_type] + # which returns a 1D array of shape (num_detection_bins,) = (num_det_els_in_module * num_energy_bins,) + all_detection_bin_effs: list[petsird.DetectionBinEfficiencies] | None = ( + scanner_info.detection_efficiencies.detection_bin_efficiencies ) - return homogeneous_to_coordinate(hom) - -def transform_BoxShape( - transform: petsird.RigidTransformation, box_shape: petsird.BoxShape -) -> petsird.BoxShape: - - return petsird.BoxShape( - corners=[mult_transforms_coord([transform], c) for c in box_shape.corners] + # get the symmetry group ID LUTs for all module types + # index via: 2D_SGID_LUT = all_module_pair_sgidlut[rep_mod_type 1][rep_mod_type 2] + # which returns a 2D array of shape (num_modules, num_modules) + all_module_pair_sgidluts: list[list[petsird.ModulePairSGIDLUT]] | None = ( + scanner_info.detection_efficiencies.module_pair_sgidlut ) + # get all module pair efficiencies vectors + # index via: mod_pair_effs = module_pair_efficiencies_vectors[rep_mod_type 1][rep_mod_type 2][sgid] + # which returns a 2D array of shape (num_det_els, num_det_els) + all_module_pair_efficiency_vectors: ( + list[list[list[petsird.ModulePairEfficiencies]]] | None + ) = scanner_info.detection_efficiencies.module_pair_efficiencies_vectors -def draw_BoxShape(ax, box: petsird.BoxShape) -> None: - vertices = np.array([c.c for c in box.corners]) - edges = [ - [vertices[j] for j in [0, 1, 2, 3]], - [vertices[j] for j in [4, 5, 6, 7]], - [vertices[j] for j in [0, 1, 5, 4]], - [vertices[j] for j in [2, 3, 7, 6]], - [vertices[j] for j in [1, 2, 6, 5]], - [vertices[j] for j in [4, 7, 3, 0]], + # number of modules for every module type + all_num_modules: list[int] = [ + len(x.transforms) for x in scanner_geom.replicated_modules + ] + # number of detecting elements for every module type + all_num_det_els: list[int] = [ + x.object.detecting_elements.number_of_objects() + for x in scanner_geom.replicated_modules ] - box = Poly3DCollection(edges, alpha=0.1, linewidths=0.1, edgecolors=plt.cm.tab10(0)) - ax.add_collection3d(box) + + if all_detection_bin_effs is None: + all_detection_bin_effs = [ + np.ones(x * y, dtype="float32") + for (x, y) in zip(all_num_modules, all_num_det_els) + ] + warnings.warn( + "No detection bin efficiencies found in scanner information - assuming all detection efficiencies are 1.0.", + ) + + if all_module_pair_sgidluts is None: + # no SGID LUTs are giving, assume that all modules pairs are in SGID 0 + + all_module_pair_sgidluts = [] + for num_mod1 in all_num_modules: + tmp_list = [] + for num_mod2 in all_num_modules: + lut = np.full((num_mod1, num_mod2), -1, dtype="int") + for i in range(num_mod1): + for j in range(i + 1, num_mod2): + lut[i, j] = 0 + + tmp_list.append(lut) + all_module_pair_sgidluts.append(tmp_list) + + warnings.warn( + "No module pair SGID LUTs found in scanner information. Asumming all module pairs are in SGID 0.", + UserWarning, + ) + + if all_module_pair_efficiency_vectors is None: + + all_module_pair_efficiency_vectors = [] + for num_det_els1 in all_num_det_els: + tmp_list = [] + for num_det_els2 in all_num_det_els: + # create a dummy efficiency vector with all ones + # the shape is (num_det_els1, num_energy_bins, num_det_els2, num_energy_bins) + # where num_energy_bins is the number of energy bins for the respective module type + tmp_list.append( + [ + petsird.ModulePairEfficiencies( + values=np.ones( + (num_det_els1, num_det_els2), dtype="float32" + ) + ) + ] + ) + all_module_pair_efficiency_vectors.append(tmp_list) + + warnings.warn( + "No module pair efficiencies vectors found in scanner information. Assuming all ones.", + UserWarning, + ) + + # %% + # generate the sensitivity image + + if verbose: + print("Generating sensitivity image") + sens_img = xp.zeros(img_shape, dtype="float32") + + for mod_type_1 in range(num_replicated_modules): + num_modules_1 = len(scanner_geom.replicated_modules[mod_type_1].transforms) + + energy_bin_edges_1 = scanner_info.event_energy_bin_edges[mod_type_1].edges + num_energy_bins_1 = energy_bin_edges_1.size - 1 + + det_bin_effs_1 = all_detection_bin_effs[mod_type_1].reshape( + num_modules_1, -1, num_energy_bins_1 + ) + + for mod_type_2 in range(num_replicated_modules): + num_modules_2 = len(scanner_geom.replicated_modules[mod_type_2].transforms) + + energy_bin_edges_2 = scanner_info.event_energy_bin_edges[mod_type_2].edges + num_energy_bins_2 = energy_bin_edges_2.size - 1 + det_bin_effs_2 = all_detection_bin_effs[mod_type_2].reshape( + num_modules_2, -1, num_energy_bins_2 + ) + + if verbose: + print( + f"Module type {mod_type_1} with {num_modules_1} modules vs. {mod_type_2} and {num_modules_2} modules" + ) + + # sigma TOF (mm) for module type combination + sigma_tof = scanner_info.tof_resolution[mod_type_1][mod_type_2] / 2.35 + tof_bin_edges = scanner_info.tof_bin_edges[mod_type_1][mod_type_2].edges + num_tofbins = tof_bin_edges.size - 1 + tofbin_width = float(tof_bin_edges[1] - tof_bin_edges[0]) + + # raise an error if tof_bin_edges are non equidistant (up to 0.1%) + if not np.allclose( + np.diff(tof_bin_edges), tof_bin_edges[1] - tof_bin_edges[0], rtol=0.001 + ): + raise ValueError( + f"TOF bin edges for module types {mod_type_1} and {mod_type_2} are not equidistant." + ) + + for i_mod_1 in range(num_modules_1): + + # get the row of the SGID LUT for the current module type + sgids = all_module_pair_sgidluts[0][0][i_mod_1].copy() + + # neglect the lower triangle of the SGID LUT to make sure we + # only back-project every module pair once + sgids[: (i_mod_1 + 1)] = -1 + + num_mods_in_coindence = np.count_nonzero(sgids >= 0) + + if num_mods_in_coindence == 0: + continue + + start_coords = np.zeros( + ( + num_mods_in_coindence, + det_bin_effs_1.shape[1] * det_bin_effs_2.shape[1], + 3, + ), + dtype="float32", + ) + + end_coords = np.zeros( + ( + num_mods_in_coindence, + det_bin_effs_1.shape[1] * det_bin_effs_2.shape[1], + 3, + ), + dtype="float32", + ) + + to_be_back_projected = np.zeros( + ( + num_energy_bins_1, + num_energy_bins_2, + num_mods_in_coindence, + det_bin_effs_1.shape[1] * det_bin_effs_2.shape[1], + ), + dtype="float32", + ) + + i_coinc = 0 + + for i_mod_2, sgid in enumerate(sgids): + + if sgid < 0: + # skip module pairs that are not in coincidence + continue + + # if the symmetry group ID (sgid) is non-negative, the module pair is in coincidence + if verbose: + print( + f" Module pair ({mod_type_1}, {i_mod_1}) vs. ({mod_type_2}, {i_mod_2}) with SGID {sgid}" + ) + + # 2D array containg the 3 coordinates of all detecting elements the start module + start_det_coords = all_detector_centers[mod_type_1][i_mod_1, :, :] + # 2D array containg the 3 coordinates of all detecting elements the end module + end_det_coords = all_detector_centers[mod_type_2][i_mod_2, :, :] + + # 2D array of start coordinates of all LORs connecting all detecting elements + # of the start module with all detecting elements of the end module + start_coords[i_coinc, :, :] = np.repeat( + start_det_coords, start_det_coords.shape[0], axis=0 + ) + + # 2D array of end coordinates of all LORs connecting all detecting elements + # of the start module with all detecting elements of the end module + end_coords[i_coinc, :, :] = np.tile( + end_det_coords, (end_det_coords.shape[0], 1) + ) + + # 2D array of shape (num_detection_bins, num_detection_bins) = + # (num_det_els * num_energy_bins, num_det_els * num_energy_bins) + module_pair_efficiencies = all_module_pair_efficiency_vectors[ + mod_type_1 + ][mod_type_2][sgid].values + + module_pair_efficiencies = module_pair_efficiencies.reshape( + module_pair_efficiencies.shape[0] // num_energy_bins_1, + num_energy_bins_1, + module_pair_efficiencies.shape[0] // num_energy_bins_2, + num_energy_bins_2, + ) + + for i_e_1 in range(num_energy_bins_1): + for i_e_2 in range(num_energy_bins_2): + if verbose: + print(f" Energy bin pair ({i_e_1}, {i_e_2})") + + # get the detection bin efficiencies for the start module + # 1D array of shape (num_det_els,) + start_det_bin_effs = det_bin_effs_1[i_mod_1, :, i_e_1] + # get the detection bin efficiencies for the end module + # 1D array of shape (num_det_els,) + end_det_bin_effs = det_bin_effs_2[i_mod_2, :, i_e_2] + + # (non-TOF) sensitivity values to be back-projected + ########## + # in case of modeled attenuation, multiply them as well + ########## + to_be_back_projected[i_e_1, i_e_2, i_coinc, :] = ( + np.outer( + start_det_bin_effs, + end_det_bin_effs, # multiplied start and end det els. effs for all LORs + ).ravel() + * module_pair_efficiencies[ + :, + i_e_1, + :, + i_e_2, # module pair effs for current module pair and energy bin pair + ].ravel() + ) + + # clean up the projector (stores many coordinates ...) + # del proj + i_coinc += 1 + + # setup a LM projector that we use for the sensitivity image calculation + start_coords = xp.asarray(start_coords.reshape(-1, 3)) + end_coords = xp.asarray(end_coords.reshape(-1, 3)) + + proj = parallelproj.ListmodePETProjector( + start_coords, + end_coords, + img_shape, + voxel_size, + ) + + if tof: + proj.tof_parameters = parallelproj.TOFParameters( + num_tofbins=num_tofbins, + tofbin_width=tofbin_width, + sigma_tof=sigma_tof, + ) + + if verbose: + print( + f"backprojecting all LORs in that starting at module {i_mod_1}\n" + ) + + for i_e_1 in range(num_energy_bins_1): + for i_e_2 in range(num_energy_bins_2): + if tof: + for signed_tofbin in np.arange( + -(num_tofbins // 2), num_tofbins // 2 + 1 + ): + proj.event_tofbins = xp.full( + start_coords.shape[0], + signed_tofbin, + dtype="int32", + ) + proj.tof = True + sens_img += proj.adjoint( + xp.asarray( + to_be_back_projected[i_e_1, i_e_2, :, :].ravel() + ) + ) + else: + proj.tof = False + sens_img += proj.adjoint( + xp.asarray( + to_be_back_projected[i_e_1, i_e_2, :, :].ravel() + ) + ) + + return sens_img + + +def read_listmode_prompt_events( + reader: petsird.BinaryPETSIRDReader, + header: petsird.Header, + all_detector_centers: list[np.ndarray], + store_energy_bins: bool = True, + unity_effs: bool = False, + verbose: bool = False, + flip_tofbin_sign: bool = False, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + + scanner_info: petsird.ScannerInformation = header.scanner + scanner_geom: petsird.ScannerGeometry = scanner_info.scanner_geometry + num_replicated_modules = scanner_geom.number_of_replicated_modules() + + if verbose: + print("\nReading prompt events from time blocks ...") + + i_t = 0 + + ## list of dictionaries, each dictionary contains the prompt detection bins for a time block for all module type and energy bin combinations + # all_prompt_detection_bins: list[dict[tuple[int, int], np.ndarray]] = [] + + coords0 = [] + coords1 = [] + effs = [] + signed_tof_bins = [] + energy_idx0 = [] + energy_idx1 = [] + + for time_block in reader.read_time_blocks(): + if isinstance(time_block, petsird.TimeBlock.EventTimeBlock): + start_time = time_block.value.time_interval.start + stop_time = time_block.value.time_interval.stop + if verbose: + print( + f"Processing time block {i_t} with time interval {start_time} ... {stop_time}" + ) + + # time_block_prompt_detection_bins = dict() + + for mtype0 in range(num_replicated_modules): + for mtype1 in range(num_replicated_modules): + tof_bin_edges = scanner_info.tof_bin_edges[mtype0][mtype1].edges + num_tofbins = tof_bin_edges.size - 1 + + for event in time_block.value.prompt_events[mtype0][mtype1]: + expanded_det_bin0 = expand_detection_bin( + scanner_info, mtype0, event.detection_bins[0] + ) + expanded_det_bin1 = expand_detection_bin( + scanner_info, mtype1, event.detection_bins[1] + ) + + coords0.append( + all_detector_centers[mtype0][ + expanded_det_bin0.module_index, + expanded_det_bin0.element_index, + ] + ) + coords1.append( + all_detector_centers[mtype1][ + expanded_det_bin1.module_index, + expanded_det_bin1.element_index, + ] + ) + + if flip_tofbin_sign: + signed_tof_bins.append(-(event.tof_idx - num_tofbins // 2)) + else: + signed_tof_bins.append(event.tof_idx - num_tofbins // 2) + + if unity_effs: + effs.append(1.0) + else: + effs.append( + get_detection_efficiency( + scanner_info, + petsird.TypeOfModulePair((mtype0, mtype1)), + event, + ) + ) + + if store_energy_bins: + energy_idx0.append(expanded_det_bin0.energy_index) + energy_idx1.append(expanded_det_bin1.energy_index) + + # all_prompt_detection_bins.append(time_block_prompt_detection_bins) + i_t += 1 + + # convert lists to numpy arrays + coords0 = np.array(coords0, dtype="float32") + coords1 = np.array(coords1, dtype="float32") + signed_tof_bins = np.array(signed_tof_bins, dtype="int16") + effs = np.array(effs, dtype="float32") + energy_idx0 = np.array(energy_idx0, dtype="uint16") + energy_idx1 = np.array(energy_idx1, dtype="uint16") + + return coords0, coords1, signed_tof_bins, effs, energy_idx0, energy_idx1 + + +# %%