library(tidyverse)
library(circlize)
library(ComplexHeatmap)

theme_set(theme_bw())

nodes <- read_delim("gfa/graph_nodes.csv", delim = " ", col_names = c("node", "type"))

node_order <- bind_rows(lapply(list.files("gfa/chroms/", pattern = "topo", full.names = T), function(p) read_delim(p, col_names = c("order", "X1", "X2", "node")) %>% select(order, node)))


node_genotypes <- read_tsv("data/calls/joined/samples.nodes.gz", col_names = c("node", "Sample"))

cpgs <- read_delim("gfa/cpg_index_xaf.csv.gz", delim = " ", col_names = c("node", "pos"))

node_annot <- read_delim("node_annotations.csv", delim = " ", col_names = c("node", "gene"))
node_bubbles_annot <- read_delim("node_bubble_annotations.csv", delim = " ", col_names = c("node", "bubble"))

load_cpg <- function(path, pattern, nodes, cpgs, average_cpg = T, ref_cpgs = F) {
  cpg <- bind_rows(lapply(
    list.files(path, pattern = pattern, full.names = T),
    function(x) {
      read_delim(x,
        col_names = c("node", "pos", "strand", "mc", "avg_score"),
        col_types = "cicid",
        delim = " "
      ) %>%
        mutate(Sample = str_split(basename(x), "\\.", simplify = T)[, 1],
               avg_score = (avg_score + 1)/256)
    }
  ))

  index_cpgs <- inner_join(cpg, cpgs)

  cpg_alt <- left_join(index_cpgs, nodes)

  if (!ref_cpgs) {
    cpg_alt <- cpg_alt %>%
    filter(type == "alt")
  }

  cpg_alt <- cpg_alt %>%
    arrange(node, pos, strand) %>%
    mutate(base = paste(node, pos, strand, sep = "_"))

  if (average_cpg) {
    cpg_alt %>%
      mutate(pos = case_when(strand == "+" ~ pos + 1, strand == "-" ~ pos)) %>%
      group_by(node, pos, type, Sample) %>%
      summarise(avg_score = mean(avg_score)) %>%
      mutate(base = paste(node, pos, sep = "_")) %>%
      select(node, pos, base, type, avg_score, Sample)
  } else {
    cpg_alt %>%
      select(node, pos, strand, base, type, avg_score, Sample)
  }
}

make_cpg_heatmap <- function(cpg_df, plot_path, node_annotations, node_order, node_types, plot_title) {
  cpg_alt_matrix_df <- left_join(cpg_df, node_order) %>%
    arrange(order, node, pos) %>%
    pivot_wider(id_cols = c(Sample), names_from = base, values_from = avg_score, values_fill = 0)

  cpg_alt_matrix <- cpg_alt_matrix_df %>%
    select(-c(Sample)) %>%
    as.matrix() %>%
    t()

  colnames(cpg_alt_matrix) <- cpg_alt_matrix_df$Sample

  png(plot_path, width = 8, height = 8, units = "in", res = 600)
  node_vec <- str_split(rownames(cpg_alt_matrix), "_", simplify = T)[, 1]

  cpg_gene_annot_matrix <- tibble(cpg = rownames(cpg_alt_matrix)) %>%
    mutate(node = str_extract(cpg, "s[0-9]+")) %>%
    left_join(node_annotations, relationship = "many-to-many") %>%
    pivot_wider(id_cols = c(cpg, node),
                names_from = "gene", values_from = "gene", values_fn = ~1, values_fill = 0)

  cpg_node_type_annot <- tibble(cpg = rownames(cpg_alt_matrix)) %>%
    mutate(node = str_extract(cpg, "s[0-9]+")) %>%
    left_join(node_types)

  lha1 <- rowAnnotation(
    type = as.matrix(cpg_node_type_annot %>% select(type)),
    node = node_vec,
    show_legend = c(T, F),
    col = list(type = c(ref = "grey60", alt = "grey30"))
  )
  rha1 <- rowAnnotation(
    gene = as.matrix(cpg_gene_annot_matrix %>%
                     select(-matches("cpg|node|NA"))),
    show_legend = F,
    simple_anno_size = unit(0.2, "in"),
    col = list(gene = colorRamp2(c(0, 1), c("white", "grey40")))
  )

  cpg_hm <- Heatmap(cpg_alt_matrix,
    column_labels = NULL, row_labels = NULL, show_column_names = F, show_row_names = F,
    show_row_dend = F, show_column_dend = F, name = "mC level",
    column_title = plot_title,
    row_title = "Methylated nucleotide",
    cluster_rows = F,
    left_annotation = lha1,
    ## right_annotation = rha1
  )
  draw(cpg_hm)
  dev.off()
  cpg_hm
  list(gene_annotation = cpg_gene_annot_matrix,
       sample_order = colnames(cpg_alt_matrix),
       cpg_order = rownames(cpg_alt_matrix)
       )
}


write_heatmap_order <- function(d, node_bubbles, prefix) {
  write_tsv(tibble(Sample = d$sample_order), paste0("plots/", prefix, "_samples.tsv"))
  tibble(cpg = d$cpg_order) %>%
    separate(cpg, sep = "_", into = c("node", "pos", "strand")) %>%
    left_join(node_bubbles) %>%
    write_tsv(paste0("plots/", prefix, "_cpgs.tsv"))
}

load_test <- load_cpg("data/bubbles/", "KIR", nodes, cpgs, average_cpg = T, ref_cpgs = T)


kir_ref <- make_cpg_heatmap(load_cpg("data/bubbles/", "KIR", nodes, cpgs, average_cpg = F, ref_cpgs = T),
                        "plots/kir_mc_full_ref.png", node_annot, node_order, nodes, "KIR")

write_heatmap_order(kir_ref, node_bubbles_annot, "KIR_ref")
kir_avg_ref <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "KIR", nodes, cpgs, average_cpg = T, ref_cpgs = T),
  "plots/kir_mc_avg_ref.png", node_annot, node_order, nodes, "KIR"
)
write_heatmap_order(kir_avg_ref, node_bubbles_annot, "KIR_avg_ref")

kir_alt <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "KIR", nodes, cpgs, average_cpg = F),
  "plots/kir_alt_mc_full.png", node_annot, node_order, nodes, "KIR"
)
write_heatmap_order(kir_alt, node_bubbles_annot, "KIR")
kir_avg_alt <- make_cpg_heatmap(load_cpg("data/bubbles/", "KIR", nodes, cpgs),
                        "plots/kir_alt_mc_avg.png", node_annot, node_order, nodes, "KIR")
write_heatmap_order(kir_avg_alt, node_bubbles_annot, "KIR_avg")

drqb1_ref <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-DRB1-DQB1", nodes, cpgs, average_cpg = F, ref_cpgs = T),
  "plots/hla_drb1_dqb1_mc_full_ref.png", node_annot, node_order, nodes, "HLA-DRB1-DQB1"
)
write_heatmap_order(drqb1_ref, node_bubbles_annot, "HLA-DRB1-DQB1_ref")
drqb1_avg_ref <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-DRB1-DQB1", nodes, cpgs, average_cpg = T, ref_cpgs = T),
  "plots/hla_drb1_dqb1_mc_avg_ref.png", node_annot, node_order, nodes, "HLA-DRB1-DQB1"
)
write_heatmap_order(drqb1_avg_ref, node_bubbles_annot, "HLA-DRB1-DQB1_avg_ref")

drqb1_alt <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-DRB1-DQB1", nodes, cpgs, average_cpg = F),
  "plots/hla_drb1_dqb1_alt_mc_full.png", node_annot, node_order, nodes, "HLA-DRB1-DQB1"
)
write_heatmap_order(drqb1, node_bubbles_annot, "HLA-DRB1-DQB1")
drqb1_avg_alt <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-DRB1-DQB1", nodes, cpgs),
  "plots/hla_drb1_dqb1_alt_mc_avg.png", node_annot, node_order, nodes, "HLA-DRB1-DQB1"
)
write_heatmap_order(drqb1_avg_alt, node_bubbles_annot, "HLA-DRB1-DQB1_avg")

hlacb_ref <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-C-B.csv", nodes, cpgs, average_cpg = F, ref_cpgs = T),
  "plots/hla_c_b_full_ref.png", node_annot, node_order, nodes, "HLA-C-B"
)
write_heatmap_order(hlacb_ref, node_bubbles_annot, "HLA-C-B_ref")

hlacb_avg_ref <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-C-B.csv", nodes, cpgs, ref_cpgs = T),
  "plots/hla_c_b_avg_ref.png", node_annot, node_order, nodes, "HLA-C-B"
)
write_heatmap_order(hlacb_avg_ref, node_bubbles_annot, "HLA-C-B_avg_ref")


hlacb_full_alt <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-C-B.csv", nodes, cpgs, average_cpg = F),
  "plots/hla_c_b_full_alt.png", node_annot, node_order, nodes, "HLA-C-B"
)
write_heatmap_order(hlacb_full_alt, node_bubbles_annot, "HLA-C-B")

hlacb_avg_alt <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-C-B.csv", nodes, cpgs),
  "plots/hla_c_b_avg_alt.png", node_annot, node_order, nodes, "HLA-C-B"
)
write_heatmap_order(hlacb_avg_alt, node_bubbles_annot, "HLA-C-B_avg")

hlacb_full_ref <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-C-B_full", nodes, cpgs, average_cpg = F, ref_cpgs = T),
  "plots/hla_c_b_wide_full_ref.png", node_annot, node_order, nodes, "HLA-C-B"
)

hlacb_avg_ref <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-C-B_full", nodes, cpgs, average_cpg = T, ref_cpgs = T),
  "plots/hla_c_b_wide_avg_ref.png", node_annot, node_order, nodes, "HLA-C-B"
)

hlacb_full_alt <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-C-B_full", nodes, cpgs, average_cpg = F, ref_cpgs = F),
  "plots/hla_c_b_wide_full_alt.png", node_annot, node_order, nodes, "HLA-C-B"
)
hlacb_avg_alt <- make_cpg_heatmap(
  load_cpg("data/bubbles/", "HLA-C-B_full", nodes, cpgs, average_cpg = T, ref_cpgs = F),
  "plots/hla_c_b_wide_avg_alt.png", node_annot, node_order, nodes, "HLA-C-B"
)

sample_reads_cpgs <- lapply(
  list.files("data/gafs/", pattern = "*.nodes", full.names = T),
  function(x) {
    snodes <- read_csv(x,
      col_names = c("node"),
    )
    left_join(snodes, cpgs) %>%
      group_by(node) %>%
      summarise(ncpg = n()) %>%
      mutate(sample = str_remove(basename(x), ".gaf.gz.nodes"))
  }
)

sample_reads_cpgs <- bind_rows(sample_reads_cpgs)
write_tsv(sample_reads_cpgs, "data/mg_ncpg_reads_sample.tsv.gz")
## sample_reads_cpgs <- read_tsv("data/mg_ncpg_reads_sample.tsv.gz")

ncpgs_reads_sample <- sample_reads_cpgs %>%
  group_by(sample) %>%
  summarise(ncpg_sample = sum(ncpg) / 2)

ncpgs_sample_reads_p <- ggplot(ncpgs_reads_sample) +
  geom_histogram(aes(x = ncpg_sample / 1e6)) +
  labs(
    title = "Number of CpGs in GA4K assemblies (reads)",
    x = "Number of CpGs (millions)"
  )
ggsave("plots/ncpgs_per_sample_reads.pdf", ncpgs_sample_reads_p)

ncpgs_sample_reads_alt <- left_join(sample_reads_cpgs, nodes) %>%
  filter(type != "ref") %>%
  group_by(sample) %>%
  summarise(ncpg_sample = sum(ncpg) / 2)

ncpgs_sample_reads_alt_p <- ggplot(ncpgs_sample_reads_alt) +
  geom_histogram(aes(x = ncpg_sample / 1e3)) +
  labs(
    title = "Number of CpGs in GA4K assemblies",
    subtitle = "In non-reference sequences (reads)",
    x = "Number of CpGs (thousands)"
  )
ggsave("plots/ncpgs_per_sample_reads_alt.pdf", ncpgs_sample_reads_alt_p)


sample_call_cpgs <- lapply(
  list.files("data/calls/", pattern = "*.nodes", full.names = T),
  function(x) {
    snodes <- read_csv(x,
      col_names = c("node"),
    )
    left_join(snodes, cpgs) %>%
      group_by(node) %>%
      summarise(ncpg = n()) %>%
      mutate(sample = str_remove(basename(x), ".fasta.gz.calls.nodes"))
  }
)

## sample_call_cpgs <- bind_rows(sample_call_cpgs)
## write_tsv(sample_call_cpgs, "data/mg_ncpg_sample.tsv.gz")
sample_call_cpgs <- read_tsv("data/mg_ncpg_sample.tsv.gz")

ncpgs_sample <- sample_call_cpgs %>%
  group_by(sample) %>%
  summarise(ncpg_sample = sum(ncpg) / 2)

ncpgs_sample_p <- ggplot(ncpgs_sample) +
  geom_histogram(aes(x = ncpg_sample / 1e6)) +
  labs(
    title = "Number of CpGs in GA4K assemblies",
    x = "Number of CpGs (millions)"
  )
ggsave("plots/ncpgs_per_sample.pdf", ncpgs_sample_p)

ncpgs_sample_alt <- left_join(sample_call_cpgs, nodes) %>%
  filter(type != "ref") %>%
  group_by(sample) %>%
  summarise(ncpg_sample = sum(ncpg) / 2)

ncpgs_sample_alt_p <- ggplot(ncpgs_sample_alt) +
  geom_histogram(aes(x = ncpg_sample / 1e3)) +
  labs(
    title = "Number of CpGs in GA4K assemblies",
    subtitle = "In non-reference sequences",
    x = "Number of CpGs (thousands)"
  )
ggsave("plots/ncpgs_per_sample_alt.pdf", ncpgs_sample_alt_p)

