Skip to content

Commit 33c66a3

Browse files
committed
Add tskit conversion module
1 parent 5d9864d commit 33c66a3

File tree

3 files changed

+689
-1
lines changed

3 files changed

+689
-1
lines changed

bio2zarr/plink.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def generate_schema(
128128
dtype="i1",
129129
dimensions=["variants", "samples", "ploidy"],
130130
description=None,
131-
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
131+
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
132132
),
133133
vcz.ZarrArraySpec(
134134
name="call_genotype_mask",

bio2zarr/tskit.py

+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
import logging
2+
import pathlib
3+
4+
import numpy as np
5+
import tskit
6+
7+
from bio2zarr import constants, core, vcz
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class TskitFormat(vcz.Source):
13+
def __init__(
14+
self,
15+
ts_path,
16+
individuals_nodes,
17+
sample_ids=None,
18+
contig_id=None,
19+
isolated_as_missing=False,
20+
):
21+
self._path = ts_path
22+
self.ts = tskit.load(ts_path)
23+
self.contig_id = contig_id if contig_id is not None else "1"
24+
self.isolated_as_missing = isolated_as_missing
25+
26+
self.positions = self.ts.sites_position
27+
28+
self._num_samples = individuals_nodes.shape[0]
29+
if self._num_samples < 1:
30+
raise ValueError("individuals_nodes must have at least one sample")
31+
self.max_ploidy = individuals_nodes.shape[1]
32+
if sample_ids is None:
33+
sample_ids = [f"tsk_{j}" for j in range(self._num_samples)]
34+
elif len(sample_ids) != self._num_samples:
35+
raise ValueError(
36+
f"Length of sample_ids ({len(sample_ids)}) does not match "
37+
f"number of samples ({self._num_samples})"
38+
)
39+
40+
self._samples = [vcz.Sample(id=sample_id) for sample_id in sample_ids]
41+
42+
self.tskit_samples = np.unique(individuals_nodes[individuals_nodes >= 0])
43+
if len(self.tskit_samples) < 1:
44+
raise ValueError("individuals_nodes must have at least one valid sample")
45+
node_id_to_index = {node_id: i for i, node_id in enumerate(self.tskit_samples)}
46+
valid_mask = individuals_nodes >= 0
47+
self.sample_indices, self.ploidy_indices = np.where(valid_mask)
48+
self.genotype_indices = np.array(
49+
[node_id_to_index[node_id] for node_id in individuals_nodes[valid_mask]]
50+
)
51+
52+
@property
53+
def path(self):
54+
return self._path
55+
56+
@property
57+
def num_records(self):
58+
return self.ts.num_sites
59+
60+
@property
61+
def num_samples(self):
62+
return self._num_samples
63+
64+
@property
65+
def samples(self):
66+
return self._samples
67+
68+
@property
69+
def root_attrs(self):
70+
return {}
71+
72+
@property
73+
def contigs(self):
74+
return [vcz.Contig(id=self.contig_id)]
75+
76+
def iter_contig(self, start, stop):
77+
yield from (0 for _ in range(start, stop))
78+
79+
def iter_field(self, field_name, shape, start, stop):
80+
if field_name == "position":
81+
for pos in self.ts.sites_position[start:stop]:
82+
yield int(pos)
83+
else:
84+
raise ValueError(f"Unknown field {field_name}")
85+
86+
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
87+
# All genotypes in tskit are considered phased
88+
phased = np.ones(shape[:-1], dtype=bool)
89+
90+
for variant in self.ts.variants(
91+
isolated_as_missing=self.isolated_as_missing,
92+
left=self.positions[start],
93+
right=self.positions[stop] if stop < self.num_records else None,
94+
samples=self.tskit_samples,
95+
):
96+
gt = np.full(shape, constants.INT_FILL, dtype=np.int8)
97+
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
98+
for i, allele in enumerate(variant.alleles):
99+
# None is returned by tskit in the case of a missing allele
100+
if allele is None:
101+
continue
102+
assert i < num_alleles
103+
alleles[i] = allele
104+
105+
gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[
106+
self.genotype_indices
107+
]
108+
109+
yield alleles, (gt, phased)
110+
111+
def generate_schema(
112+
self,
113+
variants_chunk_size=None,
114+
samples_chunk_size=None,
115+
):
116+
n = self.num_samples
117+
m = self.ts.num_sites
118+
119+
# Determine max number of alleles
120+
max_alleles = 0
121+
for site in self.ts.sites():
122+
states = {site.ancestral_state}
123+
for mut in site.mutations:
124+
states.add(mut.derived_state)
125+
max_alleles = max(len(states), max_alleles)
126+
127+
logging.info(f"Scanned tskit with {n} samples and {m} variants")
128+
logging.info(
129+
f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}"
130+
)
131+
132+
dimensions = {
133+
"variants": vcz.VcfZarrDimension(
134+
size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE
135+
),
136+
"samples": vcz.VcfZarrDimension(
137+
size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE
138+
),
139+
"ploidy": vcz.VcfZarrDimension(size=self.max_ploidy),
140+
"alleles": vcz.VcfZarrDimension(size=max_alleles),
141+
}
142+
143+
schema_instance = vcz.VcfZarrSchema(
144+
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
145+
dimensions=dimensions,
146+
fields=[],
147+
)
148+
149+
logger.info(
150+
"Generating schema with chunks="
151+
f"{schema_instance.dimensions['variants'].chunk_size}, "
152+
f"{schema_instance.dimensions['samples'].chunk_size}"
153+
)
154+
155+
# Check if positions will fit in i4 (max ~2.1 billion)
156+
min_position = 0
157+
max_position = 0
158+
if self.ts.num_sites > 0:
159+
min_position = np.min(self.ts.sites_position)
160+
max_position = np.max(self.ts.sites_position)
161+
162+
array_specs = [
163+
vcz.ZarrArraySpec(
164+
source="position",
165+
name="variant_position",
166+
dtype=core.min_int_dtype(min_position, max_position),
167+
dimensions=["variants"],
168+
description="Position of each variant",
169+
),
170+
vcz.ZarrArraySpec(
171+
source=None,
172+
name="variant_allele",
173+
dtype="O",
174+
dimensions=["variants", "alleles"],
175+
description="Alleles for each variant",
176+
),
177+
vcz.ZarrArraySpec(
178+
source=None,
179+
name="variant_contig",
180+
dtype=core.min_int_dtype(0, len(self.contigs)),
181+
dimensions=["variants"],
182+
description="Contig/chromosome index for each variant",
183+
),
184+
vcz.ZarrArraySpec(
185+
source=None,
186+
name="call_genotype_phased",
187+
dtype="bool",
188+
dimensions=["variants", "samples"],
189+
description="Whether the genotype is phased",
190+
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
191+
),
192+
vcz.ZarrArraySpec(
193+
source=None,
194+
name="call_genotype",
195+
dtype=core.min_int_dtype(constants.INT_FILL, max_alleles - 1),
196+
dimensions=["variants", "samples", "ploidy"],
197+
description="Genotype for each variant and sample",
198+
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
199+
),
200+
vcz.ZarrArraySpec(
201+
source=None,
202+
name="call_genotype_mask",
203+
dtype="bool",
204+
dimensions=["variants", "samples", "ploidy"],
205+
description="Mask for each genotype call",
206+
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
207+
),
208+
]
209+
schema_instance.fields = array_specs
210+
return schema_instance
211+
212+
213+
def convert(
214+
ts_path,
215+
zarr_path,
216+
individuals_nodes,
217+
*,
218+
sample_ids=None,
219+
contig_id=None,
220+
isolated_as_missing=False,
221+
variants_chunk_size=None,
222+
samples_chunk_size=None,
223+
worker_processes=1,
224+
show_progress=False,
225+
):
226+
tskit_format = TskitFormat(
227+
ts_path,
228+
individuals_nodes,
229+
sample_ids=sample_ids,
230+
contig_id=contig_id,
231+
isolated_as_missing=isolated_as_missing,
232+
)
233+
schema_instance = tskit_format.generate_schema(
234+
variants_chunk_size=variants_chunk_size,
235+
samples_chunk_size=samples_chunk_size,
236+
)
237+
zarr_path = pathlib.Path(zarr_path)
238+
vzw = vcz.VcfZarrWriter(TskitFormat, zarr_path)
239+
# Rough heuristic to split work up enough to keep utilisation high
240+
target_num_partitions = max(1, worker_processes * 4)
241+
vzw.init(
242+
tskit_format,
243+
target_num_partitions=target_num_partitions,
244+
schema=schema_instance,
245+
)
246+
vzw.encode_all_partitions(
247+
worker_processes=worker_processes,
248+
show_progress=show_progress,
249+
)
250+
vzw.finalise(show_progress)
251+
vzw.create_index()

0 commit comments

Comments
 (0)