Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@ import org.apache.spark.sql.{Column, DataFrame}

object ACMGImplicits {

val variantColumns = Array("chromosome", "start", "end", "reference", "alternate")

private def validateRequiredColumns(map: Map[DataFrame, (String, Array[String])], criteriaName: String = "criteria"): Unit = {
map.foreach {
case (df, (dfName, columns)) => columns.foreach(
col => require(
df.columns.contains(col),
s"Column `$col` is required in DataFrame $dfName for $criteriaName.")
)
}
}

/**
* inColArray
* Anonymous helper function generating boolean columns for array-containing columns.
*
* Designed to be used in a withColumn statement. Checks whether an entry in the column contains
* at least one value from values argument.
*
* @return A Column of boolean
*/
val inColArray = (colName: String, values: List[String]) => values.map(m => array_contains(col(colName), m)).reduce(_ || _)

implicit class ACMGOperations(df: DataFrame) {

/**
Expand Down Expand Up @@ -46,6 +69,69 @@ object ACMGImplicits {
}
}

def getBS2(orphanet: DataFrame, frequencies: DataFrame): DataFrame = {

val map = Map(
df -> ("df", Array("symbol") ++ variantColumns),
orphanet -> ("orphanet", Array("gene_symbol", "average_age_of_onset", "type_of_inheritance")),
frequencies -> ("frequencies", Array("external_frequencies", "genes_symbol") ++ variantColumns)
)
validateRequiredColumns(map, "PM2")

val threshold = 4

val onsets = List(
"Adult",
"Elderly",
"All ages",
"No data available")

val is_dominant_inheritance = List(
"Autosomal dominant",
"X-linked dominant",
"Y-linked",
"Mitochondrial inheritance")

val orphanetDF = orphanet.select("gene_symbol", "average_age_of_onset", "type_of_inheritance")
.withColumn("is_adult_onset", inColArray("average_age_of_onset", onsets))
.filter(col("is_adult_onset") === false)
.withColumn("is_dominant", inColArray("type_of_inheritance", is_dominant_inheritance))
.select(
col("gene_symbol").as("symbol"),
col("is_dominant"))
.distinct()

val freqDF = frequencies
.select(
col("chromosome"),
col("start"),
col("end"),
col("reference"),
col("alternate"),
explode(col("genes_symbol")).as("symbol"),
col("external_frequencies.gnomad_genomes_3_1_1.ac").as("gnomad_ac"),
col("external_frequencies.gnomad_genomes_3_1_1.hom").as("gnomad_hom"))

df
.join(orphanetDF, Seq("symbol"), "leftouter")
.na.fill(false, Seq("is_dominant"))
.join(freqDF, Seq("chromosome", "start", "end", "reference", "alternate", "symbol"), "leftouter")
.withColumn("BS2", struct(
col("gnomad_ac"),
col("gnomad_hom"),
col("is_dominant"),
(
col("gnomad_hom").isNotNull &&
(
col("gnomad_hom") >= threshold ||
(col("is_dominant") && col("gnomad_ac").isNotNull && col("gnomad_ac") >= threshold)
)
).as("score")
))
.drop("gnomad_ac", "gnomad_hom", "is_dominant")

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers

spark.sparkContext.setLogLevel("ERROR")

val variantSchema = new StructType()
.add("chromosome", StringType, true)
.add("start", IntegerType, true)
.add("end", IntegerType, true)
.add("reference", StringType, true)
.add("alternate", StringType, true)

def ba1Fixture = {
new {
val querySchema = new StructType()
Expand Down Expand Up @@ -65,4 +72,94 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers
f.result.collect() should contain theSameElementsAs f.resultData
}

def bs2Fixture = {
new {
val orphanetSchema = new StructType()
.add("gene_symbol", StringType, false)
.add("average_age_of_onset", new ArrayType(StringType, true), true)
.add("type_of_inheritance", new ArrayType(StringType, true), true)

val orphanetData = Seq(
Row("gene1", Array("Neonatal", "Antenatal"), Array("Autosomal recessive")),
Row("gene2", Array("Neonatal"), Array("Autosomal dominant")),
Row("gene3", Array("All ages"), Array("Autosomal dominant")),
)

val orphanetDF = spark.createDataFrame(spark.sparkContext.parallelize(orphanetData), orphanetSchema)

val freqSchema = variantSchema
.add("genes_symbol", new ArrayType(StringType, true), true)
.add("external_frequencies", new StructType()
.add("gnomad_genomes_3_1_1", new StructType()
.add("ac", IntegerType, true)
.add("hom", IntegerType, true)))

val freqData = Seq(Row("1", 1, 2, "A", "C", Array("gene1"), Row(Row(10, 0))))

val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), freqSchema)

val querySchema = variantSchema
.add("symbol", StringType, true)

val queryData = Seq(Row("1", 1, 2, "A", "C", "gene1"))

val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), querySchema)
}
}

"get_BS2" should "throw IllegalArgumentException if `average_age_of_onset` column is absent from the Orphanet DataFrame" in {
val f = bs2Fixture

an[IllegalArgumentException] should be thrownBy f.queryDF.getBS2(f.orphanetDF.drop("average_age_of_onset"), f.freqDF)
}

it should "return observed homozygote alleles as BS2 true" in {
val f = bs2Fixture

val freqData = Seq(
Row("1", 1, 2, "A", "C", Array("gene1"), Row(Row(25, 4))),
Row("1", 3, 4, "T", "C", Array("gene1"), Row(Row(10, 0))))
val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), f.freqSchema)

val queryData = Seq(
Row("1", 1, 2, "A", "C", "gene1"),
Row("1", 3, 4, "T", "C", "gene1"),
Row("1", 5, 6, "G", "C", "gene1"))
val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), f.querySchema)

val resultData = Seq(
Row("1", 1, 2, "A", "C", "gene1", Row(25, 4, false, true)),
Row("1", 3, 4, "T", "C", "gene1", Row(10, 0, false, false)),
Row("1", 5, 6, "G", "C", "gene1", Row(null, null, false, false)),
)

val result = queryDF.getBS2(f.orphanetDF, freqDF)
result.collect() should contain theSameElementsAs resultData
}

it should "return observed heterozygote allele in recessive non-adult onset diseases as BS2 true" in {
val f = bs2Fixture

val freqData = Seq(
Row("1", 1, 2, "A", "C", Array("gene1"), Row(Row(15, 0))),
Row("1", 1, 2, "A", "C", Array("gene2"), Row(Row(15, 0))),
Row("1", 1, 2, "A", "C", Array("gene3"), Row(Row(15, 0))),
)
val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), f.freqSchema)

val queryData = Seq(
Row("1", 1, 2, "A", "C", "gene1"),
Row("1", 1, 2, "A", "C", "gene2"),
Row("1", 1, 2, "A", "C", "gene3"))
val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), f.querySchema)

val resultData = Seq(
Row("1", 1, 2, "A", "C", "gene1", Row(15, 0, false, false)),
Row("1", 1, 2, "A", "C", "gene2", Row(15, 0, true, true)),
Row("1", 1, 2, "A", "C", "gene3", Row(15, 0, false, false)),
)

val result = queryDF.getBS2(f.orphanetDF, freqDF)
result.collect() should contain theSameElementsAs resultData
}
}