"""Mean methylation line-plot generation across regions and chromosomes."""
import csv
import logging
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from modalysis.core.gene_regions import (
GeneRegionsByChromosome,
RegionList,
build_gene_regions,
find_genes_at_position,
parse_gff,
)
from modalysis.core.plots.label_format import format_modification_label
logger = logging.getLogger(__name__)
RegionAccum = dict[tuple[str, str], list[int]]
[docs]
def _find_overlapping_regions(
position: int,
region_list: RegionList,
starts_list: list[int],
) -> bool:
"""Check if a position overlaps with any regions using binary search.
Returns True if the position falls within at least one region.
A position overlaps a region if region_start <= position < region_end.
"""
return bool(find_genes_at_position(position, region_list, starts_list))
[docs]
def _accumulate_pileup(
merged_pileup_path: str,
regions: GeneRegionsByChromosome,
) -> RegionAccum:
"""Read a merged pileup file and accumulate n_valid_cov and n_mod
per (chromosome, region).
Returns:
dict: (chromosome, region_name) -> [sum_n_valid_cov, sum_n_mod]
"""
accum: RegionAccum = {}
input_file = open(merged_pileup_path, newline="")
reader = csv.reader(input_file, delimiter="\t")
header = next(reader)
logger.debug("Merged pileup header: %s", header)
num_rows = 0
num_assigned = 0
for row in reader:
chromosome = row[0]
start = int(row[1])
n_valid_cov = int(row[4])
n_mod = int(row[5])
num_rows += 1
if chromosome not in regions:
continue
chrom_regions = regions[chromosome]
for region_name in ("promoter", "body", "enhancer"):
if _find_overlapping_regions(
start,
chrom_regions[region_name],
chrom_regions[f"{region_name}_starts"],
):
key = (chromosome, region_name)
if key not in accum:
accum[key] = [0, 0]
accum[key][0] += n_valid_cov
accum[key][1] += n_mod
num_assigned += 1
input_file.close()
logger.info(
"Accumulated pileup %s: %s rows read, %s region assignments.",
merged_pileup_path,
num_rows,
num_assigned,
)
return accum
[docs]
def plot_mean_methylation(
gff_path: str,
merged_pileup_paths: list[str],
labels: list[str],
output_path: str,
output_name: str,
y_min: float = 0.0,
y_max: float = 0.1,
chromosome_order: list[str] | None = None,
plot_title: str | None = None,
) -> None:
"""Generate region-grouped chromosome methylation line plots."""
output_file_path = (Path(output_path) / output_name).with_suffix(".png")
logger.info(
"Plotting mean methylation. GFF: %s, Pileups: %s, Output: %s",
gff_path,
merged_pileup_paths,
output_file_path,
)
# Step 1: Parse GFF and build region boundaries
genes_by_chromosome = parse_gff(gff_path)
regions = build_gene_regions(genes_by_chromosome)
# Collect all chromosomes, sorted unless explicit order was provided.
all_chromosomes = sorted(genes_by_chromosome.keys())
if chromosome_order:
chromosome_by_upper = {chrom.upper(): chrom for chrom in all_chromosomes}
ordered = []
seen = set()
for chrom in chromosome_order:
normalized = chrom.strip().upper()
if normalized in chromosome_by_upper:
canonical = chromosome_by_upper[normalized]
if canonical not in seen:
ordered.append(canonical)
seen.add(canonical)
for chrom in all_chromosomes:
if chrom not in seen:
ordered.append(chrom)
all_chromosomes = ordered
logger.info("Chromosomes found: %s", all_chromosomes)
# Step 2: For each merged pileup, accumulate and compute mean methylation
region_names = ["promoter", "body", "enhancer"]
num_chromosomes = len(all_chromosomes)
# Build x-axis tick labels and positions
x_labels = []
for region_name in region_names:
for chrom in all_chromosomes:
x_labels.append(chrom)
x_positions = list(range(len(x_labels)))
# Step 3: For each pileup file, compute Y values
sns.set_theme(style="white")
fig, ax = plt.subplots(figsize=(max(20, num_chromosomes * 3), 8))
for pileup_idx, (pileup_path, label) in enumerate(zip(merged_pileup_paths, labels)):
accum = _accumulate_pileup(pileup_path, regions)
display_label = format_modification_label(label)
y_values = []
for region_name in region_names:
for chrom in all_chromosomes:
key = (chrom, region_name)
if key in accum and accum[key][0] > 0:
mean_meth = accum[key][1] / accum[key][0]
else:
mean_meth = 0.0
y_values.append(mean_meth)
ax.plot(
x_positions,
y_values,
marker="o",
markersize=4,
linewidth=1.5,
label=display_label,
)
# Step 4: Add vertical separator lines between regions
for i in range(1, len(region_names)):
separator_x = i * num_chromosomes - 0.5
ax.axvline(x=separator_x, color="gray", linestyle="--", linewidth=1.0)
# Step 5: Add region group labels at the top
for i, region_name in enumerate(region_names):
center_x = (i * num_chromosomes + (i + 1) * num_chromosomes - 1) / 2
ax.text(
center_x,
1.02,
region_name.capitalize(),
transform=ax.get_xaxis_transform(),
ha="center",
va="bottom",
fontsize=14,
fontweight="bold",
)
# Step 6: Configure axes
ax.set_xticks(x_positions)
ax.set_xticklabels(x_labels, rotation=45, ha="right", fontsize=8)
ax.set_ylabel("Mean Methylation", fontsize=12)
ax.set_xlabel("Chromosome", fontsize=12)
ax.set_title(
plot_title or "Mean Methylation by Region and Chromosome", fontsize=16, pad=30
)
ax.legend(title="Modification", fontsize=10)
ax.set_xlim(-0.5, len(x_positions) - 0.5)
ax.set_ylim(y_min, y_max)
ax.grid(False)
plt.tight_layout()
plt.savefig(output_file_path, dpi=150, bbox_inches="tight")
plt.close(fig)
logger.info("Saved plot to %s", output_file_path)