Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 94 additions & 76 deletions hail/python/hail/vds/variant_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,115 +270,133 @@ def reference_genome(self) -> ReferenceGenome:
@typecheck_method(check_data=bool)
def validate(self, *, check_data: bool = True):
"""Eagerly checks necessary representational properties of the VDS."""
self.__validate_ref_row_key()
self.__validate_var_row_key()
self.__validate_ref_col_key()
self.__validate_var_col_key()

rd = self.reference_data
vd = self.variant_data
self.__validate_len_field()
self.__validate_filters_fields()

rd_row_key = rd.row_key.dtype
if check_data:
self.__validate_data()

def __validate_ref_row_key(self):
rd_row_key = self.reference_data.row_key.dtype
if (
not isinstance(rd_row_key, hl.tstruct)
or len(rd_row_key) != 1
or not rd_row_key.fields[0] == 'locus'
or not isinstance(rd_row_key.types[0], hl.tlocus)
):
_validate_err(f"expect reference data to have a single row key 'locus' of type locus, found {rd_row_key}")
raise _validate_err(
f"expect reference data to have a single row key 'locus' of type locus, found {rd_row_key}"
)

vd_row_key = vd.row_key.dtype
def __validate_var_row_key(self):
vd_row_key = self.variant_data.row_key.dtype
if (
not isinstance(vd_row_key, hl.tstruct)
or len(vd_row_key) != 2
or not vd_row_key.fields == ('locus', 'alleles')
or not isinstance(vd_row_key.types[0], hl.tlocus)
or vd_row_key.types[1] != hl.tarray(hl.tstr)
):
_validate_err(
raise _validate_err(
f"expect variant data to have a row key {{'locus': locus<rg>, alleles: array<str>}}, found {vd_row_key}"
)

rd_col_key = rd.col_key.dtype
if not isinstance(rd_col_key, hl.tstruct) or len(rd_row_key) != 1 or rd_col_key.types[0] != hl.tstr:
_validate_err(f"expect reference data to have a single col key of type string, found {rd_col_key}")
def __validate_ref_col_key(self):
rd_col_key = self.reference_data.col_key.dtype
if not isinstance(rd_col_key, hl.tstruct) or len(rd_col_key) != 1 or rd_col_key.types[0] != hl.tstr:
raise _validate_err(f"expect reference data to have a single col key of type string, found {rd_col_key}")

vd_col_key = vd.col_key.dtype
def __validate_var_col_key(self):
vd_col_key = self.variant_data.col_key.dtype
if not isinstance(vd_col_key, hl.tstruct) or len(vd_col_key) != 1 or vd_col_key.types[0] != hl.tstr:
_validate_err(f"expect variant data to have a single col key of type string, found {vd_col_key}")
raise _validate_err(f"expect variant data to have a single col key of type string, found {vd_col_key}")

def __validate_filters_fields(self):
field = 'gvcf_filters'
ref_has = field in self.reference_data.entry
var_has = field in self.variant_data.entry
if ref_has == var_has:
return
if ref_has and not var_has:
raise _validate_err(f"reference data has '{field}' when variant data does not")
if var_has and not ref_has:
raise _validate_err(f"variant data has '{field}' when reference data does not")

def __validate_len_field(self):
rd = self.reference_data
end_exists = 'END' in rd.entry
len_exists = 'LEN' in rd.entry
if not (end_exists or len_exists):
_validate_err("expect at least one of 'END' or 'LEN' in entry of reference data")
raise _validate_err("expect at least one of 'END' or 'LEN' in entry of reference data")
if end_exists and rd.END.dtype != hl.tint32:
_validate_err("'END' field in entry of reference data must have type tint32")
raise _validate_err("'END' field in entry of reference data must have type tint32")
if len_exists and rd.LEN.dtype != hl.tint32:
_validate_err("'LEN' field in entry of reference data must have type tint32")
self._validate_filters_fields()

if check_data:
# check cols
ref_cols = rd.col_key.collect()
var_cols = vd.col_key.collect()
if len(ref_cols) != len(var_cols):
_validate_err(
f"mismatch in number of columns: reference data has {ref_cols} columns, variant data has {var_cols} columns"
)

if ref_cols != var_cols:
first_mismatch = 0
while ref_cols[first_mismatch] == var_cols[first_mismatch]:
first_mismatch += 1
_validate_err(
f"mismatch in columns keys: ref={ref_cols[first_mismatch]}, var={var_cols[first_mismatch]} at position {first_mismatch}"
)
raise _validate_err("'LEN' field in entry of reference data must have type tint32")

# check locus distinctness
n_rd_rows = rd.count_rows()
n_distinct = rd.distinct_by_row().count_rows()
def __validate_data(self):
rd = self.reference_data
vd = self.variant_data

if n_distinct != n_rd_rows:
_validate_err(
f'reference data loci are not distinct: found {n_rd_rows} rows, but {n_distinct} distinct loci'
)
# check cols
ref_cols = rd.col_key.collect()
var_cols = vd.col_key.collect()
if len(ref_cols) != len(var_cols):
raise _validate_err(
f"mismatch in number of columns: reference data has {ref_cols} columns, variant data has {var_cols} columns"
)

# check END field
end_exprs = dict(
missing_end=hl.agg.filter(hl.is_missing(rd.END), hl.agg.take((rd.row_key, rd.col_key), 5)),
end_before_position=hl.agg.filter(rd.END < rd.locus.position, hl.agg.take((rd.row_key, rd.col_key), 5)),
if ref_cols != var_cols:
first_mismatch = 0
while ref_cols[first_mismatch] == var_cols[first_mismatch]:
first_mismatch += 1
raise _validate_err(
f"mismatch in columns keys: ref={ref_cols[first_mismatch]}, var={var_cols[first_mismatch]} at position {first_mismatch}"
)
if VariantDataset.ref_block_max_length_field in rd.globals:
rbml = rd[VariantDataset.ref_block_max_length_field]
end_exprs['blocks_too_long'] = hl.agg.filter(
rd.END - rd.locus.position + 1 > rbml, hl.agg.take((rd.row_key, rd.col_key), 5)
)

res = rd.aggregate_entries(hl.struct(**end_exprs))
# check locus distinctness
n_rd_rows = rd.count_rows()
n_distinct = rd.distinct_by_row().count_rows()

if res.missing_end:
_validate_err(
'found records in reference data with missing END field\n '
+ '\n '.join(str(x) for x in res.missing_end)
)
if res.end_before_position:
_validate_err(
'found records in reference data with END before locus position\n '
+ '\n '.join(str(x) for x in res.end_before_position)
)
blocks_too_long = res.get('blocks_too_long', [])
if blocks_too_long:
_validate_err(
'found records in reference data with blocks larger than `ref_block_max_length`\n '
+ '\n '.join(str(x) for x in blocks_too_long)
)
if n_distinct != n_rd_rows:
raise _validate_err(
f'reference data loci are not distinct: found {n_rd_rows} rows, but {n_distinct} distinct loci'
)

def _validate_filters_fields(self):
field = 'gvcf_filters'
ref_has = field in self.reference_data.entry
var_has = field in self.variant_data.entry
if ref_has == var_has:
return
if ref_has and not var_has:
_validate_err(f"reference data has '{field}' when variant data does not")
if var_has and not ref_has:
_validate_err(f"variant data has '{field}' when reference data does not")
# check LEN field
# technically speaking, it's possible for the LEN field to be missing,
# (ex. if the VDS was created manually using its constructor), so,
# temporarily add it for validation (or no-op if it's already defined).
rd = VariantDataset._add_len(rd)
len_exprs = {
'missing_len': hl.agg.filter(hl.is_missing(rd.LEN), hl.agg.take((rd.row_key, rd.col_key), 5)),
'negative_len': hl.agg.filter(rd.LEN < 0, hl.agg.take((rd.row_key, rd.col_key), 5)),
}
if VariantDataset.ref_block_max_length_field in rd.globals:
rbml = rd[VariantDataset.ref_block_max_length_field]
len_exprs['blocks_too_long'] = hl.agg.filter(rd.LEN > rbml, hl.agg.take((rd.row_key, rd.col_key), 5))

res = rd.aggregate_entries(hl.struct(**len_exprs))

if res.missing_len:
raise _validate_err(
'found records in reference data with missing LEN field\n '
+ '\n '.join(str(x) for x in res.missing_len)
)
if res.negative_len:
raise _validate_err(
'found records in reference data with negative LEN\n ' + '\n '.join(str(x) for x in res.negative_len)
)
blocks_too_long = res.get('blocks_too_long', [])
if blocks_too_long:
raise _validate_err(
'found records in reference data with blocks larger than `ref_block_max_length`\n '
+ '\n '.join(str(x) for x in blocks_too_long)
)

def _same(self, other: 'VariantDataset'):
return self.reference_data._same(other.reference_data) and self.variant_data._same(other.variant_data)
Expand Down Expand Up @@ -434,4 +452,4 @@ def union_rows(*vdses):


def _validate_err(msg):
raise ValueError(f'VDS.validate: {msg}')
return ValueError(f'VDS.validate: {msg}')
4 changes: 2 additions & 2 deletions hail/python/test/hail/vds/test_vds.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def test_validate():
with pytest.raises(ValueError):
hl.vds.VariantDataset(
vds.reference_data.annotate_entries(
END=hl.or_missing(vds.reference_data.locus.position % 2 == 0, vds.reference_data.END)
LEN=hl.or_missing(vds.reference_data.locus.position % 2 == 0, vds.reference_data.LEN)
),
vds.variant_data,
).validate()

with pytest.raises(ValueError):
hl.vds.VariantDataset(
vds.reference_data.annotate_entries(END=vds.reference_data.END + 1),
vds.reference_data.annotate_entries(LEN=vds.reference_data.LEN + 1),
vds.variant_data,
).validate()

Expand Down