library(magrittr)
library(dplyr)
library(tidyr)
library(readr)
# for network analysis
library(corrr)
library(tidygraph)
library(ggraph)
# for clustering
library(ComplexHeatmap)
library(circlize)
library(RColorBrewer)
# for the plot
library(ggplot2)
library(wesanderson)
# the final cluster is not based on og (many og have duplication, i.e one copy in the synteny another not)
# the final cluster schema should be a new label for the cluster number (a,b,c,d...)

# read data ---------------------------------------------------------------
dt_directory <- "./data_jan/res_dt/"

for (t in list.files(dt_directory, pattern = "*tsv")) {
  # construct df name
  t_ls <- strsplit(t,split = "_")[[1]] # [1] "hsym"      "nvec"      "lim4"      "nmin8.tsv"
  tmp_df <- read_tsv(paste0(dt_directory,t),col_names = TRUE,col_types = cols_only("og"="c","cluster"="i"))
  tmp_label <- paste0(t_ls[1],"_",t_ls[2])
  tmp_df$sp_pair <- tmp_label
  tmp_df <- tmp_df %>% mutate(sp_pair_cluster = paste0(sp_pair,"_",cluster))
  tmp_df <- tmp_df %>% filter(cluster!=0) # remove noise OG
  assign(paste0(tmp_label,"_df"),tmp_df)
}

rm(tmp_df)


# combine df---------------------------------------------------------
# add df to list
df_list_name <- ls(pattern = "*df$")
df_list_name
df_list <- list()

# bind df
for (i in seq_along(df_list_name)) {
  df_list[[i]] <- get(df_list_name[i])
}

names(df_list) <- df_list_name
full_df <- bind_rows(df_list)
full_df <- data.table::data.table(full_df) # maybe dt is faster than df for subsetting in later steps

# pairwise compaison of all pair_sp_cluster, the output is already a correlation matrix -------------------------------

sp_pair_cluster_label <- full_df$sp_pair_cluster %>% unique()
m_dimn <- length(full_df$sp_pair_cluster %>% unique())
cat(paste0("The matix has dimention: ",m_dimn," x ",m_dimn))

# define matrix
share_og_m <- matrix(0,nrow = m_dimn,ncol = m_dimn)

for(cluster_a_i in seq_along(sp_pair_cluster_label)){
  print(cluster_a_i)
  for(cluster_b_i in seq_along(sp_pair_cluster_label)){
    # number of intersect normalise by the union of the 2 groups
    intersect_num<- intersect(full_df %>% filter(sp_pair_cluster == sp_pair_cluster_label[cluster_a_i]) %>% dplyr::select(og) %>% pull(),
               full_df %>% filter(sp_pair_cluster == sp_pair_cluster_label[cluster_b_i]) %>% dplyr::select(og) %>% pull()) %>% length()
    union_num <- union(full_df %>% filter(sp_pair_cluster == sp_pair_cluster_label[cluster_a_i]) %>% dplyr::select(og) %>% pull(),
                           full_df %>% filter(sp_pair_cluster == sp_pair_cluster_label[cluster_b_i]) %>% dplyr::select(og) %>% pull()) %>% length()
    share_og_m[cluster_a_i,cluster_b_i] <- intersect_num/union_num
      }
}

rownames(share_og_m) <- sp_pair_cluster_label
colnames(share_og_m) <- sp_pair_cluster_label

# convert the matrix to correlation matrix
diag(share_og_m) <- NA
share_og_c <- as_cordf(share_og_m)


# process the correlation matrix
share_og_c <- share_og_c %>% shave(upper = TRUE) %>%  stretch(na.rm = TRUE)
half_max_r <- (share_og_c %>% filter(r != 0) %>% select(r) %>% max())/2
quarter_max_r <- (share_og_c %>% filter(r != 0) %>% select(r) %>% max())/4

r_threshold <- half_max_r
share_og_c_filter <- share_og_c %>% filter(r >= r_threshold)


set.seed(123)
cor_graph <- as_tbl_graph(share_og_c_filter, directed = FALSE)
# (optional) add node size info
nodes_on_cor_graph <- cor_graph %>% activate(nodes) %>% as_tibble() %>% pull()
nodes_on_cor_graph_size <- c()
for (n in nodes_on_cor_graph) {
  nodes_on_cor_graph_size <- c(nodes_on_cor_graph_size, full_df %>% filter(sp_pair_cluster == n) %>% nrow())
}
cor_graph <- cor_graph %>% activate(nodes) %>% mutate(node_size = nodes_on_cor_graph_size)




ggraph(cor_graph) + 
  geom_edge_link() + 
  geom_node_point(aes(size=node_size)) +
  geom_node_text(
    aes(label = name), size = 3, repel = TRUE
  ) +
  theme_graph() +
  labs(title=paste0("corr_r >= ",r_threshold))


# to extract the cluster --------------------------------------------------

cluster_group_df <- cor_graph %>% activate(nodes) %>% mutate(group=group_components()) %>% activate(nodes) %>% as_tibble() 
# rename the group (1->g1), don't mix with the original cluster index for each species pair
clutser_name_map <- cluster_group_df$group %>% unique()
names(clutser_name_map) <- paste0("g",cluster_group_df$group %>% unique())
cluster_group_df <- cluster_group_df %>% mutate(common_group_id = names(clutser_name_map)[match(group,clutser_name_map)])


# save the df
file_name_tsv <- paste0("/Users/niccw/Desktop/vlogin3/hydractinia_2019/sep2019/synteny_break/data_jan/res_dt/","cluster_group_df",".tsv")
write_tsv(cluster_group_df,file_name_tsv)

# create color palatte for the common_group_id
# use grDevice::rainbow for higher contrast colors (tho not so pretty :/)
cluster_n <- cluster_group_df$common_group_id %>% unique() %>% length()
pal <- rainbow(cluster_n, start=0.1, end=1)
names(pal) <- cluster_group_df$common_group_id %>% unique()
# noise cluster should be in dull color (grey)
pal <- unlist(as.list(pal))
pal <- c(pal, "#333333")
names(pal)[length(pal)] <- "others"

# load pair res_dt and plot -----------------------------------------------
setwd("data_jan/")
if (!exists("output_rv")){
  output_rv <- new.env(parent = emptyenv())
}

dt_directory <- "./res_dt/"

for (t in list.files(dt_directory, pattern = "*tsv")) {
  # construct df name
  t_ls <- strsplit(t,split = "_")[[1]] # [1] "hsym"      "nvec"      "lim4"      "nmin8.tsv"
  tmp_df <- read_tsv(paste0(dt_directory,t),col_names = TRUE)
  tmp_label <- paste0(t_ls[1],"_",t_ls[2])
  tmp_df$sp_pair <- tmp_label
  tmp_df <- tmp_df %>% mutate(sp_pair_cluster = paste0(sp_pair,"_",cluster))
  assign(paste0(tmp_label,"_res_dt"),tmp_df)
}

rm(tmp_df)
ls(pattern="*res_dt")

# plotplotplotplotplot ----------------------------------------------------
# !!!! check/change the file name here
output_rv$res_dt_group_label <- hech_hmag_res_dt

# match the common_group_id, if not in common -> NA
output_rv$res_dt_group_label <- output_rv$res_dt_group_label %>% mutate(common_group_id = cluster_group_df$common_group_id[match(sp_pair_cluster,cluster_group_df$name)])
output_rv$res_dt_group_label$common_group_id[is.na(output_rv$res_dt_group_label$common_group_id)] <- "others"
# specify the axis title (specie name) for the plot
spa <- output_rv$res_dt_group_label[1,"baseNameA"] %>% pull() %>% strsplit(split = "\\|") %>% unlist() %>% `[`(1)
spb <- output_rv$res_dt_group_label[1,"baseNameB"] %>% pull() %>% strsplit(split = "\\|") %>% unlist() %>% `[`(1)


# check if plot_param_env exist, if not, load it
if(!exists("plot_param_env")){
  source("plot_param_env.R")
}

ggplot_msynt_wrap <- function(){
  ga <- ggplot(output_rv$res_dt_group_label,aes(x=baseA,
                                    text = paste("scaffoldA:",scaffoldA,
                                                 "scaffoldB:",scaffoldB,
                                                 "baseNameA:",baseNameA,
                                                 "baseNameB:",baseNameB
                                    ),
                                    color = as.factor(common_group_id), fill = as.factor(common_group_id)
  ))
  
  #### build scaffold axis ####
  axisx_df <- output_rv$res_dt_group_label[,c(1,3)]
  axisy_df <- output_rv$res_dt_group_label[,c(2,4)]
  axisx_df <- arrange(axisx_df,baseA) %>% distinct(scaffoldA,.keep_all = TRUE)
  axisy_df <- arrange(axisy_df,baseB) %>% distinct(scaffoldB,.keep_all = TRUE)
  
  
  # ggplot ------------------------------------------------------------------
  output_rv$ggplot <- ga + 
    geom_point(aes(y=baseB),size=0.3) + 
    theme(panel.background=element_blank(),
          plot.margin = unit(c(1,1,3,3),"line"),
          axis.text = element_blank(),
          axis.ticks.length = unit(.25,"cm")) +
    scale_x_continuous(breaks=axisx_df$baseA,labels = axisx_df$scaffoldA) +
    scale_y_continuous(breaks=axisy_df$baseB,labels = axisy_df$scaffoldB) +
    labs(x = spa, y = spb,
         title =  paste0("lim=",plot_param_env$lim,";nmin=",plot_param_env$nmin)) +
    scale_fill_manual(name = "cluster", values = pal) +
    scale_color_manual(name = "cluster", values = pal) +
    theme(legend.position="bottom")
  
  output_rv$ggplot 
}

ggplot_msynt_wrap()


if(!file.exists("figs_common_color")){
  dir.create("./figs_common_color")
}

file_name_svg = paste0("./figs_common_color/",spa,"_",spb,"_lim",plot_param_env$lim,"_nmin",plot_param_env$nmin,".svg")
cat(file_name_svg)
file_name_png = paste0("./figs_common_color/",spa,"_",spb,"_lim",plot_param_env$lim,"_nmin",plot_param_env$nmin,".png")
cat(file_name_png)

ggsave(filename = file_name_svg,plot=output_rv$ggplot, device = "svg", width = 207, height = 212,units = "mm")
ggsave(filename = file_name_png,plot=output_rv$ggplot, device = "png", width = 207, height = 212,units = "mm")

rm(res_dt_group_label, envir = output_rv)
rm(ggplot, envir = output_rv )
dev.off(dev.list()["RStudioGD"])
