diff --git a/bio2zarr/__main__.py b/bio2zarr/__main__.py index cab080b..1372865 100644 --- a/bio2zarr/__main__.py +++ b/bio2zarr/__main__.py @@ -17,6 +17,7 @@ def bio2zarr(): bio2zarr.add_command(cli.vcf2zarr_main) bio2zarr.add_command(cli.plink2zarr) bio2zarr.add_command(cli.vcfpartition) +bio2zarr.add_command(cli.tskit2zarr) if __name__ == "__main__": bio2zarr() diff --git a/bio2zarr/cli.py b/bio2zarr/cli.py index cfe6458..a2ae88d 100644 --- a/bio2zarr/cli.py +++ b/bio2zarr/cli.py @@ -9,6 +9,7 @@ import tabulate from . import plink, provenance, vcf_utils +from . import tskit as tskit_mod from . import vcf as vcf_mod logger = logging.getLogger(__name__) @@ -630,3 +631,52 @@ def vcfpartition(vcfs, verbose, num_partitions, partition_size): ) for region in regions: click.echo(f"{region}\t{vcf_path}") + + +@click.command(name="convert") +@click.argument("ts_path", type=click.Path(exists=True)) +@click.argument("zarr_path", type=click.Path()) +@click.option("--contig-id", type=str, help="Contig/chromosome ID (default: '1')") +@click.option( + "--isolated-as-missing", is_flag=True, help="Treat isolated nodes as missing" +) +@variants_chunk_size +@samples_chunk_size +@verbose +@progress +@worker_processes +@force +def convert_tskit( + ts_path, + zarr_path, + contig_id, + isolated_as_missing, + variants_chunk_size, + samples_chunk_size, + verbose, + progress, + worker_processes, + force, +): + setup_logging(verbose) + check_overwrite_dir(zarr_path, force) + + tskit_mod.convert( + ts_path, + zarr_path, + contig_id=contig_id, + isolated_as_missing=isolated_as_missing, + variants_chunk_size=variants_chunk_size, + samples_chunk_size=samples_chunk_size, + worker_processes=worker_processes, + show_progress=progress, + ) + + +@version +@click.group() +def tskit2zarr(): + pass + + +tskit2zarr.add_command(convert_tskit) diff --git a/bio2zarr/tskit.py b/bio2zarr/tskit.py index eb68ad4..4dcceda 100644 --- a/bio2zarr/tskit.py +++ b/bio2zarr/tskit.py @@ -13,7 +13,7 @@ class TskitFormat(vcz.Source): def __init__( self, ts_path, - individuals_nodes, + individuals_nodes=None, sample_ids=None, contig_id=None, isolated_as_missing=False, @@ -25,6 +25,9 @@ def __init__( self.positions = self.ts.sites_position + if individuals_nodes is None: + individuals_nodes = self.ts.individuals_nodes + self._num_samples = individuals_nodes.shape[0] if self._num_samples < 1: raise ValueError("individuals_nodes must have at least one sample") @@ -213,8 +216,8 @@ def generate_schema( def convert( ts_path, zarr_path, - individuals_nodes, *, + individuals_nodes=None, sample_ids=None, contig_id=None, isolated_as_missing=False, @@ -225,7 +228,7 @@ def convert( ): tskit_format = TskitFormat( ts_path, - individuals_nodes, + individuals_nodes=individuals_nodes, sample_ids=sample_ids, contig_id=contig_id, isolated_as_missing=isolated_as_missing, diff --git a/pyproject.toml b/pyproject.toml index f847f5b..f838e08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ dependencies = [ # colouredlogs pulls in humanfriendly", "cyvcf2", "bed_reader", + # TODO Using dev version of tskit for CI, FIXME before release + "tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python", ] requires-python = ">=3.10" classifiers = [ @@ -51,6 +53,7 @@ documentation = "https://sgkit-dev.github.io/bio2zarr/" [project.scripts] vcf2zarr = "bio2zarr.cli:vcf2zarr_main" vcfpartition = "bio2zarr.cli:vcfpartition" +tskit2zarr = "bio2zarr.cli:tskit2zarr_main" [project.optional-dependencies] dev = [ diff --git a/tests/data/ts/example.trees b/tests/data/ts/example.trees new file mode 100644 index 0000000..4910ec2 Binary files /dev/null and b/tests/data/ts/example.trees differ diff --git a/tests/test_cli.py b/tests/test_cli.py index ead3489..2cc133b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -61,6 +61,15 @@ local_alleles=False, ) +DEFAULT_TSKIT_CONVERT_ARGS = dict( + contig_id=None, + isolated_as_missing=False, + variants_chunk_size=None, + samples_chunk_size=None, + show_progress=True, + worker_processes=1, +) + DEFAULT_PLINK_CONVERT_ARGS = dict( variants_chunk_size=None, samples_chunk_size=None, @@ -635,6 +644,116 @@ def test_vcf_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response (self.vcf_path,), str(zarr_path), **DEFAULT_CONVERT_ARGS ) + @pytest.mark.parametrize(("progress", "flag"), [(True, "-P"), (False, "-Q")]) + @mock.patch("bio2zarr.tskit.convert") + def test_convert_tskit(self, mocked, tmp_path, progress, flag): + ts_path = "tests/data/ts/example.trees" + zarr_path = tmp_path / "zarr" + runner = ct.CliRunner() + result = runner.invoke( + cli.tskit2zarr, + f"convert {ts_path} {zarr_path} {flag}", + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert len(result.stdout) == 0 + assert len(result.stderr) == 0 + args = dict(DEFAULT_TSKIT_CONVERT_ARGS) + args["show_progress"] = progress + mocked.assert_called_once_with( + ts_path, + str(zarr_path), + **args, + ) + + @pytest.mark.parametrize("response", ["y", "Y", "yes"]) + @mock.patch("bio2zarr.tskit.convert") + def test_tskit_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response): + ts_path = "tests/data/ts/example.trees" + zarr_path = tmp_path / "zarr" + zarr_path.mkdir() + runner = ct.CliRunner() + result = runner.invoke( + cli.tskit2zarr, + f"convert {ts_path} {zarr_path}", + catch_exceptions=False, + input=response, + ) + assert result.exit_code == 0 + assert f"Do you want to overwrite {zarr_path}" in result.stdout + assert len(result.stderr) == 0 + mocked.assert_called_once_with( + ts_path, + str(zarr_path), + **DEFAULT_TSKIT_CONVERT_ARGS, + ) + + @pytest.mark.parametrize("response", ["n", "N", "No"]) + @mock.patch("bio2zarr.tskit.convert") + def test_tskit_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, response): + ts_path = "tests/data/ts/example.trees" + zarr_path = tmp_path / "zarr" + zarr_path.mkdir() + runner = ct.CliRunner() + result = runner.invoke( + cli.tskit2zarr, + f"convert {ts_path} {zarr_path}", + catch_exceptions=False, + input=response, + ) + assert result.exit_code == 1 + assert "Aborted" in result.stderr + mocked.assert_not_called() + + @pytest.mark.parametrize("force_arg", ["-f", "--force"]) + @mock.patch("bio2zarr.tskit.convert") + def test_tskit_convert_overwrite_zarr_force(self, mocked, tmp_path, force_arg): + ts_path = "tests/data/ts/example.trees" + zarr_path = tmp_path / "zarr" + zarr_path.mkdir() + runner = ct.CliRunner() + result = runner.invoke( + cli.tskit2zarr, + f"convert {ts_path} {zarr_path} {force_arg}", + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert len(result.stdout) == 0 + assert len(result.stderr) == 0 + mocked.assert_called_once_with( + ts_path, + str(zarr_path), + **DEFAULT_TSKIT_CONVERT_ARGS, + ) + + @mock.patch("bio2zarr.tskit.convert") + def test_tskit_convert_with_options(self, mocked, tmp_path): + ts_path = "tests/data/ts/example.trees" + zarr_path = tmp_path / "zarr" + runner = ct.CliRunner() + result = runner.invoke( + cli.tskit2zarr, + f"convert {ts_path} {zarr_path} --contig-id chr1 " + "--isolated-as-missing -l 100 -w 50 -p 4", + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert len(result.stdout) == 0 + assert len(result.stderr) == 0 + + expected_args = dict(DEFAULT_TSKIT_CONVERT_ARGS) + expected_args["contig_id"] = "chr1" + expected_args["isolated_as_missing"] = True + expected_args["variants_chunk_size"] = 100 + expected_args["samples_chunk_size"] = 50 + expected_args["worker_processes"] = 4 + + mocked.assert_called_once_with( + ts_path, + str(zarr_path), + **expected_args, + ) + class TestVcfEndToEnd: vcf_path = "tests/data/vcf/sample.vcf.gz" @@ -908,10 +1027,36 @@ def test_part_size_multiple_vcfs(self): @pytest.mark.parametrize( - "cmd", [main.bio2zarr, cli.vcf2zarr_main, cli.plink2zarr, cli.vcfpartition] + "cmd", + [ + main.bio2zarr, + cli.vcf2zarr_main, + cli.plink2zarr, + cli.vcfpartition, + cli.tskit2zarr, + ], ) def test_version(cmd): runner = ct.CliRunner() result = runner.invoke(cmd, ["--version"], catch_exceptions=False) s = f"version {provenance.__version__}\n" assert result.stdout.endswith(s) + + +class TestTskitEndToEnd: + def test_convert(self, tmp_path): + ts_path = "tests/data/ts/example.trees" + zarr_path = tmp_path / "zarr" + runner = ct.CliRunner() + result = runner.invoke( + cli.tskit2zarr, + f"convert {ts_path} {zarr_path}", + catch_exceptions=False, + ) + assert result.exit_code == 0 + result = runner.invoke( + cli.vcf2zarr_main, f"inspect {zarr_path}", catch_exceptions=False + ) + assert result.exit_code == 0 + # Arbitrary check + assert "variant_position" in result.stdout diff --git a/tests/test_core.py b/tests/test_core.py index 03f380f..5619fd3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -237,7 +237,7 @@ def test_examples(self, chunk_size, size, start, stop): # It works in CI on Linux, but it'll probably break at some point. # It's also necessary to update these numbers each time a new data # file gets added - ("tests/data", 5030777), + ("tests/data", 5045029), ("tests/data/vcf", 5018640), ("tests/data/vcf/sample.vcf.gz", 1089), ], diff --git a/tests/test_ts.py b/tests/test_ts.py index 724df7b..9c19750 100644 --- a/tests/test_ts.py +++ b/tests/test_ts.py @@ -32,12 +32,17 @@ def test_simple_tree_sequence(self, tmp_path): tree_sequence = tables.tree_sequence() tree_sequence.dump(tmp_path / "test.trees") + # Manually specify the individuals_nodes, other tests use + # ts individuals. ind_nodes = np.array([[0, 1], [2, 3]]) with tempfile.TemporaryDirectory() as tempdir: zarr_path = os.path.join(tempdir, "test_output.zarr") ts.convert( - tmp_path / "test.trees", zarr_path, ind_nodes, show_progress=False + tmp_path / "test.trees", + zarr_path, + individuals_nodes=ind_nodes, + show_progress=False, ) zroot = zarr.open(zarr_path, mode="r") assert zroot["variant_position"].shape == (3,) @@ -70,10 +75,12 @@ class TestTskitFormat: @pytest.fixture() def simple_ts(self, tmp_path): tables = tskit.TableCollection(sequence_length=100) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.individuals.add_row() + tables.individuals.add_row() + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1) tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1 tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3 tables.edges.add_row(left=0, right=100, parent=4, child=0) @@ -133,7 +140,7 @@ def test_position_dtype_selection(self, tmp_path): ts_large.dump(ts_path_large) ind_nodes = np.array([[0], [1]]) - format_obj_small = ts.TskitFormat(ts_path_small, ind_nodes) + format_obj_small = ts.TskitFormat(ts_path_small, individuals_nodes=ind_nodes) schema_small = format_obj_small.generate_schema() position_field = next( @@ -141,7 +148,7 @@ def test_position_dtype_selection(self, tmp_path): ) assert position_field.dtype == "i1" - format_obj_large = ts.TskitFormat(ts_path_large, ind_nodes) + format_obj_large = ts.TskitFormat(ts_path_large, individuals_nodes=ind_nodes) schema_large = format_obj_large.generate_schema() position_field = next( @@ -151,10 +158,9 @@ def test_position_dtype_selection(self, tmp_path): def test_initialization(self, simple_ts): ts_path, tree_sequence = simple_ts - ind_nodes = np.array([[0, 1], [2, 3]]) # Test with default parameters - format_obj = ts.TskitFormat(ts_path, ind_nodes) + format_obj = ts.TskitFormat(ts_path) assert format_obj.path == ts_path assert format_obj.ts.num_sites == tree_sequence.num_sites assert format_obj.contig_id == "1" @@ -163,7 +169,6 @@ def test_initialization(self, simple_ts): # Test with custom parameters format_obj = ts.TskitFormat( ts_path, - ind_nodes, sample_ids=["ind1", "ind2"], contig_id="chr1", isolated_as_missing=True, @@ -176,8 +181,7 @@ def test_initialization(self, simple_ts): def test_basic_properties(self, simple_ts): ts_path, _ = simple_ts - ind_nodes = np.array([[0, 1], [2, 3]]) - format_obj = ts.TskitFormat(ts_path, ind_nodes) + format_obj = ts.TskitFormat(ts_path) assert format_obj.num_records == format_obj.ts.num_sites assert format_obj.num_samples == 2 # Two individuals @@ -193,9 +197,8 @@ def test_basic_properties(self, simple_ts): def test_custom_sample_ids(self, simple_ts): ts_path, _ = simple_ts - ind_nodes = np.array([[0, 1], [2, 3]]) custom_ids = ["sample_X", "sample_Y"] - format_obj = ts.TskitFormat(ts_path, ind_nodes, sample_ids=custom_ids) + format_obj = ts.TskitFormat(ts_path, sample_ids=custom_ids) assert format_obj.num_samples == 2 assert len(format_obj.samples) == 2 @@ -204,15 +207,13 @@ def test_custom_sample_ids(self, simple_ts): def test_sample_id_length_mismatch(self, simple_ts): ts_path, _ = simple_ts - ind_nodes = np.array([[0, 1], [2, 3]]) # Wrong number of sample IDs with pytest.raises(ValueError, match="Length of sample_ids.*does not match"): - ts.TskitFormat(ts_path, ind_nodes, sample_ids=["only_one_id"]) + ts.TskitFormat(ts_path, sample_ids=["only_one_id"]) def test_schema_generation(self, simple_ts): ts_path, _ = simple_ts - ind_nodes = np.array([[0, 1], [2, 3]]) - format_obj = ts.TskitFormat(ts_path, ind_nodes) + format_obj = ts.TskitFormat(ts_path) schema = format_obj.generate_schema() assert schema.dimensions["variants"].size == 3 @@ -234,15 +235,13 @@ def test_schema_generation(self, simple_ts): def test_iter_contig(self, simple_ts): ts_path, _ = simple_ts - ind_nodes = np.array([[0, 1], [2, 3]]) - format_obj = ts.TskitFormat(ts_path, ind_nodes) + format_obj = ts.TskitFormat(ts_path) contig_indices = list(format_obj.iter_contig(1, 3)) assert contig_indices == [0, 0] def test_iter_field(self, simple_ts): ts_path, _ = simple_ts - ind_nodes = np.array([[0, 1], [2, 3]]) - format_obj = ts.TskitFormat(ts_path, ind_nodes) + format_obj = ts.TskitFormat(ts_path) positions = list(format_obj.iter_field("position", None, 0, 3)) assert positions == [10, 20, 30] positions = list(format_obj.iter_field("position", None, 1, 3)) @@ -288,7 +287,7 @@ def test_iter_field(self, simple_ts): def test_iter_alleles_and_genotypes(self, simple_ts, ind_nodes, expected_gts): ts_path, _ = simple_ts - format_obj = ts.TskitFormat(ts_path, ind_nodes) + format_obj = ts.TskitFormat(ts_path, individuals_nodes=ind_nodes) shape = (2, 2) # (num_samples, max_ploidy) results = list(format_obj.iter_alleles_and_genotypes(0, 3, shape, 2)) @@ -314,7 +313,7 @@ def test_iter_alleles_and_genotypes_errors(self, simple_ts): # Test with node ID that doesn't exist in tree sequence (out of range) invalid_nodes = np.array([[10, 11], [12, 13]], dtype=np.int32) - format_obj = ts.TskitFormat(ts_path, invalid_nodes) + format_obj = ts.TskitFormat(ts_path, individuals_nodes=invalid_nodes) shape = (2, 2) with pytest.raises( tskit.LibraryError, match="out of bounds" @@ -326,14 +325,14 @@ def test_iter_alleles_and_genotypes_errors(self, simple_ts): with pytest.raises( ValueError, match="individuals_nodes must have at least one sample" ): - format_obj = ts.TskitFormat(ts_path, empty_nodes) + format_obj = ts.TskitFormat(ts_path, individuals_nodes=empty_nodes) # Test with all invalid nodes (-1) all_invalid = np.full((2, 2), -1, dtype=np.int32) with pytest.raises( ValueError, match="individuals_nodes must have at least one valid sample" ): - format_obj = ts.TskitFormat(ts_path, all_invalid) + format_obj = ts.TskitFormat(ts_path, individuals_nodes=all_invalid) def test_isolated_as_missing(self, tmp_path): def insert_branch_sites(ts, m=1): @@ -366,7 +365,7 @@ def insert_branch_sites(ts, m=1): tree_sequence.dump(ts_path) ind_nodes = np.array([[0], [1], [3]]) format_obj_default = ts.TskitFormat( - ts_path, ind_nodes, isolated_as_missing=False + ts_path, individuals_nodes=ind_nodes, isolated_as_missing=False ) shape = (3, 1) # (num_samples, max_ploidy) results_default = list( @@ -382,7 +381,7 @@ def insert_branch_sites(ts, m=1): assert np.array_equal(gt_default, expected_gt_default) format_obj_missing = ts.TskitFormat( - ts_path, ind_nodes, isolated_as_missing=True + ts_path, individuals_nodes=ind_nodes, isolated_as_missing=True ) results_missing = list( format_obj_missing.iter_alleles_and_genotypes(0, 1, shape, 2) @@ -411,7 +410,7 @@ def test_genotype_dtype_selection(self, tmp_path): tree_sequence.dump(ts_path) ind_nodes = np.array([[0, 1], [2, 3]]) - format_obj = ts.TskitFormat(ts_path, ind_nodes) + format_obj = ts.TskitFormat(ts_path, individuals_nodes=ind_nodes) schema = format_obj.generate_schema() call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype") assert call_genotype_spec.dtype == "i1" @@ -431,7 +430,7 @@ def test_genotype_dtype_selection(self, tmp_path): ts_path = tmp_path / "large_alleles.trees" tree_sequence.dump(ts_path) - format_obj = ts.TskitFormat(ts_path, ind_nodes) + format_obj = ts.TskitFormat(ts_path, individuals_nodes=ind_nodes) schema = format_obj.generate_schema() call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype") assert call_genotype_spec.dtype == "i4"