source(paste0(Sys.getenv("HOME"), "/code/malawi_transposon/script/R_SCRIPT/modules/frequently_used_items.R"))

# the two granges arguments should be named single element lists 
compare_annotation_overlap <- function(gr_query_list, gr_subject_list, div_threshold=50) { 
  gr_query <- gr_query_list[[1]]
  query_name <- names(gr_query_list)
  
  gr_subject <- gr_subject_list[[1]]
  subject_name <- names(gr_subject_list)

  # filter out TEs which are below the threshold of sequence divergence
  gr_query <- gr_query[gr_query$score < div_threshold]
  gr_subject <- gr_subject[gr_subject$score < div_threshold]
  
  # for each repeat class, overlap annotations with those of the query
  overall_stats <- tibble()
  for (x in color_schema_te$te_family) {
    gr_subject_subset <- gr_subject[gr_subject$repeat_class == x] %>% GenomicRanges::reduce(ignore.strand=TRUE)
    
    # retrieve fragments in query that overlap with subject, along with the details of the former
    gr_intersect <- findOverlapPairs(gr_query, gr_subject_subset, ignore.strand=TRUE) %>% pintersect(ignore.strand=TRUE)
    
    # calculate the total length of fragments for the matched annotations
    # add an extra row for those that are not overlapping with anything
    # convert to percentage of the total length of that repeat class in subject
    tmp_df <- as.data.frame(gr_intersect) %>% as_tibble() %>% 
      group_by(repeat_class) %>% 
      summarise(intersect_width = sum(width))
    
    # add column for subject ranges not annotated as such
    tmp_df <- tmp_df %>% 
      bind_rows(tibble(repeat_class = "Not Annotated", 
                       intersect_width = sum(width(gr_subject_subset)) - sum(.$intersect_width)) ) %>% 
      mutate(intersect_width = if_else(intersect_width < 0, 0, intersect_width))
    
    # add missing columns for certain repeat classes not detected, but would still like to show up in bar plot
    tmp_df <- tmp_df %>% 
      full_join(tibble(repeat_class=c(color_schema_te$te_family, "Not Annotated"))) %>%
      mutate(
        intersect_width = if_else(is.na(intersect_width), 0, intersect_width)
      )
    
    # convert intersect length into a percentage
    tmp_df$perc_width <- tmp_df$intersect_width / sum(tmp_df$intersect_width) * 100
    
    # add columns applicable to everything
    tmp_df$sum_width <- sum(width(gr_subject_subset)) # total length of all subject ranges
    tmp_df <- add_column(tmp_df, subject_class = x)
    
    overall_stats <- bind_rows(overall_stats, tmp_df)
  }
  
  # refactor entries
  overall_stats$query_class <- factor(overall_stats$repeat_class, levels = c(color_schema_te$te_family, "Not Annotated") )
  overall_stats$subject_class <- factor(overall_stats$subject_class, levels = c(color_schema_te$te_family, "Not Annotated") )
  
  # retrieve query and subject methods as columns
  overall_stats$query <- query_name
  overall_stats$subject <- subject_name
  plot_title <- glue("subject: {subject_name}\nquery: {query_name}")
  print(plot_title)
  
  # reformat the dataframe for printing
  overall_stats <- overall_stats %>% select(-repeat_class) %>% relocate(subject, query)
  overall_stats %>% arrange(query_class) %>% 
    select(subject_class, sum_width, query_class, perc_width) %>% 
    mutate(perc_width = round(perc_width, 3)) %>% 
    pivot_wider(
      id_cols = c("subject_class", "sum_width"), 
      names_from = "query_class", values_from = "perc_width") %>% 
    as.data.frame() %>% 
    print()
  
  # make the overlap plot
  p <- overall_stats %>% 
    ggplot(aes(x = subject_class, y=perc_width, fill=query_class)) + 
    geom_bar(color='black', stat='identity', size=0.4, width=0.8, position=position_dodge() ) +
    scale_fill_manual("Matching\nTE class\nin query",
                      values = c(color_schema_te$color, 'gray20'), 
                      breaks = c(color_schema_te$te_family, "Not Annotated") ) +
    xlab(NULL) +
    ylab("Percentage sequence of subject matched") +
    ylim(c(0,100)) +
    ggtitle(plot_title) +
    theme(
      aspect.ratio = 1, 
      plot.title = element_text(size=rel(0.8)), 
      legend.title = element_text(size=rel(0.7)), 
      legend.text = element_text(size=rel(0.65)),
      legend.key.size=unit(0.03, "npc"),
      legend.position = 'right', 
      axis.text.x  = element_text(angle = 30, hjust = 1),
      axis.title.y = element_text(size=rel(0.85))
    )
  print(p)
}