#!/usr/bin/env Rscript

# Script to create a heatmap visualization for cytosine mutations within genomic ranges
# Combines F1R2 and F2R1 data in a single heatmap

# Load required libraries
suppressPackageStartupMessages({
  library(optparse)
  library(VariantAnnotation)
  library(GenomicRanges)
  library(Biostrings)
  library(BSgenome)
  library(rtracklayer)
  library(dplyr)
  library(readr)
  library(ggplot2)
  library(reshape2)
  library(grid)
  library(gridExtra)
  library(cowplot)
})

# Parse command line arguments
option_list <- list(
  make_option("--vcf1", type="character", help="Path to the VCF file"),
  make_option("--vcf2", type="character", help="Path to the VCF file"),
  make_option("--bed", type="character", help="Path to the BED file with columns CHR, POS1, POS2, POS3, POS4"),
  make_option("--output", type="character", default="heatmap_output.pdf", help="Path to the output PDF file")
  make_option("--fin", type="character", help="provide fin.merge.txt")
)

opt_parser <- OptionParser(option_list=option_list)
args <- parse_args(opt_parser)
id <- gsub("\\..*","",basename(args$vcf))

# Load reference genome
reference_genome <- readDNAStringSet("/home/users/data/01_reference/human_g1k_v37/human_g1k_v37.fasta")
names(reference_genome) <- gsub(" .*", "", names(reference_genome))  # Clean sequence names

# Function to process data for a specific read direction (F1R2 or F2R1)
process_read_direction <- function(vcf_file, bed_file, read_dir) {
  cat(sprintf("Processing %s data...\n", read_dir))
  
  # Read VCF file
  cat(sprintf("Reading VCF file: %s\n", vcf_file))
  vcf <- readVcf(vcf_file)
  
  # Extract variant information from VCF
  variants <- rowRanges(vcf)
  variant_df <- data.frame(
    CHROM = seqnames(variants),
    POS = start(variants),
    REF = as.character(ref(vcf)),
    ALT = as.character(unlist(alt(vcf))),
    stringsAsFactors = FALSE
  )
  
  # Create GRanges object for variants
  variant_granges <- GRanges(
    seqnames = variant_df$CHROM,
    ranges = IRanges(start = variant_df$POS, end = variant_df$POS),
    mcols = data.frame(
      REF = variant_df$REF,
      ALT = variant_df$ALT,
      VariantID = 1:nrow(variant_df)
    )
  )
  
  # Read BED file with ranges
  cat(sprintf("Reading BED file: %s\n", bed_file))
  ranges_df <- read_tsv(bed_file) %>%
    filter(dir == read_dir)
  ranges_df <- ranges_df %>% mutate(RangeID = c(1:nrow(ranges_df)))
  
  # Create two sets of GRanges objects for the two ranges: POS1~POS2 and POS3~POS4
  range1_granges <- GRanges(
    seqnames = ranges_df$CHR1,
    ranges = IRanges(start = ranges_df$POS1-1, end = ranges_df$POS2),
    mcols = data.frame(
      RangeType = "Range1",
      RangeID = 1:nrow(ranges_df)
    )
  )
  
  range2_granges <- GRanges(
    seqnames = ranges_df$CHR1,
    ranges = IRanges(start = ranges_df$POS3_2, end = ranges_df$POS4_2),
    mcols = data.frame(
      RangeType = "Range2",
      RangeID = 1:nrow(ranges_df)
    )
  )
  
  # Combine both ranges
  all_ranges <- c(range1_granges, range2_granges)
  
  # Find overlaps between variants and ranges
  cat("Finding variants in ranges...\n")
  overlaps <- findOverlaps(variant_granges, all_ranges)
  
  # Create results data frame
  query_indices <- queryHits(overlaps)
  subject_indices <- subjectHits(overlaps)
  
  results <- data.frame(
    VariantID = variant_granges$mcols.VariantID[query_indices],
    CHROM = seqnames(variant_granges)[query_indices],
    POS = start(variant_granges)[query_indices],
    REF = variant_granges$mcols.REF[query_indices],
    ALT = variant_granges$mcols.ALT[query_indices],
    RangeType = all_ranges$mcols.RangeType[subject_indices],
    RangeID = all_ranges$mcols.RangeID[subject_indices],
    stringsAsFactors = FALSE
  )
  
  # Add range positions based on RangeID and RangeType
  results <- results %>%
    mutate(
      POS1 = ranges_df$POS1[RangeID],
      POS2 = ranges_df$POS2[RangeID],
      POS3 = ranges_df$POS3_2[RangeID],
      POS4 = ranges_df$POS4_2[RangeID]
    )
  
  # Function to count non-CG cytosines in a specific genomic range
  count_non_CG_cytosines <- function(chr, start1, end1, read_direction) {
    # Check if chromosome exists in reference genome
    if(!(chr %in% names(reference_genome))) {
      warning(sprintf("Chromosome %s not found in reference genome", chr))
      return(NA)
    }
    
    # Get sequence for the range
    seq1 <- as.character(subseq(reference_genome[[chr]], start1, end1))
    
    if(read_direction == "F1R2") {
      # Count all cytosines
      adj_seq1 <- as.character(subseq(reference_genome[[chr]], start1, end1+1))
      total_cytosines1 <- stringr::str_count(seq1, "C")
      # Count CG contexts
      cg_contexts1 <- stringr::str_count(adj_seq1, "CG")
      
      # Return cytosines excluding CG context
      return(total_cytosines1 - cg_contexts1)
    } else {
      adj_seq1 <- as.character(subseq(reference_genome[[chr]], start1-1, end1))
      total_cytosines1 <- stringr::str_count(seq1, "G")
      # Count CG contexts
      cg_contexts1 <- stringr::str_count(adj_seq1, "CG")
      # Return cytosines excluding CG context
      return(total_cytosines1 - cg_contexts1)
    }
  }
  
  # Calculate NCH counts for each range
  tmp1 <- lapply(1:nrow(ranges_df), function(x) {
    df <- ranges_df[x,]
    df %>% mutate(NCH_range1 = count_non_CG_cytosines(CHR1, POS1, POS2, read_dir))
  })
  ranges_df <- do.call(rbind, tmp1)
  
  tmp2 <- lapply(1:nrow(ranges_df), function(x) {
    df <- ranges_df[x,]
    df %>% mutate(NCH_range2 = count_non_CG_cytosines(CHR1, POS3_2, POS4_2, read_dir))
  })
  ranges_df <- do.call(rbind, tmp2)
  ranges_df <- ranges_df %>% mutate(NCH_tot = NCH_range1 + NCH_range2)
  
  # Get mutation counts by range
  fin_df <- left_join(
    ranges_df %>% filter(dir == read_dir),
    results %>%
      group_by(RangeID) %>%
      dplyr::summarise(mut_count = n())
  ) %>% 
    dplyr::select(-info, -type) %>%
    dplyr::select(CHR1, POS1, POS2, POS3_2, POS4_2, dir, RangeID, NCH_range1, NCH_range2, NCH_tot, mut_count) %>%
    mutate(NCH_rate = mut_count/NCH_tot)
  
  fin_df[is.na(fin_df)] <- 0
  fin_df%>%
    write.table(paste0(id,".fin.merge.txt"),
                sep="\t",
                quote=F,
                row.names=F)
  fin_df<-read_tsv(args$fin)
  # Function to get sequence for a range
  get_sequence <- function(chr, start, end) {
    if(!(chr %in% names(reference_genome))) {
      warning(sprintf("Chromosome %s not found in reference genome", chr))
      return(NA)
    }
    return(as.character(subseq(reference_genome[[chr]], start, end)))
  }
  
  # Function to merge the two ranges into a single sequence
  get_merged_range_sequence <- function(chr, pos1, pos2, pos3, pos4) {
    seq1 <- get_sequence(chr, pos1, pos2)
    seq2 <- get_sequence(chr, pos3, pos4)
    return(paste0(seq1, seq2))
  }
  
  # Create a data frame for the heatmap
  create_heatmap_data <- function(range_row, results_for_range) {
    # Extract range info
    chr <- range_row$CHR1
    pos1 <- range_row$POS1
    pos2 <- range_row$POS2
    pos3 <- range_row$POS3_2
    pos4 <- range_row$POS4_2
    range_id <- range_row$RangeID
    
    # Get merged sequence
    merged_seq <- get_merged_range_sequence(chr, pos1, pos2, pos3, pos4)
    
    # Initialize vectors for base positions
    range1_length <- pos2 - pos1 + 1
    range2_length <- pos4 - pos3 + 1
    total_length <- range1_length + range2_length
    
    # Create a data frame for this range
    bases <- strsplit(merged_seq, "")[[1]]
    
    # Pad with empty spaces if less than max_length (302)
    max_length <- 302
    if (length(bases) < max_length) {
      bases <- c(bases, rep("", max_length - length(bases)))
    } else if (length(bases) > max_length) {
      bases <- bases[1:max_length]
    }
    
    # Create positions vector
    positions <- c()
    for (i in pos1:pos2) {
      positions <- c(positions, i)
    }
    for (i in pos3:pos4) {
      positions <- c(positions, i)
    }
    
    # If positions are less than max_length, pad with NA
    if (length(positions) < max_length) {
      positions <- c(positions, rep(NA, max_length - length(positions)))
    } else if (length(positions) > max_length) {
      positions <- positions[1:max_length]
    }
    
    # Create the data frame
    df <- data.frame(
      RangeID = range_id,
      Position = 1:max_length,
      GenomicPos = positions,
      Base = bases,
      Status = "Non-C",  # Default status
      Direction = read_dir  # Add read direction
    )
    if(FALSE){
    # Mark all cytosines based on read direction
    if (read_dir == "F1R2") {
      df$Status[df$Base == "C"] <- "C-nonmutated"
    } else {  # F2R1
      df$Status[df$Base == "G"] <- "C-nonmutated"
    }
    }
    if (read_dir == "F1R2") {
      for (i in 1:(length(bases)-1)) {
        if (bases[i] == "C" && bases[i+1] == "G") {
          df$Status[i] <- "CpG"  # CG context → grey
        } else if (bases[i] == "C") {
          df$Status[i] <- "C-nonmutated"
        }
      }
    } else {  # F2R1
      for (i in 2:length(bases)) {
        if (bases[i] == "G" && bases[i-1] == "C") {
          df$Status[i] <- "CpG"  # CG context → grey
        } else if (bases[i] == "G") {
          df$Status[i] <- "C-nonmutated"
        }
      }
    }
    # Mark mutated cytosines based on results
    if (!is.null(results_for_range) && nrow(results_for_range) > 0) {
      for (i in 1:nrow(results_for_range)) {
        mut_pos <- results_for_range$POS[i]
        
        # Find if mutation is in range1 or range2
        if (mut_pos >= pos1 && mut_pos <= pos2) {
          # It's in range1
          array_pos <- mut_pos - pos1 + 1
        } else if (mut_pos >= pos3 && mut_pos <= pos4) {
          # It's in range2
          array_pos <- (mut_pos - pos3 + 1) + range1_length
        } else {
          # Not in either range (shouldn't happen with our filtered results)
          next
        }
        
        # Mark the position as mutated if it's within bounds
        if (array_pos <= max_length) {
          df$Status[array_pos] <- "C-mutated"
        }
      }
    }
    
    return(df)
  }
  
  # Create heatmap data for all ranges
  cat("Generating heatmap data for", read_dir, "...\n")
  heatmap_data_list <- list()
  
  for (i in 1:nrow(fin_df)) {
    range_row <- fin_df[i, ]
    results_for_range <- results %>% filter(RangeID == range_row$RangeID)
    heatmap_df <- create_heatmap_data(range_row, results_for_range)
    heatmap_data_list[[i]] <- heatmap_df
  }
  
  # Combine all heatmap data
  heatmap_data <- do.call(rbind, heatmap_data_list)
  
  # Return both the heatmap data and the final data frame with rates
  return(list(
    heatmap_data = heatmap_data,
    fin_df = fin_df
  ))
}

# Process both F1R2 and F2R1 data
#f1r2_results <- process_read_direction("/home/users/ayh/Projects/27_A3B/07_revision/EM_seq/vcf/10pg_A3B_60s.bismark.snp.fin.only_SS_lib_mut.F1R2.CtoT.NCH.rh.sample.rh.vcf", "/home/users/ayh/Projects/27_A3B/07_revision/EM_seq/sampling/10pg_A3B_60s.sample_100.txt", "F1R2")
#f2r1_results <- process_read_direction("/home/users/ayh/Projects/27_A3B/07_revision/EM_seq/vcf/10pg_A3B_60s.bismark.snp.fin.only_SS_lib_mut.F2R1.CtoT.NCH.rh.sample.rh.vcf", "/home/users/ayh/Projects/27_A3B/07_revision/EM_seq/sampling/10pg_A3B_60s.sample_100.txt", "F2R1")

f1r2_results <- process_read_direction(args$vcf1, args$bed, "F1R2")
f2r1_results <- process_read_direction(args$vcf2, args$bed, "F2R1")
# Combine the heatmap data and rate data
heatmap_data <- rbind(f1r2_results$heatmap_data, f2r1_results$heatmap_data)
fin_df <- rbind(
  f1r2_results$fin_df %>% mutate(ReadDirection = "F1R2"),
  f2r1_results$fin_df %>% mutate(ReadDirection = "F2R1")
)

# Create a range factor ordered by NCH_rate and read direction
fin_df <- fin_df %>% 
  arrange(ReadDirection, desc(NCH_rate)) %>%
  mutate(RangeIDwithDir = paste(ReadDirection, RangeID, sep = "_"))%>%
  arrange(NCH_rate)
fin_df$RangeIDwithDir<-factor(fin_df$RangeIDwithDir,levels=fin_df$RangeIDwithDir)
# Create a position factor for x-axis
heatmap_data$Position <- factor(heatmap_data$Position, levels = 1:302)

# Add combined RangeID with direction
heatmap_data <- heatmap_data %>%
  mutate(RangeIDwithDir = paste(Direction, RangeID, sep = "_"))

# Create range ordering based on NCH_rate within each read direction
range_order <- (fin_df%>%arrange(NCH_rate))$RangeIDwithDir
heatmap_data$RangeIDwithDir <- factor(heatmap_data$RangeIDwithDir, levels = range_order)
heatmap_data<-heatmap_data%>%arrange(RangeIDwithDir)
# Join to add NCH_rate to heatmap data
heatmap_data <- heatmap_data %>%
  left_join(fin_df %>% arrange(NCH_rate)%>%
              dplyr::select(RangeID, ReadDirection, NCH_rate) %>%
              mutate(RangeIDwithDir = paste(ReadDirection, RangeID, sep = "_")),
            by = "RangeIDwithDir")

range_order <- (fin_df%>%arrange(NCH_rate))$RangeIDwithDir
heatmap_data$RangeIDwithDir <- factor(heatmap_data$RangeIDwithDir, levels = range_order)
heatmap_data<-heatmap_data%>%arrange(RangeIDwithDir)
# Create the color palette
color_palette <- c(
  "C-mutated" = "red", 
  "C-nonmutated" = "blue", 
  "Non-C" = "grey90",
  "CpG" = "darkgrey",
  "Empty"="white"
)

# Make empty positions white
heatmap_data$Status[heatmap_data$Base == ""] <- "Empty"
color_palette["Empty"] <- "white"

# Create the heatmap
heatmap_data$Status%>%unique()
p_heatmap <- ggplot(heatmap_data, aes(x = Position, y = RangeIDwithDir, fill = Status)) +
  geom_tile() +
  scale_fill_manual(values = color_palette, 
                    labels = c("Mutated cytosine", "Non-mutated cytosine", "CpG","Empty", "Non-cytosine base")) +
  theme_minimal() +
  theme(
    axis.text.x = element_blank(),
    axis.ticks.x = element_blank(),
    panel.grid = element_blank(),
    legend.title = element_text(size = 10),
    legend.text = element_text(size = 8),
    axis.text.y = element_text(size = 8),
    axis.title.y = element_text(size = 10),
    plot.title = element_text(size = 12, hjust = 0.5)
  ) +
  labs(
    x = "Position in merged range",
    y = "Range ID",
    fill = "Base Status",
    title = "Cytosine mutation status in genomic ranges"
  )
p_heatmap
# Create a horizontal bar chart for NCH_rate
fin_df2<-fin_df%>%arrange(-NCH_rate)
fin_df2$RangeIDwithDir%>%unique()
fin_df2$RangeIDwithDir<-factor(fin_df2$RangeIDwithDir,levels=c(fin_df2$RangeIDwithDir%>%unique()))
p_rate<-ggplot(fin_df,aes(x=RangeIDwithDir,y="1",fill=NCH_rate))+
  geom_col()+
  theme_minimal() +
  theme(
    panel.grid.major.y = element_blank(),
    axis.text.y = element_blank(),
    axis.title.y = element_blank(),
    axis.ticks.y = element_blank(),
    legend.position = "left"
  )+
  coord_flip()+
  scale_fill_gradient(low="#d0d0d0", high="#646464")
p_rate

# Combine the plots
cat("Saving heatmap to:", args$output, "\n")
pdf(args$output, width = 14, height = nrow(fin_df) * 0.2 + 3)

# Use plot_grid from cowplot to create a more flexible layout
combined_plot <- plot_grid(
  p_rate, p_heatmap, 
  ncol = 2, 
  align = "h",
  rel_widths = c(1, 5)
)
combined_plot
# Add a title
title <- ggdraw() + 
  draw_label(
    paste0("Cytosine Mutation Status in Genomic Ranges - ", id),
    fontface = "bold",
    size = 14,
    x = 0.5,
    hjust = 0.5
  )

combined_plot
# Combine with title
final_plot <- plot_grid(
  title, combined_plot,
  ncol = 1,
  rel_heights = c(0.1, 1)
)

# Print the plot
print(final_plot)

dev.off()

cat("Heatmap generation complete!\n")

# Also output a TSV file with the mutation rates
rate_output_file <- paste0(gsub("\\.pdf$", "", args$output), "_rates.tsv")
fin_df %>%
  write.table(rate_output_file,
              sep = "\t",
              quote = FALSE,
              row.names = FALSE)

cat("Rate data saved to:", rate_output_file, "\n")
