"""Gene-region parsing and interval lookup helpers."""
import bisect
import csv
import logging
from typing import TypedDict
logger = logging.getLogger(__name__)
PROMOTER_UPSTREAM = 1000
ENHANCER_DOWNSTREAM = 1000
GeneInterval = tuple[int, int, str]
RegionList = list[GeneInterval]
GenesByChromosome = dict[str, RegionList]
[docs]
class ChromosomeRegions(TypedDict):
promoter: RegionList
body: RegionList
enhancer: RegionList
promoter_starts: list[int]
body_starts: list[int]
enhancer_starts: list[int]
GeneRegionsByChromosome = dict[str, ChromosomeRegions]
[docs]
def parse_gff(gff_path: str) -> GenesByChromosome:
"""Parse formatted GFF .modalysis file.
Returns a dict: chromosome -> sorted list of (start, end, gene_id).
"""
genes_by_chromosome = {}
input_file = open(gff_path, newline="")
reader = csv.reader(input_file, delimiter="\t")
header = next(reader)
logger.debug("GFF header: %s", header)
for row in reader:
chromosome = row[0]
start = int(row[1])
end = int(row[2])
gene_id = row[3]
if chromosome not in genes_by_chromosome:
genes_by_chromosome[chromosome] = []
genes_by_chromosome[chromosome].append((start, end, gene_id))
input_file.close()
for chromosome in genes_by_chromosome:
genes_by_chromosome[chromosome].sort(key=lambda g: g[0])
logger.info(
"Parsed GFF: %s chromosomes, %s total genes.",
len(genes_by_chromosome),
sum(len(v) for v in genes_by_chromosome.values()),
)
return genes_by_chromosome
[docs]
def build_gene_regions(
genes_by_chromosome: GenesByChromosome,
promoter_upstream: int = PROMOTER_UPSTREAM,
enhancer_downstream: int = ENHANCER_DOWNSTREAM,
) -> GeneRegionsByChromosome:
"""Build promoter/body/enhancer region boundaries for annotation lookup."""
regions = {}
for chromosome, genes in genes_by_chromosome.items():
promoter_regions = []
body_regions = []
enhancer_regions = []
for gene_start, gene_end, gene_id in genes:
promoter_start = max(0, gene_start - promoter_upstream)
promoter_end = gene_start
enhancer_start = gene_end
enhancer_end = gene_end + enhancer_downstream
promoter_regions.append((promoter_start, promoter_end, gene_id))
body_regions.append((gene_start, gene_end, gene_id))
enhancer_regions.append((enhancer_start, enhancer_end, gene_id))
promoter_regions.sort(key=lambda r: r[0])
body_regions.sort(key=lambda r: r[0])
enhancer_regions.sort(key=lambda r: r[0])
regions[chromosome] = {
"promoter": promoter_regions,
"body": body_regions,
"enhancer": enhancer_regions,
"promoter_starts": [r[0] for r in promoter_regions],
"body_starts": [r[0] for r in body_regions],
"enhancer_starts": [r[0] for r in enhancer_regions],
}
return regions
[docs]
def find_genes_at_position(
position: int,
region_list: RegionList,
starts_list: list[int],
) -> list[str]:
"""Find all gene IDs whose region contains the given position."""
gene_ids = []
idx = bisect.bisect_right(starts_list, position)
for i in range(idx):
region_start, region_end, gene_id = region_list[i]
if region_start <= position < region_end:
gene_ids.append(gene_id)
return gene_ids
[docs]
def find_genes_overlapping_interval(
interval_start: int,
interval_end: int,
region_list: RegionList,
starts_list: list[int],
) -> list[str]:
"""Find all gene IDs whose region overlaps the given half-open interval [start, end)."""
gene_ids = []
if interval_end <= interval_start:
return gene_ids
idx = bisect.bisect_right(starts_list, interval_end - 1)
for i in range(idx):
region_start, region_end, gene_id = region_list[i]
if region_start < interval_end and interval_start < region_end:
gene_ids.append(gene_id)
return gene_ids