##Combined Figure

library(tidyverse)
library(ggfortify)
library(ggpubr)
library(vegan)
library(ggpattern)

setwd("samples/spectrum/")

files<- list.files(pattern="*spectrum_counts*")
theData_list<-lapply(files, read.csv)
dat <-bind_cols(theData_list)
titles <- dat$X...1
dat <- dat %>% 
  dplyr::select(-starts_with("X."))
rownames(dat) <- titles

sigs_mm <- data.frame(dat) %>% 
  rownames_to_column("name") %>% 
  select(name, ART,Natural) %>% 
  pivot_longer(2:last_col(), names_to="type")

stat_df <- sigs_mm %>% 
  mutate(mutation_type = gsub(" .*", "", name) %>% 
           gsub("([ATCG])([ATCG])", "\\1>\\2", .)) %>% 
  mutate(name = gsub(".* ", "", name)) %>%
  group_by(type) %>% 
  mutate(fract=value/sum(value)) %>%
  dplyr::select(type,
                trinucleotide=name,
                mutation_type,
                n=value,
                fract) %>%
  group_by(type) %>% 
  mutate(total_mutations = sum(n)) %>% 
  separate(trinucleotide, sep="\\.", into=c("5'", "3'")) %>%
  group_by(type, mutation_type, `3'`, `5'`, n, total_mutations) %>% 
  summarise(value=sum(n)) %>%
  pivot_longer(value:n) %>% 
  mutate(name=paste(type,name,sep="_") %>% gsub("_value", "",.))%>%
  ungroup() %>% 
  select(-type) %>%
  pivot_wider(names_from=name) %>% select(-ART_n, -Natural_n)

compute_ordered_chisq <- function(df, group1, group2) {
  # Construct column names dynamically
  g1 <- sym(group1)
  g1_total <- sym(paste0(group1, "_total_mutations"))
  g2 <- sym(group2)
  g2_total <- sym(paste0(group2, "_total_mutations"))
  
  # Create working copy
  df_copy <- df %>%
    mutate(
      p_unordered = purrr::pmap_dbl(
        list(!!g1, !!g1_total, !!g2, !!g2_total),
        function(x1, x1_total, x2, x2_total) {
          mat <- matrix(
            c(x1, x1_total - x1,
              x2, x2_total - x2),
            nrow = 2, byrow = TRUE
          )
          chisq.test(mat)$p.value
        }
      )
    ) %>%
    arrange(p_unordered)
  
  # Precompute cumulative sums
  df_copy <- df_copy %>%
    mutate(
      g1_cumsum = rev(cumsum(rev(!!g1))),
      g2_cumsum = rev(cumsum(rev(!!g2))),
      p_ordered = NA_real_
    )
  
  for (i in seq_len(nrow(df_copy))) {
    g1_m <- df_copy[[group1]][i]
    g2_m <- df_copy[[group2]][i]
    
    if (i < nrow(df_copy)) {
      g1_rest <- df_copy$g1_cumsum[i + 1]
      g2_rest <- df_copy$g2_cumsum[i + 1]
    } else {
      g1_rest <- df_copy[[group1]][i - 1]
      g2_rest <- df_copy[[group2]][i - 1]
    }
    
    mat <- matrix(c(g1_m, g1_rest, g2_m, g2_rest), nrow = 2, byrow = TRUE)
    df_copy$p_ordered[i] <- chisq.test(mat)$p.value
  }
  
  df_copy <- df_copy %>%
    rename_with(~ paste0(., "_", group1, "_vs_", group2), c("p_unordered", "p_ordered"))
  
  return(df_copy)
}

result_ART_vs_NAT <- compute_ordered_chisq(stat_df %>% 
                                             filter(ART + Natural > 0), "ART", "Natural")


library(scales)  # for formatting p-values

# Format p-values for plotting
plot_df <- result_ART_vs_NAT %>%
  mutate(p_label = ifelse(p_ordered_ART_vs_Natural < 0.05, "<0.05", sprintf("%.3f", p_ordered_ART_vs_Natural)))

md <- ggplot(plot_df %>% mutate(ratio=((ART/ART_total_mutations)/(Natural/Natural_total_mutations)),
                                pattern=ifelse(ratio != Inf,"none", "stripe"),
                                ratio=ifelse(is.infinite(ratio),3,ratio)), aes(x=`3'`, y=`5'`, fill=ratio)) +
  theme_bw() +
  geom_tile_pattern(aes(pattern=pattern), color="black")+
  geom_text(aes(label = ifelse(grepl("<",p_label), "*", NA)), vjust=0.8,size=10) +  # Add this line
  facet_grid(mutation_type~., switch="y") +
  theme(strip.placement = "outside",
        strip.switch.pad.grid = unit(1, "cm"),
        axis.title.y = element_text(vjust = -15)) +
  ggtitle("Natural vs ART") + 
  scale_fill_gradient(high = "#E69F00", low = "#666699") + 
  theme(legend.position = "top") + 
  labs(fill = "") + 
  scale_pattern_manual(values=c("none", "stripe"), guide="none")
md
saveRDS(md, "../../figures/heatmap_with_sig.Rds")
