Skip to content

Commit ab86815

Browse files
committed
Fix map_to_vcf_model tests
1 parent 22ef437 commit ab86815

File tree

2 files changed

+48
-37
lines changed

2 files changed

+48
-37
lines changed

python/tests/test_highlevel.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5747,23 +5747,15 @@ def test_specific_individuals(self):
57475747
result = ts.map_to_vcf_model(individuals=[1, 3])
57485748
assert isinstance(result, tskit.VcfModelMapping)
57495749
# Individual 1 has ploidy 2, individual 3 has ploidy 4
5750-
assert result.individuals_nodes.shape == (2, 4)
5751-
5752-
assert result.individuals_nodes[0, 0] == 1
5753-
assert result.individuals_nodes[0, 1] == 2
5754-
assert result.individuals_nodes[0, 2] == -1
5755-
assert result.individuals_nodes[0, 3] == -1
5756-
5757-
assert result.individuals_nodes[1, 0] == 6
5758-
assert result.individuals_nodes[1, 1] == 7
5759-
assert result.individuals_nodes[1, 2] == 8
5760-
assert result.individuals_nodes[1, 3] == 9
5750+
assert result.individuals_nodes.shape == (2, 5)
5751+
assert np.array_equal(result.individuals_nodes[0], [1, 2, -1, -1, -1])
5752+
assert np.array_equal(result.individuals_nodes[1], [6, 7, 8, 9, -1])
57615753

57625754
assert result.individuals_name.shape == (2,)
57635755
assert result.individuals_name[0] == "tsk_1"
57645756
assert result.individuals_name[1] == "tsk_3"
57655757

5766-
def test_individual_with_no_nodes_warning(self):
5758+
def test_individual_with_no_nodes(self):
57675759
tables = tskit.TableCollection(1.0)
57685760
# Individual with no nodes
57695761
tables.individuals.add_row()
@@ -5772,18 +5764,11 @@ def test_individual_with_no_nodes_warning(self):
57725764
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
57735765
ts = tables.tree_sequence()
57745766

5775-
with warnings.catch_warnings(record=True) as w:
5776-
result = ts.map_to_vcf_model()
5777-
assert len(w) == 1
5778-
assert "Individual 0 has no nodes" in str(w[0].message)
5779-
5780-
# Should only include individual 1
5781-
assert result.individuals_nodes.shape == (1, 1)
5782-
assert result.individuals_nodes[0, 0] == 0
5783-
assert result.individuals_name.shape == (1,)
5784-
assert result.individuals_name[0] == "tsk_1"
5767+
result = ts.map_to_vcf_model()
5768+
assert result.individuals_nodes.shape == (2, 1)
5769+
assert np.array_equal(result.individuals_nodes, [[-1], [0]])
57855770

5786-
def test_individual_with_no_nodes_error(self):
5771+
def test_individual_with_no_nodes_only(self):
57875772
tables = tskit.TableCollection(1.0)
57885773
# Individual with no nodes
57895774
tables.individuals.add_row()
@@ -5792,8 +5777,9 @@ def test_individual_with_no_nodes_error(self):
57925777
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
57935778
ts = tables.tree_sequence()
57945779

5795-
with pytest.raises(ValueError, match="Individual 0 has no nodes"):
5796-
ts.map_to_vcf_model(individuals=[0])
5780+
result = ts.map_to_vcf_model(individuals=[0])
5781+
assert result.individuals_nodes.shape == (1, 1)
5782+
assert np.array_equal(result.individuals_nodes, [[-1]])
57975783

57985784
def test_invalid_individual_id(self):
57995785
tables = tskit.TableCollection(1.0)
@@ -5807,19 +5793,31 @@ def test_invalid_individual_id(self):
58075793
with pytest.raises(ValueError, match="Invalid individual ID"):
58085794
ts.map_to_vcf_model(individuals=[1])
58095795

5810-
def test_mixed_sample_non_sample_warning(self):
5796+
def test_mixed_sample_non_sample_ordering(self):
58115797
tables = tskit.TableCollection(1.0)
58125798
tables.individuals.add_row()
58135799
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
58145800
tables.nodes.add_row(flags=0, time=0, individual=0) # Non-sample node
5801+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
5802+
tables.nodes.add_row(flags=0, time=0, individual=0) # Non-sample node
5803+
tables.individuals.add_row()
5804+
tables.nodes.add_row(flags=0, time=0, individual=1) # Non-sample node
5805+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
58155806
ts = tables.tree_sequence()
58165807

5817-
with warnings.catch_warnings(record=True) as w:
5818-
ts.map_to_vcf_model()
5819-
assert len(w) == 1
5820-
assert "Individual 0 has both sample and non-sample nodes" in str(
5821-
w[0].message
5822-
)
5808+
result = ts.map_to_vcf_model()
5809+
assert result.individuals_nodes.shape == (2, 4)
5810+
assert np.array_equal(
5811+
result.individuals_nodes,
5812+
np.array([[0, 2, -1, -1], [5, -1, -1, -1]]),
5813+
)
5814+
5815+
result = ts.map_to_vcf_model(include_non_sample_nodes=True)
5816+
assert result.individuals_nodes.shape == (2, 4)
5817+
assert np.array_equal(
5818+
result.individuals_nodes,
5819+
np.array([[0, 1, 2, 3], [4, 5, -1, -1]]),
5820+
)
58235821

58245822
def test_samples_without_individuals_warning(self):
58255823
tables = tskit.TableCollection(1.0)
@@ -5898,3 +5896,11 @@ def test_name_count_mismatch_error(self):
58985896
ValueError, match="number of individuals does not match the number of names"
58995897
):
59005898
ts.map_to_vcf_model(individual_names=["only_one_name"])
5899+
5900+
def test_all_individuals_no_nodes(self):
5901+
tables = tskit.TableCollection(1.0)
5902+
tables.individuals.add_row()
5903+
tables.individuals.add_row()
5904+
ts = tables.tree_sequence()
5905+
result = ts.map_to_vcf_model()
5906+
assert result.individuals_nodes.shape == (2, 0)

python/tskit/trees.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10608,12 +10608,17 @@ def map_to_vcf_model(
1060810608
raise ValueError("Invalid individual ID")
1060910609

1061010610
individuals_nodes = self.individuals_nodes[individuals]
10611-
if not include_non_sample_nodes:
10612-
individuals_nodes[
10613-
np.logical_not(
10614-
self.nodes_flags[individuals_nodes] & tskit.NODE_IS_SAMPLE
10611+
non_sample_nodes = np.logical_not(
10612+
self.nodes_flags[individuals_nodes] & tskit.NODE_IS_SAMPLE
10613+
)
10614+
if np.any(non_sample_nodes) and not include_non_sample_nodes:
10615+
individuals_nodes[non_sample_nodes] = -1
10616+
rows_to_reorder = np.any(non_sample_nodes, axis=1)
10617+
for i in np.where(rows_to_reorder)[0]:
10618+
row = individuals_nodes[i]
10619+
individuals_nodes[i] = np.concatenate(
10620+
[row[row != -1], row[row == -1]]
1061510621
)
10616-
] = -1
1061710622

1061810623
if individual_names is None:
1061910624
if name_metadata_key is not None:

0 commit comments

Comments
 (0)