"""Venn plotting for overlapping negative DMR genes across modifications."""
import csv
import logging
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.patches import Circle
from modalysis.core.plots.label_format import format_modification_label
logger = logging.getLogger(__name__)
REGIONS = ["PROMOTER", "BODY", "ENHANCER"]
[docs]
def _collect_negative_gene_sets(
annotated_dmr_paths: list[str],
manifestations: list[str],
modifications: list[str],
) -> tuple[dict[tuple[str, str, str], set[str]], list[str]]:
"""Collect per-region gene sets from negative-effect DMR rows only."""
sets_by_manifestation_modification_region: dict[tuple[str, str, str], set[str]] = {}
manifestation_order = []
num_processed_rows = 0
for dmr_path, manifestation, modification in zip(
annotated_dmr_paths, manifestations, modifications
):
normalized_manifestation = manifestation.strip().upper()
normalized_modification = modification.strip().upper()
if normalized_manifestation not in manifestation_order:
manifestation_order.append(normalized_manifestation)
dmr_file = open(dmr_path, newline="")
dmr_reader = csv.DictReader(dmr_file, delimiter="\t")
for row in dmr_reader:
effect_size = float(row["EFFECT_SIZE"])
if effect_size >= 0:
num_processed_rows += 1
continue
for region in REGIONS:
genes_field = row[region].strip()
if not genes_field:
continue
key = (normalized_manifestation, normalized_modification, region)
if key not in sets_by_manifestation_modification_region:
sets_by_manifestation_modification_region[key] = set()
genes = [gene.strip().upper() for gene in genes_field.split(",")]
for gene in genes:
if gene:
sets_by_manifestation_modification_region[key].add(gene)
num_processed_rows += 1
dmr_file.close()
logger.info(
"Collected negative DMR gene sets from %s rows across %s files.",
num_processed_rows,
len(annotated_dmr_paths),
)
return sets_by_manifestation_modification_region, manifestation_order
[docs]
def _draw_venn_panel(
ax: Axes,
set_a: set[str],
set_b: set[str],
label_a: str,
label_b: str,
title: str,
) -> None:
"""Draw one two-set Venn-like panel with counts and labels."""
only_a = len(set_a - set_b)
only_b = len(set_b - set_a)
both = len(set_a.intersection(set_b))
left_center_x = 0.38
right_center_x = 0.62
center_y = 0.52
radius = 0.28
exclusive_count_x_offset = radius * 0.78
left_exclusive_x = left_center_x - exclusive_count_x_offset
right_exclusive_x = right_center_x + exclusive_count_x_offset
ax.add_patch(
Circle((left_center_x, center_y), radius, color="#1f77b4", alpha=0.35, lw=1.2)
)
ax.add_patch(
Circle((right_center_x, center_y), radius, color="#ff7f0e", alpha=0.35, lw=1.2)
)
ax.text(left_exclusive_x, center_y, str(only_a), ha="center", va="center", fontsize=11)
ax.text(right_exclusive_x, center_y, str(only_b), ha="center", va="center", fontsize=11)
ax.text(
0.50,
center_y,
str(both),
ha="center",
va="center",
fontsize=11,
fontweight="bold",
)
ax.text(
left_exclusive_x,
0.20,
"%s\n(n=%s)" % (label_a, len(set_a)),
ha="center",
va="center",
fontsize=8,
)
ax.text(
right_exclusive_x,
0.20,
"%s\n(n=%s)" % (label_b, len(set_b)),
ha="center",
va="center",
fontsize=8,
)
ax.set_title(title, fontsize=10, pad=8)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")
[docs]
def plot_common_genes_venn(
annotated_dmr_paths: list[str],
manifestations: list[str],
modifications: list[str],
modification_a: str,
modification_b: str,
output_path: str,
output_name: str,
) -> None:
"""Render regional Venn panels comparing two modifications per manifestation."""
if len(annotated_dmr_paths) != len(manifestations):
raise ValueError(
"Number of annotated DMR paths (%d) must match number of manifestations (%d)"
% (len(annotated_dmr_paths), len(manifestations))
)
if len(annotated_dmr_paths) != len(modifications):
raise ValueError(
"Number of annotated DMR paths (%d) must match number of modifications (%d)"
% (len(annotated_dmr_paths), len(modifications))
)
normalized_modification_a = modification_a.strip().upper()
normalized_modification_b = modification_b.strip().upper()
if normalized_modification_a == normalized_modification_b:
raise ValueError("Modification A and B must be different")
display_modification_a = format_modification_label(normalized_modification_a)
display_modification_b = format_modification_label(normalized_modification_b)
output_file_path = (Path(output_path) / output_name).with_suffix(".png")
logger.info(
"Plotting common-genes venn diagrams. Inputs: %s, Output: %s",
annotated_dmr_paths,
output_file_path,
)
sets_by_key, manifestation_order = _collect_negative_gene_sets(
annotated_dmr_paths,
manifestations,
modifications,
)
if not manifestation_order:
raise ValueError("No manifestation inputs provided")
num_rows = len(manifestation_order)
num_cols = len(REGIONS)
fig, axes = plt.subplots(
num_rows,
num_cols,
figsize=(num_cols * 4.2, max(3.5, num_rows * 3.8)),
)
if num_rows == 1:
axes = [axes]
for row_idx, manifestation in enumerate(manifestation_order):
axis_row = axes[row_idx]
for col_idx, region in enumerate(REGIONS):
ax = axis_row[col_idx]
set_a = sets_by_key.get((manifestation, normalized_modification_a, region), set())
set_b = sets_by_key.get((manifestation, normalized_modification_b, region), set())
panel_title = "%s %s" % (manifestation, region.capitalize())
_draw_venn_panel(
ax,
set_a,
set_b,
display_modification_a,
display_modification_b,
panel_title,
)
fig.suptitle(
"Common DMR Genes Venn (Negative DMRs): %s vs %s"
% (display_modification_a, display_modification_b),
fontsize=12,
y=1.01,
)
plt.tight_layout()
plt.savefig(output_file_path, dpi=150, bbox_inches="tight")
plt.close(fig)
logger.info("Saved common-genes venn diagram to %s", output_file_path)