Skip to content

Commit f832363

Browse files
committed
Add tskit CLI
1 parent 598c95f commit f832363

File tree

8 files changed

+261
-54
lines changed

8 files changed

+261
-54
lines changed

bio2zarr/__main__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def bio2zarr():
1717
bio2zarr.add_command(cli.vcf2zarr_main)
1818
bio2zarr.add_command(cli.plink2zarr)
1919
bio2zarr.add_command(cli.vcfpartition)
20+
bio2zarr.add_command(cli.tskit2zarr)
2021

2122
if __name__ == "__main__":
2223
bio2zarr()

bio2zarr/cli.py

+50
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from . import icf as icf_mod
1212
from . import plink, provenance, vcf_utils
13+
from . import tskit as tskit_mod
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -630,3 +631,52 @@ def vcfpartition(vcfs, verbose, num_partitions, partition_size):
630631
)
631632
for region in regions:
632633
click.echo(f"{region}\t{vcf_path}")
634+
635+
636+
@click.command(name="convert")
637+
@click.argument("ts_path", type=click.Path(exists=True))
638+
@click.argument("zarr_path", type=click.Path())
639+
@click.option("--contig-id", type=str, help="Contig/chromosome ID (default: '1')")
640+
@click.option(
641+
"--isolated-as-missing", is_flag=True, help="Treat isolated nodes as missing"
642+
)
643+
@variants_chunk_size
644+
@samples_chunk_size
645+
@verbose
646+
@progress
647+
@worker_processes
648+
@force
649+
def convert_tskit(
650+
ts_path,
651+
zarr_path,
652+
contig_id,
653+
isolated_as_missing,
654+
variants_chunk_size,
655+
samples_chunk_size,
656+
verbose,
657+
progress,
658+
worker_processes,
659+
force,
660+
):
661+
setup_logging(verbose)
662+
check_overwrite_dir(zarr_path, force)
663+
664+
tskit_mod.convert(
665+
ts_path,
666+
zarr_path,
667+
contig_id=contig_id,
668+
isolated_as_missing=isolated_as_missing,
669+
variants_chunk_size=variants_chunk_size,
670+
samples_chunk_size=samples_chunk_size,
671+
worker_processes=worker_processes,
672+
show_progress=progress,
673+
)
674+
675+
676+
@version
677+
@click.group()
678+
def tskit2zarr():
679+
pass
680+
681+
682+
tskit2zarr.add_command(convert_tskit)

bio2zarr/tskit.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TskitFormat(vcz.Source):
1313
def __init__(
1414
self,
1515
ts_path,
16-
individuals_nodes,
16+
individuals_nodes=None,
1717
sample_ids=None,
1818
contig_id=None,
1919
isolated_as_missing=False,
@@ -24,6 +24,13 @@ def __init__(
2424
self.isolated_as_missing = isolated_as_missing
2525

2626
self.positions = self.ts.sites_position
27+
if individuals_nodes is None:
28+
if self.ts.num_individuals == 0:
29+
raise ValueError(
30+
"No individuals found in the tree sequence, use individuals_nodes "
31+
"argument to specify the individuals nodes"
32+
)
33+
individuals_nodes = self.ts.individuals_nodes
2734

2835
self._num_samples = individuals_nodes.shape[0]
2936
if self._num_samples < 1:
@@ -213,8 +220,8 @@ def generate_schema(
213220
def convert(
214221
ts_path,
215222
zarr_path,
216-
individuals_nodes,
217223
*,
224+
individuals_nodes=None,
218225
sample_ids=None,
219226
contig_id=None,
220227
isolated_as_missing=False,
@@ -225,7 +232,7 @@ def convert(
225232
):
226233
tskit_format = TskitFormat(
227234
ts_path,
228-
individuals_nodes,
235+
individuals_nodes=individuals_nodes,
229236
sample_ids=sample_ids,
230237
contig_id=contig_id,
231238
isolated_as_missing=isolated_as_missing,

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ dependencies = [
2525
# colouredlogs pulls in humanfriendly",
2626
"cyvcf2",
2727
"bed_reader",
28+
# TODO Using dev version of tskit for CI, FIXME before release
29+
"tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python",
2830
]
2931
requires-python = ">=3.9"
3032
classifiers = [
@@ -52,6 +54,7 @@ documentation = "https://sgkit-dev.github.io/bio2zarr/"
5254
[project.scripts]
5355
vcf2zarr = "bio2zarr.cli:vcf2zarr_main"
5456
vcfpartition = "bio2zarr.cli:vcfpartition"
57+
tskit2zarr = "bio2zarr.cli:tskit2zarr_main"
5558

5659
[project.optional-dependencies]
5760
dev = [

tests/data/ts/example.trees

9.92 KB
Binary file not shown.

tests/test_cli.py

+146-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@
6161
local_alleles=False,
6262
)
6363

64+
DEFAULT_TSKIT_CONVERT_ARGS = dict(
65+
contig_id=None,
66+
isolated_as_missing=False,
67+
variants_chunk_size=None,
68+
samples_chunk_size=None,
69+
show_progress=True,
70+
worker_processes=1,
71+
)
72+
6473
DEFAULT_PLINK_CONVERT_ARGS = dict(
6574
variants_chunk_size=None,
6675
samples_chunk_size=None,
@@ -635,6 +644,116 @@ def test_vcf_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response
635644
(self.vcf_path,), str(zarr_path), **DEFAULT_CONVERT_ARGS
636645
)
637646

647+
@pytest.mark.parametrize(("progress", "flag"), [(True, "-P"), (False, "-Q")])
648+
@mock.patch("bio2zarr.tskit.convert")
649+
def test_convert_tskit(self, mocked, tmp_path, progress, flag):
650+
ts_path = "tests/data/ts/example.trees"
651+
zarr_path = tmp_path / "zarr"
652+
runner = ct.CliRunner(mix_stderr=False)
653+
result = runner.invoke(
654+
cli.tskit2zarr,
655+
f"convert {ts_path} {zarr_path} {flag}",
656+
catch_exceptions=False,
657+
)
658+
assert result.exit_code == 0
659+
assert len(result.stdout) == 0
660+
assert len(result.stderr) == 0
661+
args = dict(DEFAULT_TSKIT_CONVERT_ARGS)
662+
args["show_progress"] = progress
663+
mocked.assert_called_once_with(
664+
ts_path,
665+
str(zarr_path),
666+
**args,
667+
)
668+
669+
@pytest.mark.parametrize("response", ["y", "Y", "yes"])
670+
@mock.patch("bio2zarr.tskit.convert")
671+
def test_tskit_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response):
672+
ts_path = "tests/data/ts/example.trees"
673+
zarr_path = tmp_path / "zarr"
674+
zarr_path.mkdir()
675+
runner = ct.CliRunner(mix_stderr=False)
676+
result = runner.invoke(
677+
cli.tskit2zarr,
678+
f"convert {ts_path} {zarr_path}",
679+
catch_exceptions=False,
680+
input=response,
681+
)
682+
assert result.exit_code == 0
683+
assert f"Do you want to overwrite {zarr_path}" in result.stdout
684+
assert len(result.stderr) == 0
685+
mocked.assert_called_once_with(
686+
ts_path,
687+
str(zarr_path),
688+
**DEFAULT_TSKIT_CONVERT_ARGS,
689+
)
690+
691+
@pytest.mark.parametrize("response", ["n", "N", "No"])
692+
@mock.patch("bio2zarr.tskit.convert")
693+
def test_tskit_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, response):
694+
ts_path = "tests/data/ts/example.trees"
695+
zarr_path = tmp_path / "zarr"
696+
zarr_path.mkdir()
697+
runner = ct.CliRunner(mix_stderr=False)
698+
result = runner.invoke(
699+
cli.tskit2zarr,
700+
f"convert {ts_path} {zarr_path}",
701+
catch_exceptions=False,
702+
input=response,
703+
)
704+
assert result.exit_code == 1
705+
assert "Aborted" in result.stderr
706+
mocked.assert_not_called()
707+
708+
@pytest.mark.parametrize("force_arg", ["-f", "--force"])
709+
@mock.patch("bio2zarr.tskit.convert")
710+
def test_tskit_convert_overwrite_zarr_force(self, mocked, tmp_path, force_arg):
711+
ts_path = "tests/data/ts/example.trees"
712+
zarr_path = tmp_path / "zarr"
713+
zarr_path.mkdir()
714+
runner = ct.CliRunner(mix_stderr=False)
715+
result = runner.invoke(
716+
cli.tskit2zarr,
717+
f"convert {ts_path} {zarr_path} {force_arg}",
718+
catch_exceptions=False,
719+
)
720+
assert result.exit_code == 0
721+
assert len(result.stdout) == 0
722+
assert len(result.stderr) == 0
723+
mocked.assert_called_once_with(
724+
ts_path,
725+
str(zarr_path),
726+
**DEFAULT_TSKIT_CONVERT_ARGS,
727+
)
728+
729+
@mock.patch("bio2zarr.tskit.convert")
730+
def test_tskit_convert_with_options(self, mocked, tmp_path):
731+
ts_path = "tests/data/ts/example.trees"
732+
zarr_path = tmp_path / "zarr"
733+
runner = ct.CliRunner(mix_stderr=False)
734+
result = runner.invoke(
735+
cli.tskit2zarr,
736+
f"convert {ts_path} {zarr_path} --contig-id chr1 "
737+
"--isolated-as-missing -l 100 -w 50 -p 4",
738+
catch_exceptions=False,
739+
)
740+
assert result.exit_code == 0
741+
assert len(result.stdout) == 0
742+
assert len(result.stderr) == 0
743+
744+
expected_args = dict(DEFAULT_TSKIT_CONVERT_ARGS)
745+
expected_args["contig_id"] = "chr1"
746+
expected_args["isolated_as_missing"] = True
747+
expected_args["variants_chunk_size"] = 100
748+
expected_args["samples_chunk_size"] = 50
749+
expected_args["worker_processes"] = 4
750+
751+
mocked.assert_called_once_with(
752+
ts_path,
753+
str(zarr_path),
754+
**expected_args,
755+
)
756+
638757

639758
class TestVcfEndToEnd:
640759
vcf_path = "tests/data/vcf/sample.vcf.gz"
@@ -908,10 +1027,36 @@ def test_part_size_multiple_vcfs(self):
9081027

9091028

9101029
@pytest.mark.parametrize(
911-
"cmd", [main.bio2zarr, cli.vcf2zarr_main, cli.plink2zarr, cli.vcfpartition]
1030+
"cmd",
1031+
[
1032+
main.bio2zarr,
1033+
cli.vcf2zarr_main,
1034+
cli.plink2zarr,
1035+
cli.vcfpartition,
1036+
cli.tskit2zarr,
1037+
],
9121038
)
9131039
def test_version(cmd):
9141040
runner = ct.CliRunner(mix_stderr=False)
9151041
result = runner.invoke(cmd, ["--version"], catch_exceptions=False)
9161042
s = f"version {provenance.__version__}\n"
9171043
assert result.stdout.endswith(s)
1044+
1045+
1046+
class TestTskitEndToEnd:
1047+
def test_convert(self, tmp_path):
1048+
ts_path = "tests/data/ts/example.trees"
1049+
zarr_path = tmp_path / "zarr"
1050+
runner = ct.CliRunner(mix_stderr=False)
1051+
result = runner.invoke(
1052+
cli.tskit2zarr,
1053+
f"convert {ts_path} {zarr_path}",
1054+
catch_exceptions=False,
1055+
)
1056+
assert result.exit_code == 0
1057+
result = runner.invoke(
1058+
cli.vcf2zarr_main, f"inspect {zarr_path}", catch_exceptions=False
1059+
)
1060+
assert result.exit_code == 0
1061+
# Arbitrary check
1062+
assert "variant_position" in result.stdout

tests/test_core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_examples(self, chunk_size, size, start, stop):
237237
# It works in CI on Linux, but it'll probably break at some point.
238238
# It's also necessary to update these numbers each time a new data
239239
# file gets added
240-
("tests/data", 5030777),
240+
("tests/data", 5045029),
241241
("tests/data/vcf", 5018640),
242242
("tests/data/vcf/sample.vcf.gz", 1089),
243243
],

0 commit comments

Comments
 (0)