library(AnnotationDbi)
library(GenomicRanges)
library(ggplot2)
library(tibble)
library(dplyr)
library(tidyr)
library(readr)
library(glue)
library(regioneR)
library(ggsignif)

theme_set(theme_classic())
theme_update(
  panel.background = element_rect(fill = "transparent", colour = NA),
  plot.background = element_rect(fill = "transparent", colour = NA),
  legend.background = element_rect(fill = "transparent", colour = NA)
)

setwd(paste0(Sys.getenv("HOME"), "/code/malawi_transposon/script/R_SCRIPT/"))
source("modules/frequently_used_items.R")
source("modules/misannotation_thresholds.R")

updown_distance <- 2000  # distance considered upstream / downstream of a gene
n_iter <- 100

subdir <- "malawi_haplochromines"
bb_name <- "astCal1.2_ensembl"
bb_name_sqlite <- "astCal_v1.2"
bb_name_abrv <- "astCal"

################################################################################
################################################################################
# DATA IMPORT

# IMPORT STRUCTURAL VARIANT DATAFRAMES (gr_a: item that will be shuffled / bootstrapped)
df_sv <- read_csv(
  "./local/malawi_haplochromines-summary-preprocessed.csv",
  col_types = paste0('cciiiiicciiicilliidc', paste0(rep('ciicic', num_of_genomes), collapse=''), collapse=''),
  na = c("", "NA", ".")
)

chrom_sizes <- read_delim(
  glue(Sys.getenv("HOME"), "/igv/genomes/{bb_name}/chrom.sizes"),
  col_names = c('chr', 'len')
)
chrom_sizes <- Seqinfo(chrom_sizes$chr, chrom_sizes$len)

gr_sv <- GRanges(
  seqnames = df_sv$chr, 
  ranges = IRanges(start = df_sv$start, end = df_sv$end), id = df_sv$id, 
  seqinfo = chrom_sizes
)

##########

# IMPORT REGIONS WITH COVERAGE
df_covered <- readr::read_delim("../te_permutation_shuffle/input/astCal-covered_ranges.bed", col_names = c('chr', 'start', 'end'))
gr_genome <- GRanges(seqnames = df_covered$chr, IRanges(start = df_covered$start+1, end = df_covered$end+1), seqinfo = chrom_sizes) %>% GenomicRanges::trim()

# if want to use whole genome, comment that out and use this:
#gr_genome <- GRanges(seqnames = seqnames(chrom_sizes), ranges = IRanges(start=1, end=seqlengths(chrom_sizes)))

##########

# OBTAIN GENE ANNOTATIONS (gr_b: this will NOT be shuffled)
path_to_annotation_sqlite <- glue(Sys.getenv("HOME"), "/igv/genomes/{bb_name}/{bb_name_sqlite}.gff.sqlite")
txdb <- AnnotationDbi::loadDb(file = path_to_annotation_sqlite)

# filter genes (comment out these lines if you don't want, leave only the bottom line)
df_gene_trueness <- readr::read_csv(glue(Sys.getenv("HOME"), "/code/malawi_transposon/cloud/genome/{bb_name}/gene_trueness.csv"))
list_of_dodgy_genes <- df_gene_trueness %>% filter(
  (maxOverlapSingleTE > threshGeneBodyOverlapSingleTE | 
     totalOverlapMultipleTE > threshGeneBodyOverlapMultipleTE |
     hasTranspositionTerm == TRUE ) 
  ) %>% .$gene_id
glue("{length(x)} genes are potential TEs to be removed", x = list_of_dodgy_genes)

#list_of_dodgy_genes <- c()

##########

# OBTAIN GRANGES FOR THE INDIVIDUAL GENE FEATURES
gr_gene <- genes(txdb, columns = c('gene_id', 'tx_name'))
gr_gene <- gr_gene[!gr_gene$gene_id %in% list_of_dodgy_genes]
glue("number of genes: {length(gr_gene)}")

# exons in all coding and non-coding genes
gr_exon <- exons(txdb, columns = c('exon_name', "GENEID"))
gr_exon$GENEID <- sapply(gr_exon$GENEID, unlist)
gr_exon <- gr_exon[!gr_exon$GENEID %in% list_of_dodgy_genes]

# CDS: translated regions of the exons in protein-coding genes
gr_cds <- cds(txdb, columns = c('cds_name', "GENEID"))
gr_cds$GENEID <- sapply(gr_cds$GENEID, unlist)
gr_cds <- gr_cds[!gr_cds$GENEID %in% list_of_dodgy_genes]

# introns
gr_intron <- GenomicRanges::setdiff(gr_gene, gr_exon)

# 5-UTRs
gr_5utr <- unlist(fiveUTRsByTranscript(txdb, use.names=TRUE))
gr_5utr <- gr_5utr[gr_5utr$exon_name %in% gr_exon$exon_name]

# 3-UTRs
gr_3utr <- unlist(threeUTRsByTranscript(txdb, use.names=TRUE))
gr_3utr <- gr_3utr[gr_3utr$exon_name %in% gr_exon$exon_name]

# modified version of exon GRanges
# contains the CDS of coding genes and retains the exons of non coding genes
gr_5and3_utr <- GenomicRanges::union(gr_5utr, gr_3utr)
gr_exon_2 <- GenomicRanges::setdiff(gr_exon, gr_5and3_utr, ignore.strand=TRUE)

# upstream and downsream of genes
gr_upstream <- GenomicRanges::trim(GenomicRanges::flank(gr_gene, start=T, width=updown_distance))
gr_downstream <- GenomicRanges::trim(GenomicRanges::flank(gr_gene, start=F, width=updown_distance))

# intergenic regions (where strand information is dropped)
gr_gene_unstranded <- gr_gene
strand(gr_gene_unstranded) <- '*'
gr_intergenic <- gaps(gr_gene_unstranded)[strand(gaps(gr_gene_unstranded))=='*',]

##########

# CALCULATE FEATURE LENGTHS
gr_list_ordered <- list(gr_exon_2, gr_5utr, gr_3utr, gr_intron, gr_upstream, gr_downstream, gr_intergenic)
feature_names_ordered <- c("exon", "5UTR", "3UTR", "intron", "upstream", "downstream", "intergenic")

get_feature_lengths <- function() {
  gr_explored <- GRanges(seqinfo = chrom_sizes)
  feature_lengths_ordered <- rep(0, length(feature_names_ordered))
  names(feature_lengths_ordered) <- feature_names_ordered
  
  for (i in seq(1, length(gr_list_ordered)) ) {
    gr_to_add <- gr_list_ordered[[i]]
    gr_subject <- GenomicRanges::setdiff(gr_to_add, gr_explored, ignore.strand=TRUE)
    feature_lengths_ordered[i] <- sum(lengths(gr_subject))
    gr_explored <- GenomicRanges::union(gr_explored, gr_subject)
  }
  return(feature_lengths_ordered)
}

feature_lengths_ordered <- get_feature_lengths()
feature_lengths_ordered <- data.frame(Feature = names(feature_lengths_ordered), PercSeqGenomeWide = feature_lengths_ordered) %>% as_tibble()

################################################################################
################################################################################
# FUNCTIONS FOR COMPUTING OVERLAP

# function to perform hierarchical overlap count of a Granges with gene features (defined above)
hierarchical_overlap_count_with_genes <- function(gr_a) {
  gr_a$feature <- ""
  gr_explored <- GRanges(seqinfo = chrom_sizes)

  for (i in seq(1, length(gr_list_ordered)) ) {
    gr_to_add <- gr_list_ordered[[i]]
    feature_name <- feature_names_ordered[i]
    
    gr_subject <- GenomicRanges::setdiff(gr_to_add, gr_explored, ignore.strand=TRUE)
    
    idx <- findOverlaps(gr_a, gr_subject)
    gr_a$feature[idx@from][(gr_a$feature=="")[idx@from]] <- feature_name
    gr_explored <- GenomicRanges::union(gr_explored, gr_subject)
  }
  
  tmp <- table(as.data.frame(gr_a)$feature)
  print(sum(tmp))
  return(tibble(Feature = names(tmp), Count = round(as.vector(tmp) / sum(tmp) * 100, 4)))
}

# function to perform hierarchical overlap percentage sequence of a Granges with gene features (defined above)
hierarchical_overlap_percent_with_genes <- function(gr_a, return_featurelength = FALSE) {
  gr_a_explored <- GRanges(seqinfo = chrom_sizes)
  gr_b_explored <- GRanges(seqinfo = chrom_sizes) # for genomic features
  
  overlap_bases <- rep(0, length(feature_names_ordered))
  names(overlap_bases) <- feature_names_ordered
  
  for (i in seq(1, length(gr_list_ordered)) ) {
    # get GRanges for current genomic feature that has not been explored yet
    gr_to_add <- gr_list_ordered[[i]]
    feature_name <- feature_names_ordered[i]
    gr_subject <- GenomicRanges::setdiff(gr_to_add, gr_b_explored, ignore.strand=TRUE)
    
    # get GRanges for gr_a that have been explored
    gr_a <- GenomicRanges::setdiff(gr_a, gr_a_explored)
    
    # intersect 
    gr_intersect <- GenomicRanges::intersect(gr_a, gr_subject, ignore.strand=TRUE)
    
    # update GRanges for parts that had been explored
    gr_b_explored <- GenomicRanges::union(gr_b_explored, gr_subject)
    
    overlap_bases[i] <- sum(width(GenomicRanges::reduce(gr_intersect)))
  }
  
  return(tibble(Feature = names(overlap_bases), PercSeq = round(overlap_bases / sum(overlap_bases) * 100, 4)) )
}

################################################################################
################################################################################
# BOOTSTRAPPING RELATED FUNCTIONS

# function to remove a certain percentage of nonzero ranges in a GRanges object
remove_some_ranges <- function(gr_a, perc_kept = 0.1, drop_zero = FALSE) {
  if (drop_zero == TRUE){
    gr_a <- gr_a[width(gr_a) > 1]
  }
  idx_to_keep <- sample(length(gr_a), round(perc_kept * length(gr_a)))
  gr_a_kept <- gr_a[idx_to_keep]
  return(gr_a_kept)
}

# function to slightly jitter the GRanges object
jitter_ranges <- function(gr_a, jitter_range = 5000) {
  jitter_vec <- sample(seq(-jitter_range, jitter_range), length(gr_a), replace = TRUE)
  tmp <- tibble(as.data.frame(gr_a)) %>% mutate(start = start + jitter_vec, end = end + jitter_vec)
  return(GRanges(tmp))
}

# function to generate completely random ranges
generate_random_ranges <- function(gr_a, n_ranges = 5000) {
  n_ranges_lenzero <- round(sum(width(gr_a) <= 1) / length(gr_a) * n_ranges)
  genome_df <- data.frame(chr = seqnames(chrom_sizes), start = rep(1, length(chrom_sizes)), end = seqlengths(chrom_sizes))
  gr_tmp1 <- createRandomRegions(nregions = n_ranges_lenzero, length.mean = 1, length.sd = 0, genome=genome_df)
  gr_tmp2 <- createRandomRegions(nregions = n_ranges-n_ranges_lenzero, length.mean = 100, length.sd = 20, genome=genome_df)
  return(c(gr_tmp1, gr_tmp2))
}


################################################################################
################################################################################

# genome baseline (this is more accurate than feature_lengths_ordered, as it can exclude uncovered regions)
genomeBaseline <- hierarchical_overlap_percent_with_genes(gr_genome) %>% rename(PercSeqBaseline = PercSeq)
genomeBaseline

# quick comparisons
feature_lengths_ordered$PercSeqGenomeWide / sum(feature_lengths_ordered$PercSeqGenomeWide) * 100

# total length of the covered regions
genomeSizeReduced <- sum(width(GenomicRanges::reduce(gr_genome)))
genomeSizeReduced

##############
# COUNTS
# observed
x1 <- hierarchical_overlap_count_with_genes(gr_sv)

# bootstrapped
dfBootstrapCount <- left_join(genomeBaseline, x1 %>% rename(CountObs = Count))
for (idx in seq(n_iter)) {
  gr_sv_random <- remove_some_ranges(gr_sv)
  tmp <- hierarchical_overlap_count_with_genes(gr_sv_random)
  dfBootstrapCount <- left_join(dfBootstrapCount, tmp %>% rename("CountObs_{idx}" := Count))
}

# bootstrapped expected baseline
dfBootstrapCountExpected <- left_join(genomeBaseline, x1 %>% rename(CountObs = Count))
for (idx in seq(n_iter)) {
  gr_sv_random <- jitter_ranges(gr_sv, jitter_range = 500) %>% remove_some_ranges(perc_kept = 0.75)
  tmp <- hierarchical_overlap_count_with_genes(gr_sv_random)
  dfBootstrapCountExpected <- left_join(dfBootstrapCountExpected, tmp %>% rename("CountExp_{idx}" := Count))
}

dfBootstrapCount
dfBootstrapCountExpected

##############
# PERCENTAGE SEQUENCE
x2 <- hierarchical_overlap_percent_with_genes(gr_sv[width(gr_sv) > 1])

# bootstrapped
dfBootstrapPerc <- left_join(genomeBaseline, x2 %>% rename(PercSeqObs = PercSeq))
for (idx in seq(n_iter)) {
  gr_sv_random <- remove_some_ranges(gr_sv, drop_zero = TRUE)
  tmp <- hierarchical_overlap_percent_with_genes(gr_sv_random)
  dfBootstrapPerc <- left_join(dfBootstrapPerc, tmp %>% rename("PercSeqObs_{idx}" := PercSeq))
}

# bootstrapped expected baseline
dfBootstrapPercExpected <- left_join(genomeBaseline, x2 %>% rename(PercSeqObs = PercSeq))
for (idx in seq(n_iter)) {
  gr_sv_random <- jitter_ranges(gr_sv, jitter_range = 1000) %>% remove_some_ranges(perc_kept = 0.75, drop_zero = TRUE)
  tmp <- hierarchical_overlap_percent_with_genes(gr_sv_random)
  dfBootstrapPercExpected <- left_join(dfBootstrapPercExpected, tmp %>% rename("PercSeqExp_{idx}" := PercSeq))
}

dfBootstrapPerc
dfBootstrapPercExpected

##########
# convert the counts to a bubble density

tmp1 <- dfBootstrapCount %>% 
  mutate(FeatureLengthReduced = PercSeqBaseline / 100 * genomeSizeReduced) %>% 
  mutate(across(starts_with('Count'), ~ round(.x/100 * 187552) / FeatureLengthReduced * 1000) ) %>% 
  rowwise() %>%
  mutate(
    meanVal = mean(c_across(starts_with("CountObs_")), na.rm = TRUE),
    sdVal = sd(c_across(starts_with("CountObs_")), na.rm = TRUE)
  )

tmp2 <- dfBootstrapCountExpected %>% 
  mutate(FeatureLengthReduced = PercSeqBaseline / 100 * genomeSizeReduced) %>% 
  mutate(across(starts_with('Count'), ~ round(.x/100 * 187552) / FeatureLengthReduced * 1000) ) %>% 
  rowwise() %>%
  mutate(
    meanVal = mean(c_across(starts_with("CountExp_")), na.rm = TRUE),
    sdVal = sd(c_across(starts_with("CountExp_")), na.rm = TRUE)
  )

tmp_df <- bind_rows(
  list(
    #tmp1 %>% select(Feature, mean = CountObs) %>% add_column(valueType='true'),
    tmp1 %>% select(Feature, meanVal, sdVal, meanValTrue=CountObs) %>% add_column(valueType='observed'),
    tmp2 %>% select(Feature, meanVal, sdVal) %>% mutate(meanValTrue=meanVal) %>% add_column(valueType='expected') )
  ) %>% 
  mutate(
    Feature = factor(Feature, levels=feature_names_ordered),
    valueType = factor(valueType, levels=c('observed','expected'))
  ) %>% 
  ungroup()


# p-value calculation
left_join(
    tmp_df %>% filter(valueType=='observed') %>% select(Feature, meanVal, sdVal),
    tmp_df %>% filter(valueType=='expected') %>% select(Feature, tStatistic=meanVal)
  ) %>% 
  mutate(pval_two = 2*pnorm(abs(tStatistic-meanVal)/sdVal, lower.tail = FALSE),
         pval_one = pnorm(abs(tStatistic-meanVal)/sdVal, lower.tail = FALSE)
  )

errorbar_location <- tmp_df %>% ungroup() %>% group_by(Feature) %>% summarise(y_pos = max(meanVal + 1.96*sdVal))
tmp_df %>% 
  ggplot(aes(x = Feature, y = meanVal, fill = valueType)) +
  geom_col(alpha=0.9, width=0.6, position='dodge') +
  geom_signif(
    xmin = seq(7)-0.2, xmax = seq(7)+0.2, 
    y_position = errorbar_location$y_pos + 0.1,
    annotation = c('***', '*', 'ns', '***', 'ns', 'ns', 'ns'),
    textsize = 3, size=0.25
  ) +
  geom_errorbar(aes(ymin=meanVal-1.96*sdVal, ymax=meanVal+1.96*sdVal), 
                width=0.25, alpha=0.7, linewidth=0.5, position = position_dodge(0.6)) +
  geom_text(aes(y=meanVal+1.96*sdVal, label = format(round(meanValTrue, 4), n_small=4)), 
            vjust = 0.5, hjust=-0.1, size=2.7, angle=90, position = position_dodge(0.6)) +
  xlab(NULL) +
  scale_y_continuous("Bubble density per kbp of feature", expand = expansion(mult = c(.05, .1))) +
  scale_fill_manual(NULL, values = c('#3289a8', '#ff9b30')) +
  theme(
    legend.position = 'top', axis.text.x = element_text(angle = 20, hjust = 1),
    legend.text = element_text(size = rel(0.9)),
    legend.key.size=unit(0.03, "npc")
  ) 
#ggsave("feature_bubble_density_bootstrap.pdf", width=3.9, height=2.9)

