library(readr)
library(dplyr)
library(tidyr)
library(ggplot2)
library(glue)

color_schema_species <- readr::read_csv(
  paste0(Sys.getenv("HOME"), "/code/malawi_transposon/metadata/species_colorscheme.csv"), comment='#') %>% 
  mutate(color = paste0('#', color))

tmp_df1 <- read_csv(snakemake@input[[1]])
tmp_df2 <- read_csv(snakemake@input[[2]])

# combine data frames of the pangenome stats and number of events
df_raw <- inner_join(tmp_df1, tmp_df2, by=c("species1", "species2", "min_variant_size")) %>% 
  mutate(
    species1 = forcats::fct_relevel(species1, color_schema_species$species_abrv),
    species2 = forcats::fct_relevel(species2, color_schema_species$species_abrv),
    bubble_density = n_bubbles / (bb_cov_perc / 100 * bb_size / 10**3),
    insert_density = n_ins / (bb_cov_perc / 100 * bb_size / 10**3),
    delete_density = n_del / (bb_cov_perc / 100 * bb_size / 10**3),
    n_bases_insert_density = n_bases_ins / (bb_cov_perc / 100 * bb_size / 10**3),
    n_bases_delete_density = - n_bases_del / (bb_cov_perc / 100 * bb_size / 10**3)
  ) %>% 
  mutate(
    net_event_density = insert_density - delete_density,
    net_n_bases_density = n_bases_insert_density - n_bases_delete_density
  )

# check if the number of bubbles == number of events
all(df_raw$n_bubbles == df_raw$n_events)

# self-to-self comparison dummy data
diagonal_df <- data.frame(
  species1 = color_schema_species$species_abrv, 
  species2 = color_schema_species$species_abrv, 
  bb_conserv_perc=1
)


# loop over several variant sizes
for (i in seq(1, length(snakemake@config$min_variant_sizes)) ) {
  
  # get current variant size and extract relevant rows in data frame
  L <- snakemake@config$min_variant_sizes[i] #50
  df_subset <- df_raw %>% filter(min_variant_size == L)
  
  # bubble density
  p <- ggplot(df_subset, aes(species1, species2, fill = bubble_density)) +
    geom_tile(color='black') +
    geom_text(aes(label = format(round(bubble_density, 3), nsmall=3)), color='black') +
    geom_tile(aes(x = species1, y = species2), diagonal_df, fill = 'black') +
    scale_fill_gradient('bubbles\nper 1kbp\nbackbone', low = "#2171B5", high = "#EDF8E9") + #low = ggpubr::get_palette('Blues', k=7)[6], high=ggpubr::get_palette('Greens', k=7)[1]) +
    labs(x='backbone', y = 'query') +
    theme_classic() +
    theme(
      aspect.ratio = 1,
      axis.line = element_blank(), 
      axis.ticks = element_blank(),
      axis.text = element_text(size=10))
  ggsave(snakemake@output$bubble[i], width = 6.2, height = 4.8)
  
  # insertion density
  p <- ggplot(df_subset, aes(species1, species2, fill = insert_density)) +
    geom_tile(color='black') +
    geom_text(aes(label = format(round(insert_density, 3), nsmall=3)), color='black') +
    geom_tile(aes(x = species1, y = species2), diagonal_df, fill = 'black') +
    scale_fill_distiller('insertion\nevents\nper 1kbp\nbackbone', palette='RdBu', direction=-1) +
    labs(x='backbone', y = 'query') +
    theme_classic() +
    theme(
      aspect.ratio = 1,
      axis.line = element_blank(), 
      axis.ticks = element_blank(),
      axis.text = element_text(size=10))
  ggsave(snakemake@output$inserted[i], width = 6.2, height = 4.8)
  
  # insertion n_bases density
  p <- ggplot(df_subset, aes(species1, species2, fill = n_bases_insert_density)) +
    geom_tile(color='black') +
    geom_text(aes(label = format(round(n_bases_insert_density, 2), nsmall=2)), color='black') +
    geom_tile(aes(x = species1, y = species2), diagonal_df, fill = 'black') +
    scale_fill_distiller('inserted bases\nper 1kbp\nbackbone', palette='RdBu', direction=-1) +
    labs(x='backbone', y = 'query') +
    theme_classic() +
    theme(
      aspect.ratio = 1,
      axis.line = element_blank(), 
      axis.ticks = element_blank(),
      axis.text = element_text(size=10))
  ggsave(snakemake@output$inserted_bases[i], width = 6.2, height = 4.8)
  
  # deletion density
  p <- ggplot(df_subset, aes(species1, species2, fill = delete_density)) +
    geom_tile(color='black') +
    geom_text(aes(label = format(round(delete_density, 3), nsmall=3)), color='black') +
    geom_tile(aes(x = species1, y = species2), diagonal_df, fill = 'black') +
    scale_fill_distiller('deletion\nevents\nper 1kbp\nbackbone', palette='RdBu', direction=1) +
    labs(x='backbone', y = 'query') +
    theme_classic() +
    theme(
      aspect.ratio = 1,
      axis.line = element_blank(), 
      axis.ticks = element_blank(),
      axis.text = element_text(size=10))
  ggsave(snakemake@output$deleted[i], width = 6.2, height = 4.8)
  
  # deleted n_bases density
  p <- ggplot(df_subset, aes(species1, species2, fill = n_bases_delete_density)) +
    geom_tile(color='black') +
    geom_text(aes(label = format(round(n_bases_delete_density, 2), nsmall=2)), color='black') +
    geom_tile(aes(x = species1, y = species2), diagonal_df, fill = 'black') +
    scale_fill_distiller('deleted bases\nper 1kbp\nbackbone', palette='RdBu', direction=1) +
    labs(x='backbone', y = 'query') +
    theme_classic() +
    theme(
      aspect.ratio = 1,
      axis.line = element_blank(), 
      axis.ticks = element_blank(),
      axis.text = element_text(size=10))
  ggsave(snakemake@output$deleted_bases[i], width = 6.2, height = 4.8)
  
  # net insertion and deletion events density
  p <- ggplot(df_subset, aes(species1, species2, fill = net_event_density)) +
    geom_tile(color='black') +
    geom_text(aes(label = format(round(net_event_density, 3), nsmall=3)), color='black') +
    geom_tile(aes(x = species1, y = species2), diagonal_df, fill = 'black') +
    scale_fill_distiller('insertions -\ndeletions\nper 1kbp\nbackbone', palette='RdBu') +
    labs(x='backbone', y = 'query') +
    theme_classic() +
    theme(
      aspect.ratio = 1,
      axis.line = element_blank(), 
      axis.ticks = element_blank(),
      axis.text = element_text(size=10))
  ggsave(snakemake@output$net_indel[i], width = 6.2, height = 4.8)
  
  # net insertion and deletion n_bases density
  p <- ggplot(df_subset, aes(species1, species2, fill = net_n_bases_density)) +
    geom_tile(color='black') +
    geom_text(aes(label = format(round(net_n_bases_density, 2), nsmall=2)), color='black') +
    geom_tile(aes(x = species1, y = species2), diagonal_df, fill = 'black') +
    scale_fill_distiller('inserted -\ndeleted bases\nper 1kbp\nbackbone', palette='RdBu') +
    labs(x='backbone', y = 'query') +
    theme_classic() +
    theme(
      aspect.ratio = 1,
      axis.line = element_blank(), 
      axis.ticks = element_blank(),
      axis.text = element_text(size=10))
  ggsave(snakemake@output$net_indel_bases[i], width = 6.2, height = 4.8)

}

# robustness of bubble density
p <- df_raw %>% 
  ggplot(aes(x = min_variant_size, y = bubble_density, color = species2)) +
  geom_point(size=1.2) + 
  geom_line(linewidth=.5) +
  facet_wrap(.~species1, nrow=2) +
  scale_color_manual('query genome', values=color_schema_species$color, breaks=color_schema_species$species_abrv) +
  theme_classic() +
  xlab('minimum variant size (-L)') +
  ylab('detected bubbles per 1kbp of backbone sequence') +
  theme(legend.position='top', 
        panel.grid.major.y = element_line(colour = "grey90", size=0.5),
        legend.title = element_text(size = rel(0.75)),
        legend.text = element_text(size = rel(0.65)),
        legend.key.size=unit(0.04, "npc")
  )
ggsave(snakemake@output$robustness_bubble, width = 6.4, height = 4)

# robustness of insertion density
p <- df_raw %>% 
  ggplot(aes(x = min_variant_size, y = insert_density, color = species2)) +
  geom_point(size=1.2) + 
  geom_line(linewidth=.5) +
  facet_wrap(.~species1, nrow=2) +
  scale_color_manual('query genome', values=color_schema_species$color, breaks=color_schema_species$species_abrv) +
  theme_classic() +
  xlab('minimum variant size (-L)') +
  ylab('number of insertions per 1kbp of backbone sequence') +
  theme(legend.position='top', 
        panel.grid.major.y = element_line(colour = "grey90", size=0.5),
        legend.title = element_text(size = rel(0.75)),
        legend.text = element_text(size = rel(0.65)),
        legend.key.size=unit(0.04, "npc")
  )
ggsave(snakemake@output$robustness_ins, width = 6.4, height = 4)

p <- df_raw %>% 
  ggplot(aes(x = min_variant_size, y = delete_density, color = species2)) +
  geom_point(size=1.2) + 
  geom_line(linewidth=.5) +
  facet_wrap(.~species1, nrow=2) +
  scale_color_manual('query genome', values=color_schema_species$color, breaks=color_schema_species$species_abrv) +
  theme_classic() +
  xlab('minimum variant size (-L)') +
  ylab('nunber of deletions per 1kbp of backbone sequence') +
  theme(legend.position='top', 
        panel.grid.major.y = element_line(colour = "grey90", size=0.5),
        legend.title = element_text(size = rel(0.75)),
        legend.text = element_text(size = rel(0.65)),
        legend.key.size=unit(0.04, "npc")
  )
ggsave(snakemake@output$robustness_del, width = 6.4, height = 4)
