diff --git a/dataset/config/complex_traits.csv b/dataset/config/complex_traits.csv new file mode 100644 index 0000000..6226cb7 --- /dev/null +++ b/dataset/config/complex_traits.csv @@ -0,0 +1,120 @@ +trait +AD +AFib +AG +Age_at_Menarche +Age_at_Menopause +AID +Alb +ALP +ALT +Alzheimer +AMD +ApoA +ApoB +AST +Asthma +Balding_Type4 +Baso +Benign_Neoplasms +BFP +Blood_Clot_Lung +BMI +BrC +BW +Ca +CAD +Carpal_Tunnel_Syndrome +Cataract +CD +Cholelithiasis +Cirrhosis +College +COPD +Coxarthrosis +CRC +CRP +DBP +Depression +Diverticulosis +DVT +eBMD +EduYears +eGFR +eGFRcys +Eosino +FedUp_Feelings +FEV1FVC +Fibroblastic_Disorders +GGT +Glaucoma +Glucose +Guilty_Feelings +Hb +HbA1c +HDLC +Height +Ht +Hypothyroidism +IBD +IGF1 +Inguinal_Hernia +Insomnia +Irritability +IS +LDLC +LipoA +Loneliness +LOY +LuC +Lym +Malignant_Neoplasms +MAP +MCH +MCHC +MCP +MCV +MI +Migraine +Miserableness +Mono +Mood_Swings +Morning_Person +NAP +Nervous_Feelings +Neuroticism +Neutro +P +Plt +PP +PrC +RA +RBC +Risk_Taking +SBP +sCr +Sensitivity +SHBG +SkC +Smoking_CPD +Smoking_Ever_Never +Suffer_from_Nerves +T1D +T2D +TBil +TC +Tense +Testosterone_F +Testosterone_M +TG +TP +UA +UF +Urea +Urolithiasis +Varicose_Veins +VitD +WBC +WHRadjBMI +Worrier +Worry_Too_Long diff --git a/dataset/config/config.yaml b/dataset/config/config.yaml index aac5d3c..3b2cb05 100644 --- a/dataset/config/config.yaml +++ b/dataset/config/config.yaml @@ -1,4 +1,5 @@ annotation_url: "http://ftp.ensembl.org/pub/release-107/gtf/homo_sapiens/Homo_sapiens.GRCh38.107.chr.gtf.gz" +genome_url: "http://ftp.ensembl.org/pub/release-107/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna_sm.primary_assembly.fa.gz" clinvar: release: 20251019 diff --git a/dataset/workflow/Snakefile b/dataset/workflow/Snakefile index d276bdf..c750dda 100644 --- a/dataset/workflow/Snakefile +++ b/dataset/workflow/Snakefile @@ -9,3 +9,13 @@ include: "rules/gnomad.smk" include: "rules/hgmd.smk" include: "rules/mendelian_traits.smk" include: "rules/smedley_et_al.smk" +include: "rules/complex_traits.smk" +include: "rules/ldscore.smk" + + +rule all: + input: + "results/dataset/mendelian_traits_matched_9/test.parquet", + "results/feature_performance/mendelian_traits_matched_9.parquet", + "results/dataset/complex_traits_matched_9/test.parquet", + "results/feature_performance/complex_traits_matched_9.parquet", diff --git a/dataset/workflow/rules/common.smk b/dataset/workflow/rules/common.smk index 8fda337..896c1b0 100644 --- a/dataset/workflow/rules/common.smk +++ b/dataset/workflow/rules/common.smk @@ -1,9 +1,10 @@ +from gpn.data import Genome import pandas as pd import polars as pl from cyvcf2 import VCF from sklearn.metrics import average_precision_score -from traitgym.intervals import add_exon, add_tss +from traitgym.intervals import add_exon, add_tss, build_dataset from traitgym.matching import match_features from traitgym.variants import ( COORDINATES, @@ -12,4 +13,12 @@ from traitgym.variants import ( filter_snp, filter_chroms, lift_hg19_to_hg38, + check_ref_alt, ) + + +rule download_genome: + output: + "results/genome.fa.gz", + shell: + "wget {config[genome_url]} -O {output}" diff --git a/dataset/workflow/rules/complex_traits.smk b/dataset/workflow/rules/complex_traits.smk new file mode 100644 index 0000000..d629fc0 --- /dev/null +++ b/dataset/workflow/rules/complex_traits.smk @@ -0,0 +1,270 @@ +FINEMAPPING_URL = "https://huggingface.co/datasets/gonzalobenegas/finucane-ukbb-finemapping/resolve/main" +FINEMAPPING_METHODS = ["SuSiE", "FINEMAP"] +FINEMAPPING_PIP_DIFF_THRESHOLD = 0.05 +HIGH_PIP_THRESHOLD = 0.9 +LOW_PIP_THRESHOLD = 0.01 +COMPLEX_TRAITS = pl.read_csv("config/complex_traits.csv")["trait"].to_list() + + +rule complex_traits_download_all_finemapping: + input: + expand( + "results/complex_traits/finemapping/{trait}/{method}.parquet", + trait=COMPLEX_TRAITS, + method=FINEMAPPING_METHODS, + ), + + +rule complex_traits_download_finemapping: + output: + "results/complex_traits/finemapping/{trait}/{method}.parquet", + wildcard_constraints: + method="|".join(FINEMAPPING_METHODS), + run: + url = f"{FINEMAPPING_URL}/UKBB.{wildcards.trait}.{wildcards.method}.tsv.bgz" + ( + pl.read_csv( + url, + separator="\t", + null_values=["NA"], + schema_overrides={"chromosome": pl.String}, + columns=[ + "chromosome", + "position", + "allele1", + "allele2", + "rsid", + "pip", + ], + ) + .rename( + { + "chromosome": "chrom", + "position": "pos", + "allele1": "ref", + "allele2": "alt", + } + ) + .pipe(filter_snp) + .write_parquet(output[0]) + ) + + +rule complex_traits_combine_methods: + input: + susie="results/complex_traits/finemapping/{trait}/SuSiE.parquet", + finemap="results/complex_traits/finemapping/{trait}/FINEMAP.parquet", + output: + "results/complex_traits/finemapping/{trait}/combined.parquet", + run: + pip_defined_in_both_methods = ( + pl.col("pip_susie").is_not_null() & pl.col("pip_finemap").is_not_null() + ) + ( + pl.read_parquet(input.susie) + .join( + pl.read_parquet(input.finemap), + on=COORDINATES, + how="full", + suffix="_finemap", + ) + .rename({"pip": "pip_susie", "rsid": "rsid_susie"}) + .with_columns( + *(pl.coalesce(col, f"{col}_finemap").alias(col) for col in COORDINATES), + pl.coalesce("rsid_susie", "rsid_finemap").alias("rsid"), + pl.when(pip_defined_in_both_methods) + .then( + pl.when( + (pl.col("pip_susie") - pl.col("pip_finemap")).abs() + <= FINEMAPPING_PIP_DIFF_THRESHOLD + ) + .then((pl.col("pip_susie") + pl.col("pip_finemap")) / 2) + .otherwise(pl.lit(None)) + ) + .otherwise(pl.coalesce("pip_susie", "pip_finemap")) + .alias("pip"), + ) + .select([*COORDINATES, "rsid", "pip"]) + .write_parquet(output[0]) + ) + + +rule complex_traits_aggregate_traits: + input: + expand( + "results/complex_traits/finemapping/{trait}/combined.parquet", + trait=COMPLEX_TRAITS, + ), + output: + "results/complex_traits/finemapping/aggregated.parquet", + run: + any_null_pip = pl.col("pip").is_null().any() + ( + pl.concat( + [ + pl.read_parquet(path).with_columns(trait=pl.lit(trait)) + for path, trait in zip(input, COMPLEX_TRAITS) + ] + ) + .with_columns( + pl.when(pl.col("pip") > HIGH_PIP_THRESHOLD) + .then(pl.col("trait")) + .otherwise(pl.lit(None)) + .alias("trait") + ) + .group_by(COORDINATES) + .agg( + pl.col("rsid").first(), + pl.col("pip").max(), + any_null_pip.alias("any_null_pip"), + pl.col("trait").drop_nulls().unique(), + ) + .with_columns(pl.col("trait").list.sort().list.join(",").alias("traits")) + .with_columns( + pl.when(pl.col("pip") > HIGH_PIP_THRESHOLD) + .then(pl.lit(True)) + .when((pl.col("pip") < LOW_PIP_THRESHOLD) & ~pl.col("any_null_pip")) + .then(pl.lit(False)) + .otherwise(pl.lit(None)) + .alias("label") + ) + .filter(pl.col("label").is_not_null()) + .drop(["trait", "any_null_pip"]) + .sort(COORDINATES) + .write_parquet(output[0]) + ) + + +# Liftover from hg19 to hg38 drops one positive: chr10:17891705 (rs1556465893, AST/Alb/TP). +# This region doesn't exist in hg38 (dbSNP has no GRCh38 mapping for this variant). +rule complex_traits_annotate: + input: + "results/complex_traits/finemapping/aggregated.parquet", + "results/ldscore/UKBB.EUR.ldscore.parquet", + genome="results/genome.fa.gz", + consequences=expand("results/consequences/{chrom}.parquet", chrom=CHROMS), + output: + "results/complex_traits/annotated.parquet", + run: + ldscore = pl.read_parquet(input[1], columns=COORDINATES + ["MAF", "ld_score"]) + genome = Genome(input.genome) + V = ( + pl.read_parquet(input[0]) + .join(ldscore, on=COORDINATES, how="left") + # Drops ~34 high-PIP variants with very low MAF not present in LD score file + # (e.g., rs115142852, rs553424940, rs534716024) + .filter(pl.col("ld_score").is_not_null()) + .pipe(lift_hg19_to_hg38) + .filter(pl.col("pos") != -1) + .pipe(filter_chroms) + .pipe(check_ref_alt, genome) + .sort(COORDINATES) + ) + results = [] + for path, chrom in zip(input.consequences, CHROMS): + chrom_variants = V.filter(pl.col("chrom") == chrom).lazy() + consequences_lf = pl.scan_parquet(path) + joined = chrom_variants.join( + consequences_lf, + on=COORDINATES, + how="left", + maintain_order="left", + ).collect(engine="streaming") + results.append(joined) + pl.concat(results).write_parquet(output[0]) + + +rule complex_traits_full_consequence_counts: + input: + "results/complex_traits/annotated.parquet", + output: + "results/complex_traits/full_consequence_counts.parquet", + run: + ( + pl.read_parquet(input[0], columns=["label", "consequence"]) + .filter(pl.col("label")) + .get_column("consequence") + .value_counts() + .sort("count", descending=True) + .write_parquet(output[0]) + ) + + +rule complex_traits_dataset_all: + input: + "results/complex_traits/annotated.parquet", + "results/intervals/exon.parquet", + "results/intervals/tss.parquet", + output: + "results/dataset/complex_traits_all/test.parquet", + run: + build_dataset( + pl.read_parquet(input[0]), + pl.read_parquet(input[1]), + pl.read_parquet(input[2]), + config["exclude_consequences"], + config["exon_proximal_dist"], + config["tss_proximal_dist"], + config["consequence_groups"], + ).write_parquet(output[0]) + + +rule complex_traits_dataset_matched: + input: + "results/dataset/complex_traits_all/test.parquet", + output: + "results/dataset/complex_traits_matched_{k}/test.parquet", + run: + V = pl.read_parquet(input[0]) + ( + match_features( + V.filter(pl.col("label")), + V.filter(~pl.col("label")), + ["tss_dist", "exon_dist", "MAF", "ld_score"], + ["chrom", "consequence_final"], + int(wildcards.k), + ).write_parquet(output[0]) + ) + + +rule complex_traits_matched_feature_performance: + input: + "results/dataset/complex_traits_matched_{k}/test.parquet", + output: + "results/feature_performance/complex_traits_matched_{k}.parquet", + run: + V = pl.read_parquet(input[0]) + features = ["tss_dist", "exon_dist", "MAF", "ld_score"] + # Sign: +1 if higher value predicts positive, -1 if lower value predicts positive + sign = {"tss_dist": -1, "exon_dist": -1, "MAF": 1, "ld_score": -1} + rows = [] + + for feature in features: + auprc = average_precision_score(V["label"], sign[feature] * V[feature]) + rows.append( + { + "feature": feature, + "consequence_final": "all", + "auprc": auprc, + "n_pos": V["label"].sum(), + "n_neg": (~V["label"]).sum(), + } + ) + + # Per consequence_final + for consequence in V["consequence_final"].unique().sort(): + subset = V.filter(pl.col("consequence_final") == consequence) + auprc = average_precision_score( + subset["label"], sign[feature] * subset[feature] + ) + rows.append( + { + "feature": feature, + "consequence_final": consequence, + "auprc": auprc, + "n_pos": subset["label"].sum(), + "n_neg": (~subset["label"]).sum(), + } + ) + + pl.DataFrame(rows).write_parquet(output[0]) diff --git a/other/workflow/rules/data/ldscore.smk b/dataset/workflow/rules/ldscore.smk similarity index 67% rename from other/workflow/rules/data/ldscore.smk rename to dataset/workflow/rules/ldscore.smk index d3b6c4c..3309585 100644 --- a/other/workflow/rules/data/ldscore.smk +++ b/dataset/workflow/rules/ldscore.smk @@ -16,7 +16,7 @@ rule ldscore_convert: "python3 workflow/scripts/ht2tsv.py {input} {output}" -# (base) shelob:~/projects/functionality-prediction$ zcat results/ldscore/UKBB.EUR.ldscore.tsv.bgz | head +# $ zcat results/ldscore/UKBB.EUR.ldscore.tsv.bgz | head # locus alleles rsid AF ld_score # 1:11063 ["T","G"] rs561109771 4.7982e-05 5.7386e+00 # 1:13259 ["G","A"] rs562993331 2.7798e-04 5.0488e+00 @@ -32,11 +32,10 @@ rule ldscore_convert: rule ldscore_process: input: "results/ldscore/UKBB.EUR.ldscore.tsv.bgz", - "results/genome.fa.gz", output: "results/ldscore/UKBB.EUR.ldscore.parquet", run: - V = ( + ( pl.read_csv( input[0], separator="\t", @@ -55,7 +54,7 @@ rule ldscore_process: pl.when(pl.col("AF") < 0.5) .then(pl.col("AF")) .otherwise(1 - pl.col("AF")) - .alias("maf"), + .alias("MAF"), ) .with_columns( pl.col("locus").struct.field("chrom"), @@ -64,34 +63,8 @@ rule ldscore_process: pl.col("alleles").struct.field("alt"), ) .drop(["locus", "alleles"]) - .select(COORDINATES + ["maf", "ld_score"]) - .to_pandas() + .pipe(filter_snp) + .select(COORDINATES + ["MAF", "ld_score"]) + .sort(COORDINATES) + .write_parquet(output[0]) ) - print(V) - V = filter_snp(V) - print(V.shape) - V = lift_hg19_to_hg38(V) - V = V[V.pos != -1] - print(V.shape) - genome = Genome(input[1]) - V = check_ref_alt(V, genome) - print(V.shape) - V = sort_variants(V) - V.to_parquet(output[0], index=False) - - -rule ldscore_feature: - input: - "results/dataset/{dataset}/test.parquet", - "results/ldscore/UKBB.EUR.ldscore.parquet", - output: - "results/dataset/{dataset}/features/LDScore.parquet", - run: - V = pd.read_parquet(input[0]) - ldscore = pd.read_parquet(input[1], columns=COORDINATES + ["ld_score"]).rename( - columns={"ld_score": "score"} - ) - V = V.merge(ldscore, on=COORDINATES, how="left") - print(f"{V.score.isna().sum()=}") - print(V.groupby("label").score.mean()) - V[["score"]].to_parquet(output[0], index=False) diff --git a/dataset/workflow/rules/mendelian_traits.smk b/dataset/workflow/rules/mendelian_traits.smk index 4c79678..d1d1a5e 100644 --- a/dataset/workflow/rules/mendelian_traits.smk +++ b/dataset/workflow/rules/mendelian_traits.smk @@ -67,34 +67,22 @@ rule mendelian_traits_dataset_all: output: "results/dataset/mendelian_traits_all/test.parquet", run: - exon = pl.read_parquet(input[2]) - tss = pl.read_parquet(input[3]) - V = ( - pl.concat( - [ - pl.read_parquet(input[0]).with_columns(label=pl.lit(True)), - pl.read_parquet(input[1]).with_columns(label=pl.lit(False)), - ], - how="diagonal_relaxed", - ) - .filter(~pl.col("consequence").is_in(config["exclude_consequences"])) - # order is important, tss_proximal overrides exon_proximal - .pipe(add_exon, exon, config["exon_proximal_dist"]) - .pipe(add_tss, tss, config["tss_proximal_dist"]) - ) - consequence_final_pos = V.filter("label")["consequence_final"].unique() - V = V.filter(pl.col("consequence_final").is_in(consequence_final_pos)) - consequence_to_group = { - c: group - for group, consequences in config["consequence_groups"].items() - for c in consequences - } - V = V.with_columns( - pl.col("consequence_final") - .replace(consequence_to_group) - .alias("consequence_group") + V = pl.concat( + [ + pl.read_parquet(input[0]).with_columns(label=pl.lit(True)), + pl.read_parquet(input[1]).with_columns(label=pl.lit(False)), + ], + how="diagonal_relaxed", ) - V.sort(COORDINATES).write_parquet(output[0]) + build_dataset( + V, + pl.read_parquet(input[2]), + pl.read_parquet(input[3]), + config["exclude_consequences"], + config["exon_proximal_dist"], + config["tss_proximal_dist"], + config["consequence_groups"], + ).write_parquet(output[0]) rule mendelian_traits_dataset_matched: diff --git a/eval/config/config.yaml b/eval/config/config.yaml index 5ccf213..8a39638 100644 --- a/eval/config/config.yaml +++ b/eval/config/config.yaml @@ -28,7 +28,7 @@ datasets: # mendelian_traits_all: "../dataset/results/dataset/mendelian_traits_all/test.parquet" models: - - CADD.plus.score + # - CADD.plus.score # let's only do it for complex traits, for now - GPN-MSA_LLR.minus.score - GPN-Star-V_LLR.minus.llr_calibrated - GPN-Star-M_LLR.minus.llr_calibrated diff --git a/eval/workflow/rules/common.smk b/eval/workflow/rules/common.smk index 9077310..1fa25b0 100644 --- a/eval/workflow/rules/common.smk +++ b/eval/workflow/rules/common.smk @@ -1,4 +1,6 @@ import matplotlib.pyplot as plt +import numpy as np +import pandas as pd import polars as pl import seaborn as sns diff --git a/other/workflow/rules/common.smk b/other/workflow/rules/common.smk index f1ccec0..59c9937 100644 --- a/other/workflow/rules/common.smk +++ b/other/workflow/rules/common.smk @@ -81,158 +81,6 @@ tissues = pd.read_csv("config/gtex_tissues.txt", header=None).values.ravel() subsets = ["all"] + config["consequence_subsets"] -def add_tss(V: pl.DataFrame, tss: pl.DataFrame) -> pl.DataFrame: - """Add tss_dist column with distance to nearest TSS. - - Also overrides consequence to "tss_proximal" for non-exonic variants - within tss_proximal_dist of a TSS. - """ - V_pd = V.to_pandas() - tss_pd = tss.to_pandas() - V_pd["start"] = V_pd.pos - 1 - V_pd["end"] = V_pd.pos - V_pd = ( - bf.closest(V_pd, tss_pd) - .rename(columns={"distance": "tss_dist", "gene_id_": "tss_closest_gene_id"}) - .drop(columns=["start", "end", "chrom_", "start_", "end_"]) - ) - tss_proximal_dist = config["tss_proximal_dist"] - mask = V_pd.original_consequence.isin(NON_EXONIC) & ( - V_pd.tss_dist <= tss_proximal_dist - ) - V_pd.loc[mask, "consequence"] = "tss_proximal" - return pl.from_pandas(V_pd) - - -def add_exon(V: pl.DataFrame, exon: pl.DataFrame) -> pl.DataFrame: - """Add exon_dist column with distance to nearest exon. - - Also overrides consequence to "exon_proximal" for intron_variant - within exon_proximal_dist of an exon. - """ - V_pd = V.to_pandas() - exon_pd = exon.to_pandas() - V_pd["start"] = V_pd.pos - 1 - V_pd["end"] = V_pd.pos - V_pd = ( - bf.closest(V_pd, exon_pd) - .rename(columns={"distance": "exon_dist", "gene_id_": "exon_closest_gene_id"}) - .drop(columns=["start", "end", "chrom_", "start_", "end_"]) - ) - exon_proximal_dist = config["exon_proximal_dist"] - mask = (V_pd.original_consequence == "intron_variant") & ( - V_pd.exon_dist <= exon_proximal_dist - ) - V_pd.loc[mask, "consequence"] = "exon_proximal" - return pl.from_pandas(V_pd) - - -def add_cre(V: pl.DataFrame, cre: pl.DataFrame) -> pl.DataFrame: - """Add CRE-based consequence annotations. - - Creates consequence_cre column, leaving original consequence unchanged. - """ - # Filter CRE to chromosomes present in variants - chroms = V["chrom"].unique() - cre = cre.filter(pl.col("chrom").is_in(chroms)) - - # Handle case where no CRE intervals overlap with variant chromosomes - if cre.is_empty(): - return V.with_columns(pl.col("consequence").alias("consequence_cre")) - - # Prepare variant positions as 0-based half-open intervals [pos-1, pos) - V = V.with_columns( - (pl.col("pos") - 1).alias("start"), - pl.col("pos").alias("end"), - ) - - # Prepare CRE intervals with expanded flanks - flank_dist = config["cre_flank_dist"] # 500bp - cre_flank = cre.with_columns( - (pl.col("start") - flank_dist).clip(lower_bound=0).alias("start"), - (pl.col("end") + flank_dist).alias("end"), - (pl.col("cre_class") + "_flank").alias("cre_class"), - ) - - # Combine core + flank intervals (core has higher priority) - cre_combined = pl.concat([cre_flank, cre]) - - # Single overlap operation using COITree (O(log n) lookups) - # use_zero_based=True for 0-based half-open intervals (BED format) - overlaps = pb.overlap( - V.select(["chrom", "start", "end"]), - cre_combined.select(["chrom", "start", "end", "cre_class"]), - suffixes=("", "_cre"), - output_type="polars.DataFrame", - use_zero_based=True, - ) - - # Assign consequence based on CRE class priority - # Core classes override flank; within each, earlier in CRE_CLASSES = higher priority - # Lower priority number = higher priority (will be selected) - priority = {c: i for i, c in enumerate(CRE_CLASSES)} - priority.update( - {f"{c}_flank": i + len(CRE_CLASSES) for i, c in enumerate(CRE_CLASSES)} - ) - - best_cre = ( - overlaps.with_columns( - pl.col("cre_class_cre").replace_strict(priority).alias("priority") - ) - .group_by(["chrom", "start", "end"]) - .agg(pl.col("cre_class_cre").sort_by("priority").first().alias("cre_class")) - ) - - # Join back to variants and update consequence for non-exonic - return ( - V.join( - best_cre, on=["chrom", "start", "end"], how="left", maintain_order="left" - ) - .with_columns( - pl.when( - pl.col("consequence").is_in(NON_EXONIC) - & pl.col("cre_class").is_not_null() - ) - .then(pl.col("cre_class")) - .otherwise(pl.col("consequence")) - .alias("consequence_cre") - ) - .drop(["start", "end", "cre_class"]) - ) - - -def filter_snp(V: pd.DataFrame) -> pd.DataFrame: - return V[V.ref.isin(NUCLEOTIDES) & V.alt.isin(NUCLEOTIDES)] - - -def filter_chroms(V: pd.DataFrame) -> pd.DataFrame: - return V[V.chrom.isin(CHROMS)] - - -def lift_hg19_to_hg38(V: pd.DataFrame) -> pd.DataFrame: - converter = get_lifter("hg19", "hg38") - - def get_new_pos(v): - try: - res = converter[v.chrom][v.pos] - assert len(res) == 1 - chrom, pos, strand = res[0] - assert chrom.replace("chr", "") == v.chrom - return pos - except: - return -1 - - V.pos = V.apply(get_new_pos, axis=1) - return V - - -def sort_variants(V: pd.DataFrame) -> pd.DataFrame: - V.chrom = pd.Categorical(V.chrom, categories=CHROMS, ordered=True) - V = V.sort_values(COORDINATES) - V.chrom = V.chrom.astype(str) - return V - - def check_ref(V, genome): V = V[V.apply(lambda v: v.ref == genome.get_nuc(v.chrom, v.pos).upper(), axis=1)] return V @@ -249,244 +97,6 @@ def check_ref_alt(V: pd.DataFrame, genome: Genome) -> pd.DataFrame: return V -def match_features( - pos: pl.DataFrame, - neg: pl.DataFrame, - continuous_features: list[str], - categorical_features: list[str], - k: int, - scale: bool = True, - seed: int | None = 42, -) -> pl.DataFrame: - """Match positive samples to k negative samples based on features. - - For each unique combination of categorical features, finds the k closest - negative samples for each positive sample based on continuous features. - - Args: - pos: Positive samples DataFrame. - neg: Negative samples DataFrame. - continuous_features: Columns to use for distance-based matching. - categorical_features: Columns for exact matching (grouping). - k: Number of negative samples to match per positive sample. - scale: Whether to scale continuous features before matching. - seed: Random seed for reproducibility. - - Returns: - DataFrame with matched samples and match_group column. - """ - # Convert to pandas for internal processing - pos = pos.to_pandas() - neg = neg.to_pandas() - - # Scale continuous features if requested, storing in separate columns - if scale and len(continuous_features) > 0: - scaler = RobustScaler() - all_data = pd.concat([pos[continuous_features], neg[continuous_features]]) - scaler.fit(all_data) - scaled_features = [f"{c}_scaled" for c in continuous_features] - pos[scaled_features] = scaler.transform(pos[continuous_features]) - neg[scaled_features] = scaler.transform(neg[continuous_features]) - match_features_cols = scaled_features - else: - match_features_cols = continuous_features - scaled_features = [] - - pos = pos.set_index(categorical_features) - neg = neg.set_index(categorical_features) - res_pos = [] - res_neg = [] - for x in tqdm(pos.index.drop_duplicates()): - pos_x = pos.loc[[x]].reset_index() - try: - neg_x = neg.loc[[x]].reset_index() - except KeyError: - print(f"WARNING: no match for {x}") - continue - if len(pos_x) * k > len(neg_x): - print("WARNING: subsampling positive set") - pos_x = pos_x.sample(len(neg_x) // k, random_state=seed) - if len(continuous_features) == 0: - neg_x = neg_x.sample(len(pos_x) * k, random_state=seed) - else: - neg_x = _find_closest(pos_x, neg_x, match_features_cols, k) - res_pos.append(pos_x) - res_neg.append(neg_x) - res_pos = pd.concat(res_pos, ignore_index=True) - res_pos["match_group"] = np.arange(len(res_pos)) - res_neg = pd.concat(res_neg, ignore_index=True) - res_neg["match_group"] = np.repeat(res_pos.match_group.values, k) - res = pd.concat([res_pos, res_neg], ignore_index=True) - # Drop temporary scaled columns - if scaled_features: - res = res.drop(columns=scaled_features) - return pl.from_pandas(res) - - -def _find_closest( - pos: pd.DataFrame, neg: pd.DataFrame, cols: list[str], k: int -) -> pd.DataFrame: - """Find k closest negative samples for each positive sample.""" - D = cdist(pos[cols], neg[cols]) - closest = [] - for i in range(len(pos)): - js = np.argsort(D[i])[:k].tolist() - closest += js - D[:, js] = np.inf # ensure they cannot be picked up again - return neg.iloc[closest] - - -rule download_genome: - output: - "results/genome.fa.gz", - shell: - "wget -O {output} {config[genome_url]}" - - -rule download_annotation: - output: - "results/annotation.gtf.gz", - shell: - "wget -O {output} {config[annotation_url]}" - - -rule get_tss: - input: - "results/annotation.gtf.gz", - output: - "results/intervals/tss.parquet", - run: - annotation = load_table(input[0]) - tx = annotation.query('feature=="transcript"').copy() - tx["gene_id"] = tx.attribute.str.extract(r'gene_id "([^;]*)";') - tx["transcript_biotype"] = tx.attribute.str.extract( - r'transcript_biotype "([^;]*)";' - ) - tx = tx[tx.transcript_biotype == "protein_coding"] - tss = tx.copy() - tss[["start", "end"]] = tss.progress_apply( - lambda w: (w.start, w.start + 1) if w.strand == "+" else (w.end - 1, w.end), - axis=1, - result_type="expand", - ) - tss = tss[["chrom", "start", "end", "gene_id"]] - print(tss) - tss.to_parquet(output[0], index=False) - - -rule get_exon: - input: - "results/annotation.gtf.gz", - output: - "results/intervals/exon.parquet", - run: - annotation = load_table(input[0]) - exon = annotation.query('feature=="exon"').copy() - exon["gene_id"] = exon.attribute.str.extract(r'gene_id "([^;]*)";') - exon["transcript_biotype"] = exon.attribute.str.extract( - r'transcript_biotype "([^;]*)";' - ) - exon = exon[exon.transcript_biotype == "protein_coding"] - exon = exon[["chrom", "start", "end", "gene_id"]].drop_duplicates() - exon = exon[exon.chrom.isin(CHROMS)] - exon = exon.sort_values(["chrom", "start", "end"]) - print(exon) - exon.to_parquet(output[0], index=False) - - -rule make_ensembl_vep_input: - input: - "{anything}.parquet", - output: - #temp("{anything}.ensembl_vep.input.tsv.gz"), - "{anything}.ensembl_vep.input.tsv.gz", - threads: workflow.cores - run: - df = pd.read_parquet(input[0]) - df["start"] = df.pos - df["end"] = df.start - df["allele"] = df.ref + "/" + df.alt - df["strand"] = "+" - df.to_csv( - output[0], - sep="\t", - header=False, - index=False, - columns=["chrom", "start", "end", "allele", "strand"], - ) - - -# additional snakemake args (SCF): -# --sdm apptainer --apptainer-args "--bind /scratch/users/gbenegas" -# or in savio: -# --sdm apptainer --apptainer-args "--bind /global/scratch/projects/fc_songlab/gbenegas" -rule install_ensembl_vep_cache: - output: - directory("results/ensembl_vep_cache"), - container: - "docker://ensemblorg/ensembl-vep:release_109.1" - threads: workflow.cores - shell: - "INSTALL.pl -c {output} -a cf -s homo_sapiens -y GRCh38" - - -rule run_ensembl_vep: - input: - "{anything}.ensembl_vep.input.tsv.gz", - "results/ensembl_vep_cache", - output: - #temp("{anything}.ensembl_vep.output.tsv.gz"), - #temp("{anything}.ensembl_vep.output.tsv.gz_summary.html"), - "{anything}.ensembl_vep.output.tsv.gz", - "{anything}.ensembl_vep.output.tsv.gz_summary.html", - container: - "docker://ensemblorg/ensembl-vep:release_109.1" - threads: workflow.cores - shell: - """ - vep -i {input[0]} -o {output} --fork {threads} --cache \ - --dir_cache {input[1]} --format ensembl \ - --most_severe --compress_output gzip --tab --distance 1000 --offline - """ - - -rule process_ensembl_vep: - input: - "{anything}.parquet", - "{anything}.ensembl_vep.output.tsv.gz", - output: - "{anything}.annot.parquet", - priority: 100 - run: - V = pl.read_parquet(input[0]) - V2 = pl.read_csv( - input[1], - separator="\t", - has_header=False, - comment_prefix="#", - new_columns=["variant", "consequence"], - columns=[0, 6], - ) - V2 = V2.with_columns( - pl.col("variant").str.split("_").list.get(0).alias("chrom"), - pl.col("variant").str.split("_").list.get(1).cast(pl.Int64).alias("pos"), - pl.col("variant") - .str.split("_") - .list.get(2) - .str.split("/") - .list.get(0) - .alias("ref"), - pl.col("variant") - .str.split("_") - .list.get(2) - .str.split("/") - .list.get(1) - .alias("alt"), - ).drop("variant") - V = V.join(V2, on=COORDINATES, how="left", maintain_order="left") - V.write_parquet(output[0]) - - rule upload_features_to_hf: input: "results/features/{dataset}/{features}.parquet", diff --git a/other/workflow/rules/data/complex_traits.smk b/other/workflow/rules/data/complex_traits.smk deleted file mode 100644 index e36c6db..0000000 --- a/other/workflow/rules/data/complex_traits.smk +++ /dev/null @@ -1,455 +0,0 @@ -rule gwas_download: - output: - temp("results/gwas/UKBB_94traits_release1.1.tar.gz"), - "results/gwas/raw/release1.1/UKBB_94traits_release1.bed.gz", - "results/gwas/raw/release1.1/UKBB_94traits_release1.cols", - "results/gwas/raw/release1.1/UKBB_94traits_release1_regions.bed.gz", - "results/gwas/raw/release1.1/UKBB_94traits_release1_regions.cols", - params: - directory("results/gwas/raw"), - shell: - """ - wget -O {output[0]} https://www.dropbox.com/s/cdsdgwxkxkcq8cn/UKBB_94traits_release1.1.tar.gz?dl=1 && - mkdir -p {params} && - tar -xzvf {output[0]} -C {params} - """ - - -rule gwas_process_main_file: - input: - "results/gwas/raw/release1.1/UKBB_94traits_release1.bed.gz", - output: - "results/gwas/main_file.parquet", - run: - V = ( - pl.read_csv( - input[0], - separator="\t", - has_header=False, - columns=[0, 2, 5, 6, 10, 11, 14, 15, 17, 21, 22], - new_columns=[ - "chrom", - "pos", - "ref", - "alt", - "method", - "trait", - "beta_marginal", - "se_marginal", - "pip", - "LD_HWE", - "LD_SV", - ], - schema_overrides={"column_3": float}, - ) - .with_columns(pl.col("pos").cast(int)) - .filter(~pl.col("LD_HWE"), ~pl.col("LD_SV")) - .with_columns((pl.col("beta_marginal") / pl.col("se_marginal")).alias("z")) - ) - V = ( - V.with_columns( - p=2 * stats.norm.sf(abs(V["z"])), - ) - # when PIP > 0.9, manually override as 0.5 when not genome-wide significant - # (0.5 so it's excluded from both positive and negative set) - .with_columns( - pl.when(pl.col("pip") > 0.9, pl.col("p") > 5e-8) - .then(pl.lit(0.5)) - .otherwise(pl.col("pip")) - .alias("pip") - ).select(["chrom", "pos", "ref", "alt", "trait", "method", "pip"]) - ) - print(V) - V.write_parquet(output[0]) - - -rule gwas_process_secondary_file: - input: - "results/gwas/raw/release1.1/UKBB_94traits_release1_regions.bed.gz", - output: - "results/gwas/secondary_file.parquet", - run: - V = ( - pl.read_csv( - input[0], - separator="\t", - has_header=False, - columns=[4, 6, 7, 8], - new_columns=["trait", "variant", "success_finemap", "success_susie"], - ) - .filter(pl.col("success_finemap"), pl.col("success_susie")) - .drop("success_finemap", "success_susie") - .with_columns( - pl.col("variant") - .str.split_exact(":", 3) - .struct.rename_fields(COORDINATES) - ) - .with_columns( - pl.col("variant").struct.field("chrom"), - pl.col("variant").struct.field("pos").cast(int), - pl.col("variant").struct.field("ref"), - pl.col("variant").struct.field("alt"), - ) - .drop("variant") - .select(["chrom", "pos", "ref", "alt", "trait"]) - .with_columns(pl.lit(0.0).alias("pip")) - ) - print(V) - V = pl.concat( - [ - V.with_columns(pl.lit("SUSIE").alias("method")), - V.with_columns(pl.lit("FINEMAP").alias("method")), - ] - ).select(["chrom", "pos", "ref", "alt", "trait", "method", "pip"]) - print(V) - V.write_parquet(output[0]) - - -rule gwas_process: - input: - "results/gwas/main_file.parquet", - "results/gwas/secondary_file.parquet", - "results/genome.fa.gz", - output: - "results/gwas/processed.parquet", - run: - V = ( - pl.concat([pl.read_parquet(input[0]), pl.read_parquet(input[1])]) - .unique( - ["chrom", "pos", "ref", "alt", "trait", "method"], - keep="first", - maintain_order=True, - ) - .group_by(["chrom", "pos", "ref", "alt", "trait"]) - .agg( - pl.mean("pip"), - (pl.max("pip") - pl.min("pip")).alias("pip_diff"), - pl.count().alias("pip_n"), - ) - .filter(pl.col("pip_n") == 2, pl.col("pip_diff") < 0.05) - .drop("pip_n", "pip_diff") - ) - print(V) - V = ( - V.with_columns( - pl.when(pl.col("pip") > 0.9) - .then(pl.col("trait")) - .otherwise(pl.lit(None)) - .alias("trait") - ) - .group_by(COORDINATES) - .agg(pl.max("pip"), pl.col("trait").drop_nulls().unique()) - .with_columns(pl.col("trait").list.sort().list.join(",")) - .to_pandas() - ) - print(V) - V.chrom = V.chrom.str.replace("chr", "") - V = filter_snp(V) - print(V.shape) - V = lift_hg19_to_hg38(V) - V = V[V.pos != -1] - print(V.shape) - genome = Genome(input[2]) - V = check_ref_alt(V, genome) - print(V.shape) - V = sort_variants(V) - print(V) - V.to_parquet(output[0], index=False) - - -rule complex_traits_dataset: - input: - "results/gwas/processed.parquet", - "results/ldscore/UKBB.EUR.ldscore.annot_with_cre.parquet", - "results/tss.parquet", - output: - "results/dataset/complex_traits_matched_{k,\d+}/test.parquet", - run: - k = int(wildcards.k) - V = ( - pl.read_parquet(input[0]) - .with_columns( - pl.when(pl.col("pip") > 0.9) - .then(True) - .when(pl.col("pip") < 0.01) - .then(False) - .otherwise(None) - .alias("label") - ) - .drop_nulls() - .to_pandas() - ) - - annot = pd.read_parquet(input[1]) - V = V.merge(annot, on=COORDINATES, how="inner") - - V = V[V.consequence.isin(TARGET_CONSEQUENCES)] - - V["start"] = V.pos - 1 - V["end"] = V.pos - - tss = pd.read_parquet(input[2], columns=["chrom", "start", "end"]) - - V = ( - bf.closest(V, tss) - .rename(columns={"distance": "tss_dist"}) - .drop(columns=["start", "end", "chrom_", "start_", "end_"]) - ) - - match_features = ["maf", "ld_score", "tss_dist"] - - consequences = V[V.label].consequence.unique() - V_cs = [] - for c in consequences: - print(c) - V_c = V[V.consequence == c].copy() - for f in match_features: - V_c[f"{f}_scaled"] = RobustScaler().fit_transform( - V_c[f].values.reshape(-1, 1) - ) - print(V_c.label.value_counts()) - V_c = match_columns_k( - V_c, "label", [f"{f}_scaled" for f in match_features], k - ) - V_c["match_group"] = c + "_" + V_c.match_group.astype(str) - print(V_c.label.value_counts()) - print(V_c.groupby("label")[match_features].median()) - V_c.drop(columns=[f"{f}_scaled" for f in match_features], inplace=True) - V_cs.append(V_c) - V = pd.concat(V_cs, ignore_index=True) - V = sort_variants(V) - print(V) - V.to_parquet(output[0], index=False) - - -rule dataset_subset_trait: - input: - "results/dataset/{dataset}/test.parquet", - output: - "results/dataset/{dataset}/subset/{trait}.parquet", - wildcard_constraints: - trait="|".join(select_gwas_traits), - run: - V = pd.read_parquet(input[0]) - V.trait = V.trait.str.split(",") - target_size = len(V[V.match_group == V.match_group.iloc[0]]) - V = V[(~V.label) | (V.trait.apply(lambda x: wildcards.trait in x))] - match_group_size = V.match_group.value_counts() - match_groups = match_group_size[match_group_size == target_size].index - V = V[V.match_group.isin(match_groups)] - V[COORDINATES].to_parquet(output[0], index=False) - - -rule dataset_subset_disease: - input: - "results/dataset/{dataset}/test.parquet", - output: - "results/dataset/{dataset}/subset/disease.parquet", - run: - V = pd.read_parquet(input[0]) - V.trait = V.trait.str.split(",").apply(set) - target_size = len(V[V.match_group == V.match_group.iloc[0]]) - - y = set(config["complex_traits_disease"]) - - V = V[(~V.label) | (V.trait.apply(lambda x: len(x & y) > 0))] - match_group_size = V.match_group.value_counts() - match_groups = match_group_size[match_group_size == target_size].index - V = V[V.match_group.isin(match_groups)] - print(V) - V[COORDINATES].to_parquet(output[0], index=False) - - -rule dataset_subset_non_disease: - input: - "results/dataset/{dataset}/test.parquet", - output: - "results/dataset/{dataset}/subset/non_disease.parquet", - run: - V = pd.read_parquet(input[0]) - V.trait = V.trait.str.split(",").apply(set) - target_size = len(V[V.match_group == V.match_group.iloc[0]]) - - y = set(config["complex_traits_disease"]) - - V = V[(~V.label) | (V.trait.apply(lambda x: len(x & y) == 0))] - match_group_size = V.match_group.value_counts() - match_groups = match_group_size[match_group_size == target_size].index - V = V[V.match_group.isin(match_groups)] - print(V) - V[COORDINATES].to_parquet(output[0], index=False) - - -rule complex_traits_all_dataset: - input: - "results/gwas/processed.parquet", - "results/ldscore/UKBB.EUR.ldscore.annot_with_cre.parquet", - output: - "results/dataset/complex_traits_all/test.parquet", - run: - V = ( - pl.read_parquet(input[0]) - .with_columns( - pl.when(pl.col("pip") > 0.9) - .then(True) - .when(pl.col("pip") < 0.01) - .then(False) - .otherwise(None) - .alias("label") - ) - .drop_nulls() - .to_pandas() - ) - - annot = pd.read_parquet(input[1]) - V = V.merge(annot, on=COORDINATES, how="inner") - - V = V[V.consequence.isin(TARGET_CONSEQUENCES)] - V_pos = V[V.label] - V = V[V.consequence.isin(V_pos.consequence.unique())] - V = V[V.chrom.isin(V_pos.chrom.unique())] - V = sort_variants(V) - print(V) - V.to_parquet(output[0], index=False) - - -rule complex_traits_all_subset_maf_match: - input: - "results/dataset/complex_traits_all/test.parquet", - output: - "results/dataset/complex_traits_all/subset/maf_match.parquet", - run: - V = pd.read_parquet(input[0]) - print(V) - n_bins = 100 - bins = np.linspace(0, 0.5, n_bins + 1) - V["maf_bin"] = pd.cut(V.maf, bins=bins, labels=False) - V_pos = V.query("label") - V_neg = V.query("not label") - pos_hist = V_pos.maf_bin.value_counts().sort_index().values - neg_hist = V_neg.maf_bin.value_counts().sort_index().values - pos_dist = pos_hist / len(V_pos) - pos_dist_ratio_to_max = pos_dist / pos_dist.max() - neg_hist_max = neg_hist[pos_hist.argmax()] - upper_bound = np.floor(neg_hist_max * pos_dist_ratio_to_max) - downsample = (neg_hist / upper_bound).min() - target_n_samples = np.floor(upper_bound * downsample).astype(int) - V_neg_matched = pd.concat( - [ - V_neg[V_neg.maf_bin == i].sample( - target_n_samples[i], replace=False, random_state=42 - ) - for i in range(n_bins) - ] - ) - V = pd.concat([V_pos, V_neg_matched]) - V = sort_variants(V) - print(V) - V[COORDINATES].to_parquet(output[0], index=False) - - -# gene matched -rule complex_dataset_v22: - input: - "results/gwas/processed.parquet", - "results/ldscore/UKBB.EUR.ldscore.annot_with_cre.parquet", - "results/tss.parquet", - output: - "results/dataset/complex_traits_v22_matched_{k,\d+}/test.parquet", - run: - k = int(wildcards.k) - V = ( - pl.read_parquet(input[0]) - .with_columns( - pl.when(pl.col("pip") > 0.9) - .then(True) - .when(pl.col("pip") < 0.01) - .then(False) - .otherwise(None) - .alias("label") - ) - .drop_nulls() - .to_pandas() - ) - - annot = pd.read_parquet(input[1]) - V = V.merge(annot, on=COORDINATES, how="inner") - - V = V[V.consequence.isin(TARGET_CONSEQUENCES)] - - V["start"] = V.pos - 1 - V["end"] = V.pos - - tss = pd.read_parquet(input[2]) - - V = ( - bf.closest(V, tss) - .rename( - columns={ - "distance": "tss_dist", - "gene_id_": "gene", - } - ) - .drop(columns=["start", "end", "chrom_", "start_", "end_"]) - ) - - match_features = ["maf", "ld_score", "tss_dist"] - - consequences = V[V.label].consequence.unique() - V_cs = [] - for c in consequences: - print(c) - V_c = V[V.consequence == c].copy() - for f in match_features: - V_c[f"{f}_scaled"] = RobustScaler().fit_transform( - V_c[f].values.reshape(-1, 1) - ) - print(V_c.label.value_counts()) - V_c = match_columns_k_gene( - V_c, "label", [f"{f}_scaled" for f in match_features], k - ) - V_c["match_group"] = c + "_" + V_c.match_group.astype(str) - print(V_c.label.value_counts()) - print(V_c.groupby("label")[match_features].median()) - V_c.drop(columns=[f"{f}_scaled" for f in match_features], inplace=True) - V_cs.append(V_c) - V = pd.concat(V_cs, ignore_index=True) - V = sort_variants(V) - print(V) - V.to_parquet(output[0], index=False) - - -rule dataset_subset_pip: - input: - "results/dataset/{dataset}/test.parquet", - output: - "results/dataset/{dataset}/subset/pip_{pip}.parquet", - run: - V = pd.read_parquet(input[0]) - target_size = len(V[V.match_group == V.match_group.iloc[0]]) - V = V[(~V.label) | (V.pip > float(wildcards.pip))] - match_group_size = V.match_group.value_counts() - match_groups = match_group_size[match_group_size == target_size].index - V = V[V.match_group.isin(match_groups)] - print(V.label.value_counts()) - V[COORDINATES].to_parquet(output[0], index=False) - - -rule dataset_subset_pleiotropy: - input: - "results/dataset/{dataset}/test.parquet", - output: - "results/dataset/{dataset}/subset/pleiotropy_{pleiotropy,yes|no}.parquet", - run: - V = pd.read_parquet(input[0]) - V["n_traits"] = V.trait.str.split(",").apply(len) - target_size = len(V[V.match_group == V.match_group.iloc[0]]) - if wildcards.pleiotropy == "yes": - V = V[(~V.label) | (V.n_traits > 1)] - else: - V = V[(~V.label) | (V.n_traits == 1)] - match_group_size = V.match_group.value_counts() - match_groups = match_group_size[match_group_size == target_size].index - V = V[V.match_group.isin(match_groups)] - print(V.label.value_counts()) - V[COORDINATES].to_parquet(output[0], index=False) diff --git a/other/workflow/rules/features/alphagenome.smk b/other/workflow/rules/features/alphagenome.smk deleted file mode 100644 index e1386d1..0000000 --- a/other/workflow/rules/features/alphagenome.smk +++ /dev/null @@ -1,30 +0,0 @@ -rule run_vep_alphagenome: - input: - "results/dataset/{dataset}/test.parquet", - output: - "results/dataset/{dataset}/features/AlphaGenome_L2.parquet", - threads: workflow.cores - shell: - "python workflow/scripts/vep_alphagenome.py {input} {output} --num_workers {threads}" - - -rule alphagenome_aggregate_assay: - input: - "results/dataset/{dataset}/features/AlphaGenome_L2.parquet", - output: - "results/dataset/{dataset}/features/AlphaGenome_L2_{operation,max|mean|L2}.parquet", - run: - operation = wildcards.operation - df = pd.read_parquet(input[0]) - col_assay = df.columns.str.split("-").str[0] - assays = ["all"] + col_assay.unique().tolist() - result = {} - for assay in assays: - df_assay = df if assay == "all" else df.loc[:, col_assay == assay] - if operation == "max": - result[assay] = df_assay.max(axis=1) - elif operation == "mean": - result[assay] = df_assay.mean(axis=1) - elif operation == "L2": - result[assay] = np.linalg.norm(df_assay.values, axis=1, ord=2) - pd.DataFrame(result).to_parquet(output[0], index=False) diff --git a/other/workflow/rules/features/gpn_star.smk b/other/workflow/rules/features/gpn_star.smk deleted file mode 100644 index 3d80269..0000000 --- a/other/workflow/rules/features/gpn_star.smk +++ /dev/null @@ -1,66 +0,0 @@ -rule gpn_star_download_model: - output: - directory("results/gpn_star/checkpoints/{model}"), - params: - repo_id=lambda wildcards: config["gpn_star"][wildcards.model]["repo_id"], - threads: workflow.cores - shell: - "hf download {params.repo_id} --local-dir {output} --max-workers {threads}" - - -# intermediate output used to obtain both LLR, entropy -rule get_logits: - input: - "results/dataset/{dataset}/test.parquet", - lambda wildcards: config["gpn_star"][wildcards.model]["msa_path"], - "results/gpn_star/checkpoints/{model}", - output: - "results/gpn_star/logits/{dataset}/{model}.parquet", - params: - window_size=lambda wildcards: config["gpn_star"][wildcards.model]["window_size"], - resources: - # using resources to avoid re-runs when changing batch size - per_device_batch_size=lambda wildcards: config["gpn_star"][wildcards.model][ - "per_device_batch_size" - ], - threads: workflow.cores - shell: - """ - python \ - -m gpn.star.inference logits \ - {input[0]} {input[1]} {params.window_size} {input[2]} {output} \ - --per_device_batch_size {resources.per_device_batch_size} \ - --dataloader_num_workers {threads} \ - --is_file - """ - - -rule get_llr_calibrated: - input: - "results/dataset/{dataset}/test.parquet", - "results/genome.fa.gz", - "results/gpn_star/checkpoints/{model}", - "results/gpn_star/logits/{dataset}/{model}.parquet", - output: - "results/dataset/{dataset}/features/GPN-Star-{model}_LLR.parquet", - params: - calibration_path="results/gpn_star/checkpoints/{model}/calibration_table/llr.parquet", - run: - from gpn.star.utils import normalize_logits, get_llr - - V = pd.read_parquet(input[0]) - genome = Genome(input[1]) - V["pentanuc"] = V.apply( - lambda row: genome.get_seq( - row["chrom"], row["pos"] - 3, row["pos"] + 2 - ).upper(), - axis=1, - ) - V["pentanuc_mut"] = V["pentanuc"] + "_" + V["alt"] - df_calibration = pd.read_parquet(params.calibration_path) - logits = pd.read_parquet(input[3]) - normalized_logits = normalize_logits(logits) - V["llr"] = get_llr(normalized_logits, V["ref"], V["alt"]) - V = V.merge(df_calibration, on="pentanuc_mut", how="left") - V["llr_calibrated"] = V["llr"] - V["llr_neutral_mean"] - V[["llr", "llr_calibrated"]].to_parquet(output[0], index=False) diff --git a/other/workflow/scripts/vep_alphagenome.py b/other/workflow/scripts/vep_alphagenome.py deleted file mode 100644 index 585c672..0000000 --- a/other/workflow/scripts/vep_alphagenome.py +++ /dev/null @@ -1,117 +0,0 @@ -from alphagenome.data import genome -from alphagenome.models import dna_client, variant_scorers -import argparse -import concurrent.futures -import numpy as np -import os -import pandas as pd -from tqdm import tqdm - - -def run_vep( - V, - num_workers=0, -): - model = dna_client.create(os.environ.get("ALPHA_GENOME_API_KEY")) - metadata = model.output_metadata() - sequence_length = "1MB" - sequence_length = dna_client.SUPPORTED_SEQUENCE_LENGTHS[ - f"SEQUENCE_LENGTH_{sequence_length}" - ] - organism = dna_client.Organism.HOMO_SAPIENS - - tracks = [ - "ATAC", - "DNASE", - "CHIP_TF", - "CHIP_HISTONE", - "CAGE", - "PROCAP", - "RNA_SEQ", - ] - scorers = [ - variant_scorers.CenterMaskScorer( - requested_output=getattr(dna_client.OutputType, track), - width=None, - aggregation_type=variant_scorers.AggregationType.L2_DIFF_LOG1P, - ) - for track in tracks - ] - reverse_map = dict(zip(map(str, scorers), tracks)) - - def run_vep_variant(v): - variant = genome.Variant( - chromosome=v.chrom, - position=v.pos, - reference_bases=v.ref, - alternate_bases=v.alt, - ) - interval = variant.reference_interval.resize(sequence_length) - # need to put strand = "+" to fwd, since by default is "." - interval_fwd = interval.copy() - interval_fwd.strand = "+" - interval_rev = interval.copy() - interval_rev.strand = "-" - - def my_score_interval(interval): - variant_scores = model.score_variant( - interval=interval, - variant=variant, - organism=organism, - variant_scorers=scorers, - ) - return variant_scorers.tidy_scores([variant_scores]) - - res_fwd = my_score_interval(interval_fwd) - res_rev = my_score_interval(interval_rev) - assert (res_fwd.index == res_rev.index).all() and ( - res_fwd.columns == res_rev.columns - ).all() - res = res_fwd - res.raw_score = (res.raw_score + res_rev.raw_score) / 2 - res["variant_scorer"] = res.variant_scorer.map(reverse_map) - res["track"] = ( - res["variant_scorer"] - + "-" - + res.groupby("variant_scorer").cumcount().astype(str) - ) - res = res.set_index("track")[["raw_score"]].T - return res - - rows_iterable = V.itertuples(index=False) - - # simple version - # res = list(tqdm(map(run_vep_variant, rows_iterable), total=len(V))) - - # parallel version (watch out for API limits) - with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - res = list(tqdm(executor.map(run_vep_variant, rows_iterable), total=len(V))) - - res = pd.concat(res, axis=0, ignore_index=True) - return res - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Variant effect prediction") - parser.add_argument( - "variants_path", - type=str, - help="Variants path. Needs the following columns: chrom,pos,ref,alt. pos should be 1-based", - ) - parser.add_argument("output_path", help="Output path (parquet)", type=str) - parser.add_argument( - "--num_workers", - type=int, - default=0, - ) - args = parser.parse_args() - - V = pd.read_parquet(args.variants_path, columns=["chrom", "pos", "ref", "alt"]) - V.chrom = "chr" + V.chrom - - pred = run_vep(V, num_workers=args.num_workers) - - directory = os.path.dirname(args.output_path) - if directory != "" and not os.path.exists(directory): - os.makedirs(directory) - pred.to_parquet(args.output_path, index=False) diff --git a/pyproject.toml b/pyproject.toml index 25ed676..a525f0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,12 @@ allow-direct-references = true [tool.hatch.build.targets.wheel] packages = ["src/traitgym"] +[tool.pytest.ini_options] +addopts = "-m 'not slow'" +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", +] + [tool.snakefmt] include = '\.smk$|^Snakefile|\.py$' diff --git a/src/traitgym/intervals.py b/src/traitgym/intervals.py index b13b671..322e9b1 100644 --- a/src/traitgym/intervals.py +++ b/src/traitgym/intervals.py @@ -1,7 +1,7 @@ import polars as pl import polars_bio as pb -from traitgym.variants import CHROMS, NON_EXONIC +from traitgym.variants import CHROMS, COORDINATES, NON_EXONIC GTF_COLUMNS = [ @@ -200,3 +200,42 @@ def add_tss( .otherwise(pl.col("consequence_final")) .alias("consequence_final"), ) + + +def build_dataset( + V: pl.DataFrame, + exon: pl.DataFrame, + tss: pl.DataFrame, + exclude_consequences: list[str], + exon_proximal_dist: int, + tss_proximal_dist: int, + consequence_groups: dict[str, list[str]], +) -> pl.DataFrame: + """Build a dataset with final consequence annotations and groups. + + Args: + V: Variant DataFrame with columns including 'label' and 'consequence'. + exon: Exon intervals from get_exon(). + tss: TSS intervals from get_tss(). + exclude_consequences: List of consequences to filter out. + exon_proximal_dist: Distance threshold for exon_proximal consequence. + tss_proximal_dist: Distance threshold for tss_proximal consequence. + consequence_groups: Mapping from group name to list of consequences. + + Returns: + Processed DataFrame with consequence_final, consequence_group columns, + filtered and sorted by COORDINATES. + """ + V = ( + V.filter(~pl.col("consequence").is_in(exclude_consequences)) + .pipe(add_exon, exon, exon_proximal_dist) + .pipe(add_tss, tss, tss_proximal_dist) + ) + consequence_final_pos = V.filter("label")["consequence_final"].unique() + V = V.filter(pl.col("consequence_final").is_in(consequence_final_pos)) + consequence_to_group = { + c: group for group, consequences in consequence_groups.items() for c in consequences + } + return V.with_columns( + pl.col("consequence_final").replace(consequence_to_group).alias("consequence_group") + ).sort(COORDINATES) diff --git a/src/traitgym/matching.py b/src/traitgym/matching.py index 37f91a4..167c473 100644 --- a/src/traitgym/matching.py +++ b/src/traitgym/matching.py @@ -45,6 +45,9 @@ def match_features( required_cols = COORDINATES + continuous_features + categorical_features _validate_columns(pos, required_cols, "pos") _validate_columns(neg, required_cols, "neg") + all_features = continuous_features + categorical_features + _validate_no_nulls(pos, all_features, "pos") + _validate_no_nulls(neg, all_features, "neg") pos_pd = pos.to_pandas() neg_pd = neg.to_pandas() @@ -98,6 +101,27 @@ def _validate_columns( raise ValueError(f"{name} is missing columns: {sorted(missing)}") +def _validate_no_nulls( + df: pl.DataFrame, + columns: list[str], + name: str, +) -> None: + """Validate DataFrame has no null values in specified columns. + + Args: + df: DataFrame to validate. + columns: List of column names to check. + name: Name of DataFrame for error messages. + + Raises: + ValueError: If any columns contain null values. + """ + for col in columns: + null_count = df[col].null_count() + if null_count > 0: + raise ValueError(f"{name} has {null_count} null values in column '{col}'") + + def _scale_features( pos: pd.DataFrame, neg: pd.DataFrame, diff --git a/src/traitgym/variants.py b/src/traitgym/variants.py index d168a56..9012150 100644 --- a/src/traitgym/variants.py +++ b/src/traitgym/variants.py @@ -57,3 +57,41 @@ def get_new_coords(row: dict) -> dict: .unnest("_coords") .select(original_columns) ) + + +def check_ref_alt(V: pl.DataFrame, genome) -> pl.DataFrame: + """ + Check and fix ref/alt alleles against the reference genome. + + For each variant: + 1. Get the reference nucleotide from the genome at that position + 2. If ref doesn't match, swap ref and alt + 3. Filter out variants where neither ref nor alt matches the reference + + Args: + V: DataFrame with chrom, pos, ref, alt columns + genome: Genome object with get_nuc(chrom, pos) method + + Returns: + DataFrame with corrected ref/alt and filtered to valid variants + """ + + def get_ref_nuc(row: dict) -> str: + return genome.get_nuc(row["chrom"], row["pos"]).upper() + + original_columns = V.columns + + return ( + V.with_columns( + pl.struct(["chrom", "pos"]) + .map_elements(get_ref_nuc, return_dtype=pl.Utf8) + .alias("_ref_nuc") + ) + .with_columns(_needs_swap=(pl.col("ref") != pl.col("_ref_nuc"))) + .with_columns( + ref=pl.when(pl.col("_needs_swap")).then(pl.col("alt")).otherwise(pl.col("ref")), + alt=pl.when(pl.col("_needs_swap")).then(pl.col("ref")).otherwise(pl.col("alt")), + ) + .filter(pl.col("ref") == pl.col("_ref_nuc")) + .select(original_columns) + ) diff --git a/tests/test_intervals.py b/tests/test_intervals.py index f78fcb7..84f0ef4 100644 --- a/tests/test_intervals.py +++ b/tests/test_intervals.py @@ -3,10 +3,13 @@ from traitgym.intervals import add_exon, add_tss, get_exon, get_tss, load_annotation +ANNOTATION_PATH = "dataset/results/annotation.gtf.gz" + +@pytest.mark.slow class TestLoadAnnotation: def test_loads_gtf(self) -> None: - ann = load_annotation("other/results/annotation.gtf.gz") + ann = load_annotation(ANNOTATION_PATH) assert ann.shape[0] > 0 assert set(ann.columns) == { "chrom", @@ -21,17 +24,18 @@ def test_loads_gtf(self) -> None: } def test_converts_to_0_based(self) -> None: - ann = load_annotation("other/results/annotation.gtf.gz") + ann = load_annotation(ANNOTATION_PATH) # GTF is 1-based, BED is 0-based, so start should be decremented # First gene on chr1 starts at 11869 in GTF, should be 11868 in BED chr1 = ann.filter(pl.col("chrom") == "1") assert chr1["start"].min() == 11868 +@pytest.mark.slow class TestGetTss: @pytest.fixture def annotation(self) -> pl.DataFrame: - return load_annotation("other/results/annotation.gtf.gz") + return load_annotation(ANNOTATION_PATH) def test_output_schema(self, annotation: pl.DataFrame) -> None: tss = get_tss(annotation) @@ -56,7 +60,7 @@ def test_matches_original_implementation(self, annotation: pl.DataFrame) -> None my_tss = get_tss(annotation) # Original logic (without unique/sort) - ann_pd = load_table("other/results/annotation.gtf.gz") + ann_pd = load_table(ANNOTATION_PATH) tx = ann_pd.query('feature=="transcript"').copy() tx["gene_id"] = tx.attribute.str.extract(r'gene_id "([^;]*)";') tx["transcript_biotype"] = tx.attribute.str.extract( @@ -75,10 +79,11 @@ def test_matches_original_implementation(self, annotation: pl.DataFrame) -> None assert my_sorted.equals(orig_tss) +@pytest.mark.slow class TestGetExon: @pytest.fixture def annotation(self) -> pl.DataFrame: - return load_annotation("other/results/annotation.gtf.gz") + return load_annotation(ANNOTATION_PATH) def test_output_schema(self, annotation: pl.DataFrame) -> None: exon = get_exon(annotation) @@ -115,7 +120,7 @@ def test_matches_original_implementation(self, annotation: pl.DataFrame) -> None my_exon = get_exon(annotation) # Original logic - ann_pd = load_table("other/results/annotation.gtf.gz") + ann_pd = load_table(ANNOTATION_PATH) exon = ann_pd.query('feature=="exon"').copy() exon["gene_id"] = exon.attribute.str.extract(r'gene_id "([^;]*)";') exon["transcript_biotype"] = exon.attribute.str.extract( @@ -133,10 +138,11 @@ def test_matches_original_implementation(self, annotation: pl.DataFrame) -> None assert my_sorted.equals(orig_sorted) +@pytest.mark.slow class TestAddExon: @pytest.fixture def exon(self) -> pl.DataFrame: - ann = load_annotation("other/results/annotation.gtf.gz") + ann = load_annotation(ANNOTATION_PATH) return get_exon(ann) @pytest.fixture @@ -186,15 +192,16 @@ def test_distances_match_bioframe(self, variants: pl.DataFrame, exon: pl.DataFra assert result["exon_dist"].to_list() == result_bf["distance"].tolist() +@pytest.mark.slow class TestAddTss: @pytest.fixture def tss(self) -> pl.DataFrame: - ann = load_annotation("other/results/annotation.gtf.gz") + ann = load_annotation(ANNOTATION_PATH) return get_tss(ann) @pytest.fixture def exon(self) -> pl.DataFrame: - ann = load_annotation("other/results/annotation.gtf.gz") + ann = load_annotation(ANNOTATION_PATH) return get_exon(ann) @pytest.fixture diff --git a/tests/test_variants.py b/tests/test_variants.py index 53e93fe..6b36e94 100644 --- a/tests/test_variants.py +++ b/tests/test_variants.py @@ -18,7 +18,7 @@ def test_nucleotides(self) -> None: assert NUCLEOTIDES == ["A", "C", "G", "T"] def test_chroms(self) -> None: - expected = [str(i) for i in range(1, 23)] + ["X", "Y"] + expected = sorted([str(i) for i in range(1, 23)] + ["X", "Y"]) assert CHROMS == expected assert len(CHROMS) == 24