Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,8 @@ lineage_name_by_abbreviation = {
"yam": "Yam",
}

clade_url_by_lineage_and_segment = {
"h1n1pdm": {
"ha": "https://raw.githubusercontent.com/influenza-clade-nomenclature/seasonal_A-H1N1pdm_HA/main/.auto-generated/clades.tsv",
},
"h3n2": {
"ha": "https://raw.githubusercontent.com/influenza-clade-nomenclature/seasonal_A-H3N2_HA/main/.auto-generated/clades.tsv",
},
"vic": {
"ha": "https://raw.githubusercontent.com/influenza-clade-nomenclature/seasonal_B-Vic_HA/main/.auto-generated/clades.tsv",
}
}

subclade_url_by_lineage_and_segment = {
"h1n1pdm": {
"ha": "https://raw.githubusercontent.com/influenza-clade-nomenclature/seasonal_A-H1N1pdm_HA/main/.auto-generated/subclades.tsv",
"na": "https://raw.githubusercontent.com/influenza-clade-nomenclature/seasonal_A-H1N1pdm_NA/main/.auto-generated/subclades.tsv",
},
"h3n2": {
"ha": "https://raw.githubusercontent.com/influenza-clade-nomenclature/seasonal_A-H3N2_HA/main/.auto-generated/subclades.tsv",
"na": "https://raw.githubusercontent.com/influenza-clade-nomenclature/seasonal_A-H3N2_NA/main/.auto-generated/subclades.tsv",
},
"vic": {
"ha": "https://raw.githubusercontent.com/influenza-clade-nomenclature/seasonal_B-Vic_HA/main/.auto-generated/subclades.tsv",
"na": "https://raw.githubusercontent.com/influenza-clade-nomenclature/seasonal_B-Vic_NA/main/.auto-generated/subclades.tsv",
}
nextclade_dataset_by_lineage_and_segment = {
"h3n2_ha": "flu_h3n2_ha_broad",
}

include: "workflow/snakemake_rules/common.smk"
Expand Down
129 changes: 113 additions & 16 deletions scripts/table_to_node_data.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,129 @@
"""Create Augur-compatible node data JSON from a pandas data frame.
"""
import argparse
import sys

import pandas as pd
from augur.utils import write_json

from augur.utils import annotate_parents_for_tree, read_tree, write_json


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--table", help="table to convert to a node data JSON")
parser.add_argument("--table", required=True, help="table to convert to a node data JSON")
parser.add_argument("--tree", help="tree with named internal nodes that match the index column values in the given table. Only required for assigning branch labels.")
parser.add_argument("--index-column", default="strain", help="name of the column to use as an index")
parser.add_argument("--delimiter", default=",", help="separator between columns in the given table")
parser.add_argument("--columns", nargs="+", help="optional list of columns in the given table to include in the output. If not provided, all columns will be included.")
parser.add_argument("--node-name", default="nodes", help="name of the node data attribute in the JSON output")
parser.add_argument("--output", help="node data JSON file")
parser.add_argument("--branch-labels", nargs="+", help="optional map of column names to branch labels. You must supply this for each column you would like to label. By default the branch label key the same as the column name, but you may customise this via the COLUMN=LABEL syntax.")
parser.add_argument("--column-to-node-attribute", nargs="+", help="optional map of column names to node attributes in the node data JSON in the format of COLUMN=ATTRIBUTE")
parser.add_argument("--error-on-polyphyletic-labels", action="store_true", help="exit with an error if the script finds any polyphyletic branch labels")
parser.add_argument("--output", required=True, help="node data JSON file")

args = parser.parse_args()

if args.output is not None:
table = pd.read_csv(
args.table,
sep=args.delimiter,
index_col=args.index_column,
dtype=str,
delimiter = "," if args.table.endswith(".csv") else "\t"

table = pd.read_csv(
args.table,
sep=delimiter,
index_col=args.index_column,
dtype=str,
)

if args.columns:
table = table.loc[:, args.columns].copy()

# Optionally rename columns to new node attribute names.
column_to_node_attribute = {}
if args.column_to_node_attribute:
column_to_node_attribute = dict(
tuple(value.split("="))
for value in args.column_to_node_attribute
)

# # Convert columns that aren't strain names or labels to floats.
# for column in table.columns:
# if column != "strain" and not "label" in column:
# table[column] = table[column].astype(float)
table_dict = table.rename(columns=column_to_node_attribute).transpose().to_dict()
node_data = {
args.node_name: table_dict,
}

# Optionally annotate branch labels for internal nodes, if a tree is given.
found_polyphyletic_labels = False
if args.branch_labels:
if not args.tree:
print(
"ERROR: You must provide a Newick tree with named internal nodes (e.g., from augur refine) to assign branch labels.",
file=sys.stderr,
)
sys.exit(1)

# Load the tree.
tree = annotate_parents_for_tree(read_tree(args.tree))

# Parse branch columns and labels.
branch_label_by_column = {}
for branch_label in args.branch_labels:
if "=" in branch_label:
column, label = branch_label.split("=")
else:
column = label = branch_label

branch_label_by_column[column] = label

# For each branch label column, find all distinct values of the column
# and then find the first node in the tree with each value.
branches = {}
for column, label in branch_label_by_column.items():
# Get distinct values for this column.
branch_values = set(table[column].drop_duplicates().values)
labeled_branch_values = set()

# Map each node to its value.
value_by_node = dict(table[column].reset_index().values)

# Using a preorder traversal, find the first node in the tree with
# each distinct value.
for node in tree.find_clades():
node_value = value_by_node.get(node.name)
parent_node_value = value_by_node.get(getattr(node.parent, "name", None))

# Check whether the current node value is among the values we
# want to label.
if node_value in branch_values and node_value != parent_node_value:
# Check whether the current value has already been assigned
# to a branch. If it hasn't, assign the label to the branch.
# If it has been assigned and the current node represents
# another clade with the value, let the user know that the
# value could not be assigned monophyletically.
if node.name not in branches:
branches[node.name] = {"labels": {}}

if node_value not in labeled_branch_values:
branches[node.name]["labels"][label] = node_value
labeled_branch_values.add(node_value)
else:
print(
f"WARNING: The value {node_value} for the column {column} is not monophyletic ({node.name})",
file=sys.stderr,
)

branches[node.name]["labels"][label + "_by_node"] = node_value + ":" + node.name
# if node_value not in labeled_branch_values:
# if node.name not in branches:
# branches[node.name] = {"labels": {}}

# branches[node.name]["labels"][label] = node_value
# labeled_branch_values.add(node_value)
# elif node_value != parent_node_value:
# found_polyphyletic_labels = True
# print(
# f"WARNING: The value {node_value} for the column {column} is not monophyletic ({node.name})",
# file=sys.stderr,
# )

node_data["branches"] = branches

if found_polyphyletic_labels and args.error_on_polyphyletic_labels:
sys.exit(1)

table_dict = table.transpose().to_dict()
write_json({args.node_name: table_dict}, args.output)
write_json(node_data, args.output)
102 changes: 57 additions & 45 deletions workflow/snakemake_rules/core.smk
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ rule ancestral:
annotation = lambda w: f"{config['builds'][w.build_name]['annotation']}",
output:
node_data = build_dir + "/{build_name}/{segment}/muts.json",
sequences = build_dir + "/{build_name}/{segment}/ancestral_sequences.fasta",
translations_done = build_dir + "/{build_name}/{segment}/translations.done",
params:
inference = "joint",
Expand All @@ -298,10 +299,36 @@ rule ancestral:
--genes {params.genes} \
--translations "{params.input_translations}" \
--output-node-data {output.node_data} \
--output-sequences {output.sequences} \
--output-translations "{params.output_translations}" \
--inference {params.inference} 2>&1 | tee {log} && touch {output.translations_done}
"""

rule clades_for_ancestral_sequences:
input:
sequences = build_dir + "/{build_name}/{segment}/ancestral_sequences.fasta",
nextclade_dataset = lambda wildcards: f"nextclade_dataset/{config['builds'][wildcards.build_name]['lineage']}_{wildcards.segment}/",
output:
annotations = build_dir + "/{build_name}/{segment}/nextclade_ancestral_sequences.tsv",
conda: "../envs/nextstrain.yaml"
benchmark:
"benchmarks/clades_for_ancestral_sequences_{build_name}_{segment}.txt"
log:
"logs/clades_for_ancestral_sequences_{build_name}_{segment}.txt"
threads: 8
resources:
mem_mb=16000,
time="0:30:00",
shell:
r"""
nextclade run\
{input.sequences} \
-D {input.nextclade_dataset} \
--gap-alignment-side right \
--jobs {threads} \
--output-tsv {output.annotations} 2>&1 | tee {log}
"""

rule traits:
message:
"""
Expand Down Expand Up @@ -329,74 +356,59 @@ rule traits:
--confidence 2>&1 | tee {log}
"""

rule download_clades:
output:
clades="config/{lineage}/{segment}/clades.tsv",
conda: "../envs/nextstrain.yaml"
params:
url=lambda wildcards: clade_url_by_lineage_and_segment.get(wildcards.lineage, {}).get(wildcards.segment),
shell:
"""
curl -o {output.clades} "{params.url}"
"""

# Determine clades with HA mutations.
rule clades:
input:
nextclade_annotations = build_dir + "/{build_name}/ha/nextclade_ancestral_sequences.tsv",
tree = build_dir + "/{build_name}/ha/tree.nwk",
muts = build_dir + "/{build_name}/ha/muts.json",
clades = lambda wildcards: config["builds"][wildcards.build_name]["clades"],
output:
node_data = build_dir + "/{build_name}/ha/clades.json",
params:
index_column="seqName",
columns = ["clade"],
column_to_branch_label = ["clade=clade"],
column_to_node_attribute = ["clade=clade_membership"],
conda: "../envs/nextstrain.yaml"
benchmark:
"benchmarks/clades_{build_name}.txt"
"benchmarks/clades_{build_name}_ha.txt"
log:
"logs/clades_{build_name}.txt"
"logs/clades_{build_name}_ha.txt"
shell:
"""
augur clades \
r"""
python scripts/table_to_node_data.py \
--table {input.nextclade_annotations} \
--tree {input.tree} \
--mutations {input.muts} \
--clades {input.clades} \
--index-column {params.index_column:q} \
--columns {params.columns:q} \
--branch-labels {params.column_to_branch_label:q} \
--column-to-node-attribute {params.column_to_node_attribute:q} \
--output {output.node_data} 2>&1 | tee {log}
"""

rule download_subclades:
output:
subclades="config/{lineage}/{segment}/subclades.tsv",
conda: "../envs/nextstrain.yaml"
params:
url=lambda wildcards: subclade_url_by_lineage_and_segment.get(wildcards.lineage, {}).get(wildcards.segment),
shell:
"""
curl -o {output.subclades} "{params.url}"
"""

# Determine subclades for na and ha.
rule subclades:
input:
nextclade_annotations = build_dir + "/{build_name}/{segment}/nextclade_ancestral_sequences.tsv",
tree = build_dir + "/{build_name}/{segment}/tree.nwk",
muts = build_dir + "/{build_name}/{segment}/muts.json",
clades = lambda wildcards: config["builds"][wildcards.build_name].get("subclades"),
output:
node_data = build_dir + "/{build_name}/{segment}/subclades.json",
params:
membership_name = "subclade",
label_name = "Subclade",
index_column="seqName",
columns = lambda wildcards: ["subclade"] if wildcards.segment == "ha" else ["clade"],
column_to_branch_label = lambda wildcards: ["subclade=Subclade"] if wildcards.segment == "ha" else ["clade=Subclade"],
column_to_node_attribute = lambda wildcards: ["subclade=subclade"] if wildcards.segment == "ha" else ["clade=subclade"],
conda: "../envs/nextstrain.yaml"
benchmark:
"benchmarks/subclades_{build_name}_{segment}.txt"
log:
"logs/subclades_{build_name}_{segment}.txt"
shell:
"""
augur clades \
r"""
python scripts/table_to_node_data.py \
--table {input.nextclade_annotations} \
--tree {input.tree} \
--mutations {input.muts} \
--clades {input.clades} \
--membership-name {params.membership_name} \
--label-name {params.label_name} \
--index-column {params.index_column:q} \
--columns {params.columns:q} \
--branch-labels {params.column_to_branch_label:q} \
--column-to-node-attribute {params.column_to_node_attribute:q} \
--output {output.node_data} 2>&1 | tee {log}
"""

Expand Down Expand Up @@ -449,7 +461,7 @@ rule annotate_derived_haplotypes:
input:
tree=build_dir + "/{build_name}/ha/tree.nwk",
translations=build_dir + "/{build_name}/ha/translations.done",
clades=lambda wildcards: build_dir + "/{build_name}/ha/subclades.json" if "subclades" in config["builds"][wildcards.build_name] else build_dir + "/{build_name}/ha/clades.json",
clades=lambda wildcards: build_dir + "/{build_name}/ha/subclades.json",
output:
haplotypes=build_dir + "/{build_name}/ha/derived_haplotypes.json",
conda: "../envs/nextstrain.yaml"
Expand All @@ -460,8 +472,8 @@ rule annotate_derived_haplotypes:
params:
min_tips=config.get("haplotypes", {}).get("min_tips", 5),
alignment=build_dir + "/{build_name}/ha/translations/HA1_withInternalNodes.fasta",
clade_label_attribute=lambda wildcards: "Subclade" if "subclades" in config["builds"][wildcards.build_name] else "clade",
clade_node_attribute=lambda wildcards: "subclade" if "subclades" in config["builds"][wildcards.build_name] else "clade_membership"
clade_label_attribute="Subclade",
clade_node_attribute="subclade",
shell:
"""
python3 scripts/annotate_derived_haplotypes.py \
Expand Down
11 changes: 8 additions & 3 deletions workflow/snakemake_rules/select_strains.smk
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,15 @@ rule get_nextclade_dataset_for_lineage_and_segment:
output:
nextclade_dir=directory("nextclade_dataset/{lineage}_{segment}/"),
params:
dataset_name=lambda wildcards: nextclade_dataset_by_lineage_and_segment.get(
f"{wildcards.lineage}_{wildcards.segment}",
f"nextstrain/flu/{wildcards.lineage}/{wildcards.segment}"
),
nextclade_server_arg=lambda wildcards: f"--server={shquotewords(config['nextclade_server'])}" if config.get("nextclade_server") else "",
shell:
r"""
nextclade3 dataset get \
-n 'nextstrain/flu/{wildcards.lineage}/{wildcards.segment}' \
nextclade dataset get \
-n {params.dataset_name} \
{params.nextclade_server_arg} \
--output-dir {output.nextclade_dir}
"""
Expand Down Expand Up @@ -302,7 +306,8 @@ rule merge_nextclade_with_metadata:
--metadata-id-columns \
metadata={params.metadata_id} \
nextclade={params.nextclade_id} \
--output-metadata {output.merged} 2>&1 | tee {log}
--output-metadata /dev/stdout 2> {log} \
| csvtk mutate -t -f subclade -n subclade_nextclade > {output.merged}
"""

def get_subsample_input(w):
Expand Down