|
| 1 | +import logging |
| 2 | +import pathlib |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import tskit |
| 6 | + |
| 7 | +from bio2zarr import constants, core, vcz |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | + |
| 12 | +class TskitFormat(vcz.Source): |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + ts_path, |
| 16 | + individuals_nodes, |
| 17 | + sample_ids=None, |
| 18 | + contig_id=None, |
| 19 | + isolated_as_missing=False, |
| 20 | + ): |
| 21 | + self._path = ts_path |
| 22 | + self.ts = tskit.load(ts_path) |
| 23 | + self.contig_id = contig_id if contig_id is not None else "1" |
| 24 | + self.isolated_as_missing = isolated_as_missing |
| 25 | + |
| 26 | + self.positions = self.ts.sites_position |
| 27 | + |
| 28 | + self._num_samples = individuals_nodes.shape[0] |
| 29 | + if self._num_samples < 1: |
| 30 | + raise ValueError("individuals_nodes must have at least one sample") |
| 31 | + self.max_ploidy = individuals_nodes.shape[1] |
| 32 | + if sample_ids is None: |
| 33 | + sample_ids = [f"tsk_{j}" for j in range(self._num_samples)] |
| 34 | + elif len(sample_ids) != self._num_samples: |
| 35 | + raise ValueError( |
| 36 | + f"Length of sample_ids ({len(sample_ids)}) does not match " |
| 37 | + f"number of samples ({self._num_samples})" |
| 38 | + ) |
| 39 | + |
| 40 | + self._samples = [vcz.Sample(id=sample_id) for sample_id in sample_ids] |
| 41 | + |
| 42 | + self.tskit_samples = np.unique(individuals_nodes[individuals_nodes >= 0]) |
| 43 | + if len(self.tskit_samples) < 1: |
| 44 | + raise ValueError("individuals_nodes must have at least one valid sample") |
| 45 | + node_id_to_index = {node_id: i for i, node_id in enumerate(self.tskit_samples)} |
| 46 | + valid_mask = individuals_nodes >= 0 |
| 47 | + self.sample_indices, self.ploidy_indices = np.where(valid_mask) |
| 48 | + self.genotype_indices = np.array( |
| 49 | + [node_id_to_index[node_id] for node_id in individuals_nodes[valid_mask]] |
| 50 | + ) |
| 51 | + |
| 52 | + @property |
| 53 | + def path(self): |
| 54 | + return self._path |
| 55 | + |
| 56 | + @property |
| 57 | + def num_records(self): |
| 58 | + return self.ts.num_sites |
| 59 | + |
| 60 | + @property |
| 61 | + def num_samples(self): |
| 62 | + return self._num_samples |
| 63 | + |
| 64 | + @property |
| 65 | + def samples(self): |
| 66 | + return self._samples |
| 67 | + |
| 68 | + @property |
| 69 | + def root_attrs(self): |
| 70 | + return {} |
| 71 | + |
| 72 | + @property |
| 73 | + def contigs(self): |
| 74 | + return [vcz.Contig(id=self.contig_id)] |
| 75 | + |
| 76 | + def iter_contig(self, start, stop): |
| 77 | + yield from (0 for _ in range(start, stop)) |
| 78 | + |
| 79 | + def iter_field(self, field_name, shape, start, stop): |
| 80 | + if field_name == "position": |
| 81 | + for pos in self.ts.sites_position[start:stop]: |
| 82 | + yield int(pos) |
| 83 | + else: |
| 84 | + raise ValueError(f"Unknown field {field_name}") |
| 85 | + |
| 86 | + def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): |
| 87 | + # All genotypes in tskit are considered phased |
| 88 | + phased = np.ones(shape[:-1], dtype=bool) |
| 89 | + |
| 90 | + for variant in self.ts.variants( |
| 91 | + isolated_as_missing=self.isolated_as_missing, |
| 92 | + left=self.positions[start], |
| 93 | + right=self.positions[stop] if stop < self.num_records else None, |
| 94 | + samples=self.tskit_samples, |
| 95 | + ): |
| 96 | + gt = np.full(shape, constants.INT_FILL, dtype=np.int8) |
| 97 | + alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") |
| 98 | + for i, allele in enumerate(variant.alleles): |
| 99 | + # None is returned by tskit in the case of a missing allele |
| 100 | + if allele is None: |
| 101 | + continue |
| 102 | + assert i < num_alleles |
| 103 | + alleles[i] = allele |
| 104 | + |
| 105 | + gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[ |
| 106 | + self.genotype_indices |
| 107 | + ] |
| 108 | + |
| 109 | + yield alleles, (gt, phased) |
| 110 | + |
| 111 | + def generate_schema( |
| 112 | + self, |
| 113 | + variants_chunk_size=None, |
| 114 | + samples_chunk_size=None, |
| 115 | + ): |
| 116 | + n = self.num_samples |
| 117 | + m = self.ts.num_sites |
| 118 | + |
| 119 | + # Determine max number of alleles |
| 120 | + max_alleles = 0 |
| 121 | + for site in self.ts.sites(): |
| 122 | + states = {site.ancestral_state} |
| 123 | + for mut in site.mutations: |
| 124 | + states.add(mut.derived_state) |
| 125 | + max_alleles = max(len(states), max_alleles) |
| 126 | + |
| 127 | + logging.info(f"Scanned tskit with {n} samples and {m} variants") |
| 128 | + logging.info( |
| 129 | + f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}" |
| 130 | + ) |
| 131 | + |
| 132 | + dimensions = { |
| 133 | + "variants": vcz.VcfZarrDimension( |
| 134 | + size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE |
| 135 | + ), |
| 136 | + "samples": vcz.VcfZarrDimension( |
| 137 | + size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE |
| 138 | + ), |
| 139 | + "ploidy": vcz.VcfZarrDimension(size=self.max_ploidy), |
| 140 | + "alleles": vcz.VcfZarrDimension(size=max_alleles), |
| 141 | + } |
| 142 | + |
| 143 | + schema_instance = vcz.VcfZarrSchema( |
| 144 | + format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, |
| 145 | + dimensions=dimensions, |
| 146 | + fields=[], |
| 147 | + ) |
| 148 | + |
| 149 | + logger.info( |
| 150 | + "Generating schema with chunks=" |
| 151 | + f"{schema_instance.dimensions['variants'].chunk_size}, " |
| 152 | + f"{schema_instance.dimensions['samples'].chunk_size}" |
| 153 | + ) |
| 154 | + |
| 155 | + # Check if positions will fit in i4 (max ~2.1 billion) |
| 156 | + min_position = 0 |
| 157 | + max_position = 0 |
| 158 | + if self.ts.num_sites > 0: |
| 159 | + min_position = np.min(self.ts.sites_position) |
| 160 | + max_position = np.max(self.ts.sites_position) |
| 161 | + |
| 162 | + array_specs = [ |
| 163 | + vcz.ZarrArraySpec( |
| 164 | + source="position", |
| 165 | + name="variant_position", |
| 166 | + dtype=core.min_int_dtype(min_position, max_position), |
| 167 | + dimensions=["variants"], |
| 168 | + description="Position of each variant", |
| 169 | + ), |
| 170 | + vcz.ZarrArraySpec( |
| 171 | + source=None, |
| 172 | + name="variant_allele", |
| 173 | + dtype="O", |
| 174 | + dimensions=["variants", "alleles"], |
| 175 | + description="Alleles for each variant", |
| 176 | + ), |
| 177 | + vcz.ZarrArraySpec( |
| 178 | + source=None, |
| 179 | + name="variant_contig", |
| 180 | + dtype=core.min_int_dtype(0, len(self.contigs)), |
| 181 | + dimensions=["variants"], |
| 182 | + description="Contig/chromosome index for each variant", |
| 183 | + ), |
| 184 | + vcz.ZarrArraySpec( |
| 185 | + source=None, |
| 186 | + name="call_genotype_phased", |
| 187 | + dtype="bool", |
| 188 | + dimensions=["variants", "samples"], |
| 189 | + description="Whether the genotype is phased", |
| 190 | + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), |
| 191 | + ), |
| 192 | + vcz.ZarrArraySpec( |
| 193 | + source=None, |
| 194 | + name="call_genotype", |
| 195 | + dtype=core.min_int_dtype(constants.INT_FILL, max_alleles - 1), |
| 196 | + dimensions=["variants", "samples", "ploidy"], |
| 197 | + description="Genotype for each variant and sample", |
| 198 | + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(), |
| 199 | + ), |
| 200 | + vcz.ZarrArraySpec( |
| 201 | + source=None, |
| 202 | + name="call_genotype_mask", |
| 203 | + dtype="bool", |
| 204 | + dimensions=["variants", "samples", "ploidy"], |
| 205 | + description="Mask for each genotype call", |
| 206 | + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), |
| 207 | + ), |
| 208 | + ] |
| 209 | + schema_instance.fields = array_specs |
| 210 | + return schema_instance |
| 211 | + |
| 212 | + |
| 213 | +def convert( |
| 214 | + ts_path, |
| 215 | + zarr_path, |
| 216 | + individuals_nodes, |
| 217 | + *, |
| 218 | + sample_ids=None, |
| 219 | + contig_id=None, |
| 220 | + isolated_as_missing=False, |
| 221 | + variants_chunk_size=None, |
| 222 | + samples_chunk_size=None, |
| 223 | + worker_processes=1, |
| 224 | + show_progress=False, |
| 225 | +): |
| 226 | + tskit_format = TskitFormat( |
| 227 | + ts_path, |
| 228 | + individuals_nodes, |
| 229 | + sample_ids=sample_ids, |
| 230 | + contig_id=contig_id, |
| 231 | + isolated_as_missing=isolated_as_missing, |
| 232 | + ) |
| 233 | + schema_instance = tskit_format.generate_schema( |
| 234 | + variants_chunk_size=variants_chunk_size, |
| 235 | + samples_chunk_size=samples_chunk_size, |
| 236 | + ) |
| 237 | + zarr_path = pathlib.Path(zarr_path) |
| 238 | + vzw = vcz.VcfZarrWriter(TskitFormat, zarr_path) |
| 239 | + # Rough heuristic to split work up enough to keep utilisation high |
| 240 | + target_num_partitions = max(1, worker_processes * 4) |
| 241 | + vzw.init( |
| 242 | + tskit_format, |
| 243 | + target_num_partitions=target_num_partitions, |
| 244 | + schema=schema_instance, |
| 245 | + ) |
| 246 | + vzw.encode_all_partitions( |
| 247 | + worker_processes=worker_processes, |
| 248 | + show_progress=show_progress, |
| 249 | + ) |
| 250 | + vzw.finalise(show_progress) |
| 251 | + vzw.create_index() |
0 commit comments