#####################################################################################################
############################################ Settings ###############################################
#####################################################################################################

library(tidyverse)
library(magrittr)
library(rtracklayer)
library(ggrastr)
library(GenomicRanges)
library(seqLogo)
library(PWMEnrich)
library(RColorBrewer)
library(UpSetR)
library(gridExtra)
library(ggpubr)
library(pheatmap)
library(data.table)
library(universalmotif)
library(ggseqlogo)
library(nat.utils)
library(Biostrings)
library(BSgenome)
library(BSgenome.Dmelanogaster.UCSC.dm6)
library(DNAshapeR)
library(memes)
library(org.Dm.eg.db)

library(ChIPseeker)
library(topGO)
library(TxDb.Dmelanogaster.UCSC.dm6.ensGene)
txdb <- TxDb.Dmelanogaster.UCSC.dm6.ensGene


options(stringsAsFactors = F)
options(bitmapType='cairo')

#####################################################################################################
############################################ Parameters #############################################
#####################################################################################################

# genetic code
letters = c("A", "C", "G", "T")

letter_colors = c("red", "darkblue", "darkorange", "darkgreen")
names(letter_colors) = letters
custom_col_scheme = make_col_scheme(chars = names(letter_colors), cols = letter_colors)

rev_compliment = letters
names(rev_compliment) = c("T", "G", "C", "A")


# list of TFs with corresponding time-points, as they occur in file paths
ab_tp_list = c("twi/24", "ctcf/68", "mef2/68", "mef2/1012", "bin/68", "bin/1012")

TFs = c("Twi", "CTCF", "Mef2", "Mef2", "Bin", "Bin")
names(TFs) = ab_tp_list

TF2cond = list("twi/24", "ctcf/68", c("mef2/68", "mef2/1012"), c("bin/68", "bin/1012"))
names(TF2cond) = c("Twi", "CTCF", "Mef2", "Bin")

timepoints = c("2-4h", "6-8h", "6-8h", "10-12h", "6-8h", "10-12h")
names(timepoints) = ab_tp_list

ab_tp_labels = paste(TFs, timepoints, sep = ", ")
names(ab_tp_labels) = ab_tp_list

# Motif IDs of PWMs used in the analysis
#motifs = c("berkeley_bdtnp_twi", "M4670_1.02", "flyfactor_mef2_optimized_furlong", "flyfactor_mef2_optimized_furlong", "flyfactor_bin_optimized_furlong", "flyfactor_bin_optimized_furlong")

motifs_list= c("Twi_bdtnp", "CTCF_cisbp.M08413.2.00", "Mef2_flyfactorOptimizedFurlong", "Mef2_flyfactorOptimizedFurlong", "Bin_flyfactorOptimizedFurlong", "Bin_flyfactorOptimizedFurlong")
names(motifs_list) = ab_tp_list

motifs_denovo_list = c("Twi_denovo", "CTCF_denovo", "Mef2_denovo.68h", "Mef2_denovo.1012h", "Bin_denovo.68h", "Bin_denovo.1012h")
names(motifs_denovo_list) = ab_tp_list

select = dplyr::select
rename = dplyr::rename
reduce = GenomicRanges::reduce
union = GenomicRanges::union
shift = GenomicRanges::shift

cols9 = c("#999999", "#F781BF", "#A65628", "#C78080", "#FF7F00", "#984EA3", "#4DAF4A", "#377EB8", "#E41A1C")
cbPalette = c("#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")

# TF color codes
c1 = brewer.pal(4, name="Set1")
c2 = brewer.pal(4, name="Set2")[c(1, 4)]
TFcols = c(c1[1:3], c2[1], c1[4], c2[2])


#####################################################################################################
############################################ General ################################################
#####################################################################################################


load_config = function(config_path = "../../config/config.yml") {
  
  suppressPackageStartupMessages(require(yaml))
  yaml.load_file(config_path)
  
}

make_recursive_dir = function(new_dir){
  
  if(!dir.exists(new_dir)){ dir.create(new_dir, recursive = TRUE) }
  
}

get_job_stat = function(jobid) {
  
  # Get statistics of SLURM job with command: `seff jobid`
  # Return empty vector for running and pending jobs
  # More controls can be added
  
  # compile system command
  command = paste0("seff ", jobid)
  # get job stats and write to vector
  system(command, intern = TRUE)
  
}


parse_job_stat = function(job_stat) {
  
  # parse output generated by 'get_job_stat' function
  #### assume that seff output always has same lines ###
  
  out_df = data.frame(matrix(ncol = 10))
  names(out_df) = c("job_id", "state", "exit_code", "nodes", "cores_per_node", "cpu_utilized",
                    "cpu_efficiency", "wall_clock_time", "memory_utilized", "memory_efficiency")
  
  id_line = grep("^Job ID", job_stat)
  state_line = grep("^State", job_stat)
  nodes_line = grep("^Nodes", job_stat)
  cores_line = grep("^Cores", job_stat)
  cpu_line = grep("^CPU Utilized", job_stat)
  cpu_effic_line = grep("^CPU Efficiency", job_stat)
  mem_line = grep("^Memory Utilized", job_stat)
  mem_effic_line = grep("^Memory Efficiency", job_stat)
  clock_line = grep("^Job Wall-clock time", job_stat)
  
  
  out_df$job_id = ifelse(is_empty(id_line), NA, gsub("Job ID: ([0-9]+)", "\\1", job_stat[id_line]))
  out_df$state = ifelse(is_empty(state_line), NA, gsub("State: (.*) \\(exit code ([0-9])\\)", "\\1", job_stat[state_line]))
  out_df$exit_code = ifelse(is_empty(state_line), NA, gsub("State: (.*) \\(exit code ([0-9])\\)", "\\2", job_stat[state_line]))
  out_df$nodes = ifelse(is_empty(nodes_line), NA, gsub("Nodes: ([0-9]+)", "\\1", job_stat[nodes_line]))
  out_df$cores_per_node = ifelse(is_empty(cores_line), NA, gsub("Cores.*: ([0-9]+)", "\\1", job_stat[cores_line]))
  out_df$cpu_utilized = ifelse(is_empty(cpu_line), NA, gsub("CPU Utilized: (.*)", "\\1", job_stat[cpu_line]))
  out_df$cpu_efficiency = ifelse(is_empty(cpu_effic_line), NA, gsub("CPU Efficiency: (.*)% of .*", "\\1", job_stat[cpu_effic_line]))
  out_df$wall_clock_time = ifelse(is_empty(clock_line), NA, gsub("Job Wall-clock time: (.*)", "\\1", job_stat[clock_line]))
  out_df$memory_utilized = ifelse(is_empty(mem_line), NA, gsub("Memory Utilized: (.*) [GM]B", "\\1", job_stat[mem_line]))
  out_df$memory_efficiency = ifelse(is_empty(mem_effic_line), NA, gsub("Memory Efficiency: (.*)% of .*", "\\1", job_stat[mem_effic_line]))
  
  out_df
  
}


#####################################################################################################
############################################ Setting up #############################################
#####################################################################################################

config = load_config()

# output directories
outdir_fig_main = config$output_dirs$figures_main
outdir_fig_suppl = config$output_dirs$figures_suppl

make_recursive_dir(outdir_fig_main)
make_recursive_dir(outdir_fig_suppl)

#####################################################################################################
############################################ Loading data ###########################################
#####################################################################################################



get_path_consensus_peakset = function(ab_tp, dir = NA, filename = "consensus_peaks_idr0.01_ind3.bed") {
  
  suppressPackageStartupMessages(require(yaml))
  
  if(is.na(dir)){
    if(!exists("config")){config = load_config()}
    dir = config$data$peaks$consensus_peaksets 
  }
  
  path = file.path(dir, ab_tp, filename)
  
  if(!file.exists(path)){stop(paste("File doesn't exist:\n", path))}  
  
  path
  
}

get_path_cht = function(ab_tp, dir = NA, filename = "cht_results.txt") {
  
  suppressPackageStartupMessages(require(yaml))
  
  if(is.na(dir)){
    if(!exists("config")){config = load_config()}
    dir = config$data$CHT$CHT_dir 
  }
  
  path = file.path(dir, ab_tp, filename)
  
  if(!file.exists(path)){stop(paste("File doesn't exist:\n", path))}  
  
  path
  
}

load_consensus_peak_set = function(f, filter_chr = T, chr_list = c("chr2L", "chr2R", "chr3L", "chr3R", "chr4")) {
  
  gr = import(f, format = "bed")
  ab_tp = gsub(".*Consensus_peaksets\\/(.*)\\/consensus_peaks.*", "\\1", f)
  
  gr$ab = TFs[ab_tp]
  gr$tp = timepoints[ab_tp]
  gr$condition = ab_tp
  gr$peak_id = paste(gr$name, gr$condition, sep = "_")
  
  if(filter_chr) {
    gr = gr[seqnames(gr) %in% chr_list]
  }
  
  gr
}



load_cht_results = function(ab_tp,
                            remove_chr = TRUE, remove_chr_list = c("chrX", "chrY", "chrM"),
                            subset_ids = NA,
                            adjust_pval = TRUE,
                            padj_thres = 0.01,
                            AI_thres = 0.1,
                            ab_list = TFs,
                            tp_list = timepoints,
                            labels = ab_tp_labels,
                            file_name = "cht_results.txt") {
  
  path = get_path_cht(ab_tp, filename = file_name)
  
  
  df = read.delim2(path)
  df[ , 6:18] = sapply(df[ , 6:18], as.numeric)
  
  # add condition labels
  df$condition = ab_tp
  df$ab = ab_list[ab_tp]
  df$tp = tp_list[ab_tp]
  df$ab_tp_label = labels[ab_tp]
  
  if(adjust_pval) {
    df$padjust = p.adjust(df$P.VALUE, method = "BH")
  }
  
  if(remove_chr) {
    df %<>% filter(!TEST.SNP.CHROM %in% remove_chr_list)
  }
  
  df %<>% mutate(snp_id = paste(TEST.SNP.CHROM, TEST.SNP.POS, sep = "_"), 
                 peak_id = paste(TEST.SNP.CHROM, REGION.START + 250, condition, sep = "_"),
                 peak_summit = REGION.START + 250,
                 dist2summit = abs(TEST.SNP.POS - peak_summit),
                 signif = padjust < padj_thres,
                 AI = ALPHA / (ALPHA + BETA),
                 AI_abs = abs(ALPHA/(ALPHA + BETA) - 0.5),
                 strongAI = AI_abs >= AI_thres,
                 signif_strongAI = signif & strongAI,
                 is_indel = nchar(TEST.SNP.REF.ALLELE) > 1 | nchar(TEST.SNP.ALT.ALLELE) > 1,
                 indel_length = pmax(nchar(TEST.SNP.REF.ALLELE), nchar(TEST.SNP.ALT.ALLELE)))
  
  # subset CHT results (e.g. remove test SNPs with low coverage) - need to provide list of SNP IDs for this
  if(!is.na(subset_ids)) {
    stopifnot(any(subset_ids %in% df$snp_id))
    df %<>% filter(snp_id %in% subset_ids)
  }
  
  df

}



load_dhs = function(path = "config$data$DHS$dm6_liftover_annot"){
  
    suppressPackageStartupMessages(require(yaml))
    suppressPackageStartupMessages(require(GenomicRanges))
    suppressPackageStartupMessages(require(rtracklayer))
    
    if(!exists("config")){config = load_config()}
    
    path = eval(parse(text = path))
    dhs = fread(path)
    
    dhs$TSS = ifelse(abs(as.numeric(dhs$gene_nearest_dist)) > 500, "distal", "proximal")
    
    dhs
  
}


load_genes = function(path = "config$data$genome$dm6$gff_genes") {
  
  # Configuration file stores all paths to data
  # Changing column names ID and Name to gene_id and gene_name to reach agreement with other datasets 
  # Output is GRanges object 
  
  suppressPackageStartupMessages(require(yaml))
  suppressPackageStartupMessages(require(rtracklayer))
  suppressPackageStartupMessages(require(GenomicRanges))
  
  if(!exists("config")){config = load_config()} # load file with all paths
  genes_path = eval(parse(text = path))
  genes = import(genes_path)
  genes$gene_name = genes$Name
  genes$gene_id = genes$ID
  mcols(genes)[c("Name", "ID")] = NULL
  genes
  
}

load_genes_annot = function(path = "config$data$genome$dm6$genes_annot") {
  
  # Configuration file stores all paths to data
  # Changing column names ID and Name to gene_id and gene_name to reach agreement with other datasets 
  # Output is GRanges object 
  
  suppressPackageStartupMessages(require(yaml))
  suppressPackageStartupMessages(require(data.table))
  
  if(!exists("config")){config = load_config()} # load file with all paths
  genes_path = eval(parse(text = path))
  
  genes = fread(genes_path)
  GRanges(genes)
  
}



get_gene_promoter_shape = function(path = "config$data$external_data$Promoter_shape") {
  
  # data.frame with CAGE cluters asigned to genes: for each gene selecting the most expressed TSS
  # shape is assigned by the most expressed TSS. If there are broad promoters with lower expression, shape is specified as "minor broad"
  # unused (commented) -  Gini index for TSS expression
  
  suppressPackageStartupMessages(require(yaml))
  suppressPackageStartupMessages(require(GenomicRanges))
  
  if(!exists("config")) {config = load_config()}
  path = eval(parse(text = path))
  
  CAGE_clusters = import(path, format = "gff3")
  CAGE = data.frame(CAGE_clusters) %>% select(gene_name, peakid, shape.ind, shape, totalexpr)
  CAGE %<>% group_by(gene_name) %>% 
    mutate(totalexpr = as.numeric(totalexpr), 
           max_expr = max(totalexpr)) %>% 
    ungroup() %>% 
    filter(totalexpr == max_expr) %>% 
    mutate(max_expr = NULL, 
           totalexpr = NULL,
           peakid = NULL)
  
  # there is one case with two TSS  with exactly same expression for one gene (one narrow and one broad) => select broad
  dupl = names(which(table(CAGE$gene_name) > 1))
  CAGE %>% filter(!(gene_name == dupl & shape == "narrow")) 

}


load_fimo_motifs = function(ab_tp, motif_ids = motifs_list, pwm = NA, pval_thres = 0.0001, by_peak = T, path1 = "config$data$scanned_motifs$dir", path2 = "/FIMO/combined_motifs/fimo.tsv", TF_specific = TRUE, parse_motif_coordinates = TRUE, coord_shift = 0, subset_by_tf_motif = T) {
  
  # coord_shift = 1 - required when sequences were extracted with bedtools getfasta (coordinates are shifted by 1)
  
  if(!exists("config")) {config = load_config()}
  path1 = eval(parse(text = path1))
  
  if(TF_specific) {
    path = file.path(path1, ab_tp, path2)
  } else {
    path = file.path(path1, path2)
  }
  
  fimo = read.delim2(path, sep = "\t")
  
  if(subset_by_tf_motif) {
    
    if(is.na(pwm)) {
      pwm = motif_ids[ab_tp]
    } 
    
    # select motif matches
    fimo %<>% filter(motif_id == pwm)
    
  }
  
  # set p.value threshold (1e-4 - default) 
  fimo %<>% filter(as.numeric(p.value) <= pval_thres)
  
  # get motif coordinates - only works if sequence.name in fasta file was generated with pattern name::chr_start_end
  
  if (parse_motif_coordinates & by_peak){
  fimo %<>% 
    separate(sequence_name, into = c("name", "coord"), sep = "::") %>%
    separate(coord, into = c("chr", "coord"), sep = ":") %>%
    separate(coord, into = c("reg_start", "reg_end"), sep = "-") %>%
    filter(chr %in% c("chr3L", "chr2R", "chrX", "chr3R", "chr2L", "chr4", "chrY")) %>%
    separate(name, into = c("chr_summit", "summit"), sep = "_", remove = F) %>%
    rename(motif_start = start, motif_end = stop, seqnames = chr) %>%
    mutate(start = as.numeric(reg_start) + motif_start - 1 + coord_shift, 
           end = as.numeric(reg_start) + motif_end - 1 + coord_shift,
           center = (start + end - 1) / 2,
           dist2summit = center - as.numeric(summit),
           peak_id = paste(seqnames, as.numeric(reg_start) + 251, ab_tp, sep = "_")
           ) 
    
  } else if (parse_motif_coordinates & !by_peak) {
    
    fimo %<>% 
      separate(sequence_name, into = c("name", "coord"), sep = "::") %>%
      separate(coord, into = c("chr", "coord"), sep = ":") %>%
      separate(coord, into = c("reg_start", "reg_end"), sep = "-") %>%
      filter(chr %in% c("chr3L", "chr2R", "chrX", "chr3R", "chr2L", "chr4", "chrY")) %>%
      rename(motif_start = start, motif_end = stop, seqnames = chr) %>%
      mutate(start = as.numeric(reg_start) + motif_start - 1, 
             end = as.numeric(reg_start) + motif_end - 1) 
  }
  
  fimo
}

# combine motifs called in peaks and in alleles
get_full_motif_sets = function(cht, ab_tp_list) {
  
  # motifs that were called in 30 bp sequences around variants in motifs in either of alleles
  motifs_in_alleles  = lapply(ab_tp_list, function(ab_tp) parse_motifs_in_two_alleles(ab_tp, cht)) %>% 
    bind_rows() %>%
    filter(in_peak) %>% # keep only motifs in peaks
    mutate(motif_loc = paste(seqnames, start, sep = "_"))
  
  # motifs that were called in peaks (reference genome)
  motifs_in_peaks = lapply(ab_tp_list, function(ab_tp) load_fimo_motifs(ab_tp, coord_shift = 1) %>% mutate(condition = ab_tp)) %>%
    bind_rows() %>%
    select(condition, motif_id, motif_alt_id, seqnames, start, end, peak_id) %>%
    mutate(motif_loc = paste(seqnames, start, sep = "_"))
  
  # combine two sets
  rbind.data.frame(motifs_in_alleles %>% select(condition, motif_id, motif_alt_id, seqnames, start, end, peak_id, motif_loc), 
                            motifs_in_peaks) %>% unique() 
  
}


load_AI_peaks_ATACseq = function(return_granges = T) {
  
  # coordinates of ATAC peaks
  atac_regions_path = "/g/furlong/project/68_F1_cisreg_ichip/data/F1_paper_multiom/ATAC_feature_location_dm6.bed"
  atac_regions = import(atac_regions_path, format = "bed")
  
  # imbalance info for ATAC peaks
  atac_ai_path = "/g/furlong/project/68_F1_cisreg_ichip/data/F1_paper_multiom/ATAC_all_peaks_atac_X.txt"
  atac_ai = read.delim(atac_ai_path)
  
  # additional annotations for allele imbalance
  atac_ai_sum = atac_ai %>% 
    mutate(AI = padj < 0.01 & abs(0.5 - meanprop) > 0.1) %>%
    group_by(feature, time) %>% 
    summarize(n_AI = sum(AI), 
              AI_peak = any(AI),
              mean_AI = abs(0.5 - mean(meanprop)), 
              max_AI = max(abs(0.5 - meanprop)),
              n_het_lines =n())
  
  # combine AII info with peak coordinates
  atac_f1_df = merge(data.frame(atac_regions), atac_ai_sum, by.x = "name", by.y = "feature") %>% filter(!seqnames %in% c("chrX", "chrY"))
  
  if(return_granges) {
    GRanges(atac_f1_df)
  } else {
    atac_f1_df
  }
  
}

load_AI_peaks_k4me3 = function(return_granges = T) {
  
  # coordinates of K4me3 peaks
  k4me3_regions_path = "/g/furlong/project/68_F1_cisreg_ichip/data/F1_paper_multiom/K4me3_feature_location_dm6.bed"
  k4me3_regions = import(k4me3_regions_path, format = "bed")
  
  # imbalance info for ATAC peaks
  k4me3_ai_path = "/g/furlong/project/68_F1_cisreg_ichip/data/F1_paper_multiom/K4_all_peaks_k4_X.txt"
  k4me3_ai = read.delim(k4me3_ai_path)
  
  # additional annotations for allele imbalance
  k4me3_ai_sum = k4me3_ai %>% 
    mutate(AI = padj < 0.01 & abs(0.5 - meanprop) > 0.1) %>%
    group_by(feature, time) %>% 
    summarize(n_AI = sum(AI), 
              AI_peak = any(AI),
              mean_AI = abs(0.5 - mean(meanprop)), 
              max_AI = max(abs(0.5 - meanprop)),
              n_het_lines =n())
  
  # combine AII info with peak coordinates
  k4me3_f1_df = merge(data.frame(k4me3_regions), k4me3_ai_sum, by.x = "name", by.y = "feature") %>% filter(!seqnames %in% c("chrX", "chrY"))
  
  if(return_granges) {
    GRanges(k4me3_f1_df)
  } else {
    k4me3_f1_df
  }
  
}

load_AI_peaks_k27ac = function(return_granges = T) {
  
  # coordinates of k27ac peaks
  k27ac_regions_path = "/g/furlong/project/68_F1_cisreg_ichip/data/F1_paper_multiom/K27ac_feature_location_dm3.bed"
  k27ac_regions = import(k27ac_regions_path, format = "bed")
  
  # imbalance info for ATAC peaks
  k27ac_ai_path = "/g/furlong/project/68_F1_cisreg_ichip/data/F1_paper_multiom/K27_all_peaks_k27_X.txt"
  k27ac_ai = read.delim(k27ac_ai_path)
  
  # additional annotations for allele imbalance
  k27ac_ai_sum = k27ac_ai %>% 
    mutate(AI = padj < 0.01 & abs(0.5 - meanprop) > 0.1) %>%
    group_by(feature, time) %>% 
    summarize(n_AI = sum(AI), 
              AI_peak = any(AI),
              mean_AI = abs(0.5 - mean(meanprop)), 
              max_AI = max(abs(0.5 - meanprop)),
              n_het_lines =n())
  
  # combine AII info with peak coordinates
  k27ac_f1_df = merge(data.frame(k27ac_regions), k27ac_ai_sum, by.x = "name", by.y = "feature") %>% filter(!seqnames %in% c("chrX", "chrY"))
  
  if(return_granges) {
    GRanges(k27ac_f1_df)
  } else {
    k27ac_f1_df
  }
  
}


#####################################################################################################
############################################ peaks ##################################################
#####################################################################################################

construct_combined_consensus_set = function(peaks_grlist) {
  
  # peaks_grlist must be object GRangesList where each element is consensus set for 1 TF
  
  # get union of all Chip peaks (regions with at least one summit)
  combined_peaks = Reduce(union, peaks_grlist)
  
  # intervals in combined_peaks that are present in individual sample
  individual_peaks_in_combined_peaks = lapply(peaks_grlist, function(x) {overlaps = findOverlaps(x, combined_peaks)
                                                                         as.numeric(1:length(combined_peaks) %in% subjectHits(overlaps))}) 
  
  # add logical vectors for individual samples
  elementMetadata(combined_peaks) = individual_peaks_in_combined_peaks
  
  names(elementMetadata(combined_peaks)) = names(peaks_grlist)
  as.data.frame(combined_peaks)
  
}



get_peaks_from_cht = function(ab_tp, cht, as_granges = T) {
  
  cht_sel = cht %>% filter(condition == ab_tp) 
  
  peaks_sum = cht_sel %>% 
    # select(peak_id, TEST.SNP.CHROM, peak_summit, condition, signif_strongAI) %>%
    mutate(chr = TEST.SNP.CHROM, 
           start = peak_summit - 250,
           end = peak_summit + 250) %>%
    group_by(peak_id, condition, chr, start, end) %>%
    summarise(AI_peak = any(signif_strongAI)) 
  
  if(as_granges) {
    peaks_sum %>% GRanges()
  } else {
    peaks_sum
  }
  
}

# asign AIabs of variant with lowest p-value per peak
#get_AIabs_per_peak

#####################################################################################################
############################################ CHT ####################################################
#####################################################################################################

get_counts_per_line = function(ab_tp, path = "/g//furlong/project/68_F1_cisreg_ichip/data/ChIPseq/AS_counts/with_indels/2_target_regions/by_peaks_width5000/merged_input_idr0.01_ind3/") {
  
  
  counts_dir = file.path(path, ab_tp, "adjust_hap_read_counts_hetp_adj")
  ff = list.files(counts_dir, pattern = "^haplotype_read_counts")
  
  
  lapply(file.path(counts_dir, ff), function(f) {sample = gsub("haplotype_read_counts.(.+).txt.gz", "\\1", basename(f));
  print(sample);
  header = strsplit(readLines(f, 1), split = " ")[[1]];
  df = read.delim2(f, sep = "\t", header = FALSE, skip = 1);
  names(df) = header;
  df %>% mutate(sample_id = sample)})
  
}


get_top_variants = function(cht, ab_tp = NA, pthres = 0.01, filter = T, AIfilt = F) {
  
  if(!is.na(ab_tp)) {cht %<>% filter(condition == ab_tp)}
  
  cht %<>% 
    group_by(peak_id) %>% 
    mutate(min_pval = min(P.VALUE)) %>%
    mutate(top_var = P.VALUE == min_pval) %>%
    ungroup()
  
  if(filter) {
    cht %<>% filter(P.VALUE == min_pval)
    if(AIfilt) {
      cht %<>% group_by(peak_id) %>% 
        mutate(max_ai = max(AI_abs)) %>%
        filter(AI_abs == max_ai) %>%
        ungroup()
    } 
  }
  cht
}
    
    
  


get_peaks_with_variants_ids = function(cht) {
  
  cht %>% select(peak_id, condition) %>% unique() 
  
}



get_number_of_variants_per_peak = function(cht) {
  
  cht %>% group_by(peak_id, condition) %>% summarize(n_var = n()) %>%
    select(peak_id, condition, n_var)
  
}




get_AI_peak_ids = function(cht, padj_thres = 0.01, AI_thres = 0.1, reverse = F) {
  
  if(!reverse){
    cht %>%
      filter(signif_strongAI) %>%
      #filter(padjust < padj_thres & AI_abs >= AI_thres) %>% 
      select(peak_id, condition) %>% unique()
    } else {
      cht %>% 
        group_by(peak_id, condition) %>%
        mutate(signif_strongAI = any(signif_strongAI)) %>%
        ungroup() %>%
        filter(!signif_strongAI) %>% 
        select(peak_id, condition) %>% unique() 
   }
}

get_consensus_peaksets_with_AI = function(ab_tp_list, cht, padj_thres = 0.01, AI_thres = 0.1, labels = ab_tp_labels, filter = F, rev = F) {
  
  # get peak IDs with AI
  AI_peak_ids = get_AI_peak_ids(cht, padj_thres, AI_thres, reverse = rev) 
  
  # get consensus peaks
  ff = lapply(ab_tp_list, function(ab_tp) get_path_consensus_peakset(ab_tp))
  peaks = lapply(ff, function(f) load_consensus_peak_set(f)) %>% GRangesList()
  names(peaks) = ab_tp_list
  peaks = lapply(peaks, function(gr) {gr$isAI = gr$peak_id %in% AI_peak_ids$peak_id; gr})
  
  # subset peaks by presence of AI
  if(filter) {peaks = lapply(peaks, function(gr) {gr[gr$isAI]})}
  
  peaks
  
}


select_top_cht_variants = function(ab_tp, cht){
  
  df = cht %>% filter(condition == ab_tp)
  
  # annotate top variants per peak - also for non-significant variants
  df %>% 
    # filter(signif_strongAI) %>%
    group_by(peak_id) %>% 
    mutate(min_padj = min(padjust)) %>%
    filter(padjust == min_padj) %>%
    mutate(max_ai = max(AI_abs)) %>%
    filter(AI_abs == max_ai) %>%
    ungroup() 
  
}

get_as_counts = function(dat) {
  
  ref_counts = unlist(lapply(dat$REGION.SNP.REF.HAP.COUNT, function(x) sum(unlist(lapply(strsplit(x, split = ";"), as.numeric)))))
  alt_counts = unlist(lapply(dat$REGION.SNP.ALT.HAP.COUNT, function(x) sum(unlist(lapply(strsplit(x, split = ";"), as.numeric)))))
  
  dat$ref_counts = ref_counts
  dat$alt_counts = alt_counts
  dat$total_as_count = dat$ref_counts + dat$alt_counts
  
  dat %<>% mutate(allele_ratio = ifelse(total_as_count == 0, NA, ref_counts / (ref_counts + alt_counts)))
  dat
  
}

get_qq_values = function(cht, data_label = "CHT", min.p = 1e-20) {
  
  n.test = nrow(cht)
  null.p = (1:n.test)/(n.test)
  obs.p = cht$P.VALUE
  
  # cap p-values at min.p for drawing purposes
  obs.p[obs.p < min.p]  = min.p
  null.p[null.p < min.p] = min.p
  vals = qqplot(-log10(null.p), -log10(obs.p), plot.it=F)
  vals$label = data_label
  vals
  
}




plot_ai_and_read_depth_for_variant = function(cht, ll, id, peak = NA) {
  
  snp = cht %>% filter(snp_id == id)
  if(!is.na(peak)) {
    snp %<>% filter(peak_id == peak)
    #peak_start = as.numeric(strsplit(peak, "_")[[1]][2]) - 250
  }
  ai = round(snp[snp$snp_id == id, "AI_abs"], 2)
  ll_snp = lapply(ll, function(df) df %>% filter(CHROM == snp$TEST.SNP.CHROM & TEST.SNP.POS == snp$TEST.SNP.POS))
  
  df = Reduce(rbind.data.frame, ll_snp) %>% unique()
  if(!is.na(peak)) {
    peak_start = as.numeric(strsplit(peak, "_")[[1]][2]) - 250
    df %<>% filter(REGION.START == peak_start)
  }
  df_f1 = get_as_counts(df)
  
  tmp = df_f1 %>% 
    gather(allele, counts, ref_counts:alt_counts) %>%
    group_by(sample_id) %>%
    mutate(N = sum(counts),
           allele_share = counts / N) %>%
    na.omit()
  
  p1 = ggplot(tmp, aes(fill = allele, y = allele_share, x = sample_id)) +
    geom_bar(stat = "identity") +
    theme_bw() +
    scale_fill_manual(name = "Allele", values = c("red", "darkblue"), labels = c("1", "0")) +
    #scale_fill_manual(name = "Allele", values = c("#4DAF4A", "#E41A1C"), labels = c("1", "0")) +
    ylab("Share of allele counts") +
    xlab("Heterozygous lines") +
    geom_hline(yintercept = 0.5, color = "darkgrey", size = 2) +
    #ggtitle(paste0("Allelic imbalance, SNP #", snp_rank)) +
    theme(axis.text.y = element_text(size=12), axis.text.x = element_text(size=10, angle = 45, hjust = 1),
          axis.title.y = element_text(size=12), axis.title.x = element_text(size=12),
          legend.text=element_text(size=12), legend.title=element_text(size=12),
          strip.text.x = element_text(size = 12), strip.text.y = element_text(size = 12),
          plot.title = element_text(color="black", face="bold", size=12, hjust=0.5))
  
  df %<>% mutate(rel_count = as.numeric(REGION.READ.COUNT) * 100 / as.numeric(GENOMEWIDE.READ.COUNT),
                 line_id = gsub("_[1-2]$", "", sample_id))
  
  p2 = ggplot(df, aes(x = TEST.SNP.HAPLOTYPE,
                      y = rel_count,
                      #y = REGION.READ.COUNT,
                      color = TEST.SNP.HAPLOTYPE)) +
    geom_boxplot(width = 0.8, alpha = 0.6) +
    geom_jitter(width = 0.25, size = 2) +
    theme_bw() +
    scale_color_manual(name = "Genotype", values = c("darkblue", "darkorchid1", "red")) +
    #scale_color_brewer(palette = "Set1", name = "Haplotype") +
    xlab("Haplotype (all lines)") +
    #  ylab("Total counts (RPM)") +
    ylab("Relative region counts") +
    ggtitle(paste0(id, " , AI=", ai)) +
    theme(axis.text.y = element_text(size=12), axis.text.x = element_text(size=12),
          axis.title.y = element_text(size=12), axis.title.x = element_text(size=12),
          legend.text=element_text(size=12), legend.title=element_text(size=12),
          strip.text.x = element_text(size = 12), strip.text.y = element_text(size = 12),
          plot.title = element_text(color="black", face="bold", size=12, hjust=0.5))
  
  list(p2, p1, df_f1)
  
  
}

#####################################################################################################
######################################## motifs #####################################################
#####################################################################################################

annotate_peaks_by_motif_presence = function(ab_tp, peaks = AI_peaks, pwm_id = NA, p.thres = 0.0001) {
  
  peakset = peaks[[ab_tp]]
  peakset = data.frame(peakset)
  
  # motifs in peaks
  fimo = load_fimo_motifs(ab_tp, pwm = pwm_id, pval_thres = p.thres)
  peakIDs_with_motif = fimo %>% filter(name %in% peakset$name) %>% select(name) %>% unique() %>% unlist(use.names = F)
  peakset$with_motif = FALSE
  peakset$with_motif[peakset$name %in% peakIDs_with_motif] = TRUE
  
  peakset
}


plot_motif_central_enrichment = function(ab_tp, pwm = NA, shift = 1) {
  
  fimo = load_fimo_motifs(ab_tp, coord_shift = shift)
  
  # select one motif closest to the summit per peak
  fimo %<>%
    group_by(name) %>%
    filter(abs(dist2summit) == min(abs(dist2summit)))
  
  
  ggplot(fimo, aes(x = dist2summit)) + geom_density(color = cbPalette[2], size = 2) + theme_bw() +
    xlab("Distance to peak summit ") + ylab("Density") +
    theme(axis.text.y = element_text(size=12), axis.text.x = element_text(size=12), 
          axis.title.x = element_text(size=12), axis.title.y = element_text(size=12),
          plot.title = element_text(size= 12, hjust = 0.5), legend.position = "none")
  

}


# annotate CHT table by motif overlaps
get_variants_in_motifs = function(ab_tp, cht, flank = 0, pthres = 0.0001) {
  
  # subset CHT results
  cht_sel = cht %>% filter(condition == ab_tp)
  
  # get motif coordinates
  fimo = load_fimo_motifs(ab_tp, pval_thres = pthres)
  
  # overlap with variants
  motifs_gr = GRanges(fimo %>% mutate(start = start - flank, end = end + flank) %>% select(seqnames, start, end))
  cht_gr = GRanges(cht_sel %>% mutate(seqnames = TEST.SNP.CHROM, start = TEST.SNP.POS, end = TEST.SNP.POS))
  ov = findOverlaps(cht_gr, motifs_gr)
  
  # asign variants to motifs
  cht_sel["in_motif"] = FALSE
  cht_sel[unique(queryHits(ov)), "in_motif"] = TRUE
  
  cht_sel
  
}


get_share_of_peaks_with_disrupted_motifs = function(ab_tp, cht, padj = 0.0001, motif_flank = 0){
  
  df = select_top_cht_variants(ab_tp, cht)
  cht1 = get_variants_in_motifs(ab_tp, df, pthres = padj, flank = motif_flank)
  
  n1 = cht1 %>% filter(signif_strongAI & in_motif) %>% select(peak_id) %>%unique() %>% tally()
  n2 = cht1 %>% filter(signif_strongAI & !in_motif) %>% select(peak_id) %>%unique() %>% tally()
  n3 = cht1 %>% filter(!signif_strongAI & in_motif) %>% select(peak_id) %>%unique() %>% tally()
  n4 = cht1 %>% filter(!signif_strongAI & !in_motif) %>% select(peak_id) %>%unique() %>% tally()
  
  r1 = round(n1/(n1+n2) * 100, 2)
  r2 = round(n3/(n4+n3) * 100, 2)
  
  df = data.frame(ab_tp, n1, n2, n3, n4, r1, r2)
  names(df) = c("condition", "n_sign_motif", "n_signif_nomotif", "n_nonsignif_motif", "n_nonsignif_nomotif", "ratio_AIpeaks", "ratio_nonAIpeaks")
  df
  
}


prepare_snps_in_motif_2alleles_4plotting = function(ab_tp, cht, remove_indels = T, right_flank = 0, left_flank = 0) {
  
  variants_in_motifs  = parse_motifs_in_two_alleles(ab_tp, cht) %>% 
    filter(in_peak & signif_strongAI) %>%
    mutate(var_pos = variant_start - start + 1)
  
  ml = get_motif_length(ab_tp)
  aggregate_strands(variants_in_motifs, ml, rev_compliment, remove_indels)
  
}



get_variants_in_motifs_info = function(variants_df, motifs_df, rflank = 0, lflank = 0) {
  
  motifs_df %<>% mutate(start = start - lflank, end = end + rflank)
  gr_motif = GRanges(motifs_df %>% select(seqnames, start, end)) 
  gr_var = GRanges(variants_df %>% mutate(chr = TEST.SNP.CHROM, start = TEST.SNP.POS, end = TEST.SNP.POS)) 
  ov = findOverlaps(gr_motif, gr_var)
  
  motifs_df = motifs_df[queryHits(ov), c( "seqnames", "start", "end", "strand", "score", "matched_sequence")]
  variants_df = variants_df[subjectHits(ov), c("snp_id", "TEST.SNP.POS", "TEST.SNP.REF.ALLELE", "TEST.SNP.ALT.ALLELE", "AI", "AI_abs", "ALPHA", "BETA", "padjust", "is_indel", "indel_length", "peak_id", "condition", "ab", "tp")]
  variants_df %<>% rename(start_variant = "TEST.SNP.POS")
  
  cbind.data.frame(motifs_df, variants_df) %>% mutate(var_pos = start_variant - start + 1)           
  
}



get_motif_pfm_logo = function(motif_name, type = "known", path = "config$data$motif_databases$individual_pwms", rev_comp = FALSE, x_axis = FALSE, letters_fill = letter_colors){
  
  # type = c("known_pwms", "denovo_pwms")
  
  if(!exists("config")){config = load_config()}
  path = eval(parse(text = path))[type]
  path_full = file.path(path, paste0(motif_name, ".txt"))
  
  m = read.table(path_full)
  m = t(m)
  rownames(m) = letters
  
  if(rev_comp){
    m = m[rev(1:nrow(m)), rev(1:ncol(m)) ]
    rownames(m) = letters
  }
  
  view_motifs(m, use.type = "ICM", colour.scheme = letter_colors, sort.positions = T) +
    theme_classic() +
    scale_x_continuous(breaks = 1:ncol(m), labels = 1:ncol(m), name = "Position") +
    ylab("Bits") +
    ggtitle(motif_name) +
    theme(axis.text.y = element_text(size=12), axis.text.x = element_text(size=12),
        axis.title.y = element_text(size=12), axis.title.x = element_text(size=12),
        legend.position="none",
        plot.title = element_text(color="black", face="bold", size=16, hjust=0.5))
  
  # view_logo(m, colour.scheme = letter_colors, sort.positions = T) +
  #   theme_classic() +
  #   scale_x_continuous(breaks = 1:ncol(m), labels = 1:ncol(m), name = "Position") +
  #   ylab("Probability")
}


get_motif_logo = function(motif_name, type = "known", path = "config$data$motif_databases$individual_pwms", rev_comp = FALSE, x_axis = FALSE, letters_fill = letter_colors){
  
  # type = c("known_pwms", "denovo_pwms")
  
  if(!exists("config")){config = load_config()}
  path = eval(parse(text = path))[type]
  path_full = file.path(path, paste0(motif_name, ".txt"))
  
  m = read.table(path_full)
  m = t(m)
  
  if(rev_comp){
    m = m[rev(1:nrow(m)), rev(1:ncol(m)) ]
  }
  
  p = makePWM(m)
  seqLogo(p, yaxis = FALSE, xaxis = x_axis, fill = letters_fill)
}


get_motif_length = function(ab_tp, type = "known", path = "config$data$motif_databases$individual_pwms"){
  
  # type = c("known_pwms", "denovo_pwms")
  
  if(!exists("config")){config = load_config()}
  motif_name = TFs[ab_tp]
  
  path = eval(parse(text = path))[type]
  path_full = file.path(path, paste0(motif_name, ".txt"))
  
  m = read.table(path_full)
  nrow(m)

}


plot_ai_pfm = function(ab_tp, df, x_axis = T, return_total_count = T) {
  
  # select alleles with positive AI
  df %<>% filter(rel_affinity > 0)
  print(paste0("Number of variants to make PFM: ", nrow(df)))
  
  # make count matrix
  df %<>% group_by(var_pos, allele) %>% summarize(n_var = n()) %>% spread(allele, n_var) %>% ungroup()
  df[is.na(df)] = 0
  
  # number of variants supporting each position
  print("Number of variants supporting each position:")
  count = rowSums(df[ , 2:5])
  #names(count) = df$var_pos
  count = data.frame(position = df$var_pos, var_count = count, TF = TFs[ab_tp])
  #print(count)
  
  # number of each preferred allele at each position
  allele_count = df %>% gather(allele, count, 2:5)
  
  # make PFM
  df[, 2:5] = df[, 2:5] / rowSums(df[, 2:5 ])
  
  motif_length = get_motif_length(ab_tp)
  
  # for missing positions - 
  nn = 1:motif_length
  missing_positions = nn[!nn %in% df$var_pos]
  #print(paste0("Missing positions: ",  paste(missing_positions, collapse = ", ")))
  
  if(length(missing_positions) > 0) {
    
    tmp = data.frame(matrix(rep(rep(0.25, 4), length(missing_positions)), nrow = length(missing_positions)))
    names(tmp) = c("A", "C", "G", "T")
    tmp = cbind.data.frame(data.frame(var_pos = missing_positions), tmp)
    # add missing positions
    df = rbind.data.frame(df, tmp)
    # and to the count matrix
    count = rbind.data.frame(count, data.frame(position = missing_positions, var_count = rep(0, length(missing_positions)), TF = TFs[ab_tp]))
  }
  
  
  df %<>% arrange(var_pos) %>% select(-var_pos)
  
  # prepare for plottting logo 
  mat = t(df)
  p = makePWM(mat)
  seqLogo(p, yaxis = FALSE, xaxis = x_axis, fill = letter_colors)
  
  # return count matrix (total or per allele)
  if(return_total_count) {
    return(count)    
  } else{
    return(allele_count)  
  } 
  
}



parse_motifs_in_two_alleles_SNPonly = function(ab_tp, cht, radius = 15, path2_base1 = "/all_SNPs_alleles/FIMO/", path2_base2 = "combined_motifs/fimo.tsv", peak_radius = 250) {
  
  seq_len = radius*2 + 1
  variant_loc = radius + 1
  
  path2_ref = file.path(path2_base1, paste0("REF_", seq_len, "bp"), path2_base2)
  path2_alt = file.path(path2_base1, paste0("ALT_", seq_len, "bp"), path2_base2)
  
  ref = load_fimo_motifs(ab_tp, path2 = path2_ref, TF_specific = F, parse_motif_coordinates = F)
  alt = load_fimo_motifs(ab_tp, path2 = path2_alt, TF_specific = F, parse_motif_coordinates = F)
  
  df = merge(ref, alt, by = c("motif_id", "motif_alt_id", "sequence_name"), all = T) %>% 
    rename(snp_id = sequence_name) %>%
    mutate(start = pmin(start.x, start.y, na.rm = T), stop = pmax(stop.x, stop.y, na.rm = T), motif_length = stop - start + 1) %>%
    mutate(in_motif = start <= variant_loc & stop >= variant_loc) %>% filter(in_motif)
  
  cht_sel = cht %>% filter(condition == ab_tp) %>% mutate(in_peak =  dist2summit < peak_radius) %>% 
    select(condition, snp_id, signif, signif_strongAI, AI, AI_abs, in_peak, dist2summit, peak_id, peak_summit)
  
  merge(cht_sel, df, by = "snp_id")
  
}



parse_motifs_in_two_alleles = function(ab_tp, cht, radius = 15, peak_radius = 250, 
                                       path2_base1 = "/all_variants_alleles/FIMO/", path2_base2 = "combined_motifs/fimo.tsv", 
                                       subset_motif = T, pwm_id = NA, subset_cht = T, subset_tf = F, TF_list = TFs) {
  
  # subset_motif - use single selected motif per TF (default - TRUE)
  # subset_tf - select all motifs per TF (default - FALSE, if TRUE - need to set subset_motif to FALSE)

  
  seq_len = radius*2 + 1
  variant_start = radius + 1
  
  path2_ref = file.path(path2_base1, paste0("REF_radius", radius, "bp"), path2_base2)
  path2_alt = file.path(path2_base1, paste0("ALT_radius", radius, "bp"), path2_base2)
  
  ref = load_fimo_motifs(ab_tp, pwm = pwm_id, path2 = path2_ref, TF_specific = F, parse_motif_coordinates = T, by_peak = F, subset_by_tf_motif = subset_motif) %>%
          mutate(variant_start = as.numeric(reg_start) + radius, variant_end = as.numeric(reg_end) - radius) %>%
          mutate(in_motif = start <= variant_end & end >= variant_start) %>% filter(in_motif)

  alt = load_fimo_motifs(ab_tp, pwm = pwm_id, path2 = path2_alt, TF_specific = F, parse_motif_coordinates = T, by_peak = F, subset_by_tf_motif = subset_motif) %>%
          mutate(variant_start = as.numeric(reg_start) + radius, variant_end = as.numeric(reg_end) - radius) %>%
          mutate(in_motif = start <= variant_end & end >= variant_start) %>% filter(in_motif)    
  
  df = merge(ref, alt, by = c("motif_id", "motif_alt_id", "name", "seqnames", "start", "end", "strand", "variant_start"), all = T, suffixes = c(".ref", ".alt")) %>% rename(snp_id = name) 

  if(subset_tf) {
    df %<>% filter(motif_alt_id == TF_list[ab_tp])
  }
  
  cht_sel = cht %>% 
    mutate(in_peak =  dist2summit < peak_radius) %>% 
    select(condition, snp_id, TEST.SNP.REF.ALLELE, TEST.SNP.ALT.ALLELE, AI, AI_abs, ALPHA, BETA, 
           TOTAL.READ.COUNT, TOTAL.AS.READ.COUNT, 
           padjust ,signif, signif_strongAI, in_peak, dist2summit, peak_id, peak_summit, is_indel, indel_length)
  
  if(subset_cht) {
    cht_sel %<>% filter(condition == ab_tp)
  }
  
  merge(cht_sel, df, by = "snp_id")
  
}

aggregate_strands = function(variants_in_motifs, motif_length, rev_compl = rev_compliment, remove_indels = T) {
  
  #get_motif_length(motif_name)
  
  if(remove_indels) { variants_in_motifs %<>% filter(is_indel == F) }
  
  tmp_plus = variants_in_motifs %>% filter(strand == "+") 
  
  tmp_minus = variants_in_motifs %>% filter(strand == "-") %>% 
    mutate(var_pos = motif_length - var_pos + 1)
  
  # reverse compliment
  tmp_minus$TEST.SNP.REF.ALLELE = sapply(tmp_minus$TEST.SNP.REF.ALLELE, function(x) paste(rev(rev_compliment[strsplit(x, split = "")[[1]]]), collapse = ""))
  tmp_minus$TEST.SNP.ALT.ALLELE = sapply(tmp_minus$TEST.SNP.ALT.ALLELE, function(x) paste(rev(rev_compliment[strsplit(x, split = "")[[1]]]), collapse = ""))
  
  # for SNPs only:
  #TEST.SNP.REF.ALLELE = rev_compl[TEST.SNP.REF.ALLELE],
  #TEST.SNP.ALT.ALLELE = rev_compl[TEST.SNP.ALT.ALLELE]
  
  tmp = rbind.data.frame(tmp_plus, tmp_minus)
  
  tmp %>% gather(allele_type, allele, TEST.SNP.REF.ALLELE:TEST.SNP.ALT.ALLELE) %>%
    arrange(snp_id, var_pos) %>%
    mutate(allele_type = gsub("TEST.SNP.(.*).ALLELE", "\\1", allele_type),
           # ALPHA = as.numeric(ALPHA), BETA = as.numeric(BETA)) %>%
           affinity = ifelse(allele_type == "REF", ALPHA, BETA),
           share_affinity = ifelse(allele_type == "REF", ALPHA/(ALPHA + BETA), BETA/(ALPHA + BETA)),
           rel_affinity = ifelse(allele_type == "REF", ALPHA/(ALPHA + BETA) - 0.5, BETA/(ALPHA + BETA) - 0.5))
  
}

prepare_variants = function(ab_tp, cht, rev_compl = rev_compliment, remove_indels = T, right_flank = 0, left_flank = 0) {
  
  # load data
  #motif = retrieve_motif_id(ab_tp, motifs_optimised)
  # cht_sel = select_top_cht_variants(ab_tp, cht) %>% filter(signif_strongAI)
  cht_sel = cht %>% filter(condition == ab_tp & signif_strongAI)
  fimo = load_fimo_motifs(ab_tp)
  ml = get_motif_length(ab_tp) + right_flank + left_flank
  
  variants_in_motifs = get_variants_in_motifs_info(cht_sel, fimo, rflank = right_flank, lflank = left_flank)
  aggregate_strands(variants_in_motifs, ml, rev_compliment, remove_indels)
  
}


prepare_variants_2alleles = function(ab_tp, cht, rev_compl = rev_compliment, remove_indels = T, right_flank = 0, left_flank = 0) {
  
  # load data
  #motif = retrieve_motif_id(ab_tp, motifs_optimised)
  # cht_sel = select_top_cht_variants(ab_tp, cht) %>% filter(signif_strongAI)
  cht_sel = cht %>% filter(condition == ab_tp & signif_strongAI)
  fimo = load_fimo_motifs(ab_tp)
  ml = get_motif_length(ab_tp) + right_flank + left_flank
  
  variants_in_motifs = get_variants_in_motifs_info(cht_sel, fimo, rflank = right_flank, lflank = left_flank)
  aggregate_strands(variants_in_motifs, ml, rev_compliment, remove_indels)
  
}


prepare_ns_variants = function(ab_tp, cht, rev_compl = rev_compliment, remove_indels = T) {
  
  # load data
  #motif = retrieve_motif_id(ab_tp, motifs_optimised)
  # cht_sel = select_top_cht_variants(ab_tp, cht) %>% filter(signif_strongAI)
  cht_sel = cht %>% filter(condition == ab_tp & padjust > 0.1)
  fimo = load_fimo_motifs(ab_tp)
  ml = get_motif_length(ab_tp)
  
  variants_in_motifs = get_variants_in_motifs_info(cht_sel, fimo)
  aggregate_strands(variants_in_motifs, ml, rev_compliment, remove_indels)
  
}



make_ppm = function(df) {
  
  # select alleles with positive AI
  df %<>% filter(rel_affinity > 0)
  
  # make count matrix
  df %<>% group_by(var_pos, allele) %>% summarize(n_var = n()) %>% spread(allele, n_var) %>% ungroup()
  df[is.na(df)] = 0
  
  
  motif_length = get_motif_length(ab_tp)
  
  # for missing positions - 
  nn = 1:motif_length
  missing_positions = nn[!nn %in% df$var_pos]
  #print(paste0("Missing positions: ",  paste(missing_positions, collapse = ", ")))
  
  if(length(missing_positions) > 0) {
    
    tmp = data.frame(matrix(rep(rep(0, 4), length(missing_positions)), nrow = length(missing_positions)))
    names(tmp) = c("A", "C", "G", "T")
    tmp = cbind.data.frame(data.frame(var_pos = missing_positions), tmp)
    # add missing positions
    df = rbind.data.frame(df, tmp)
    # and to the count matrix
    #count = rbind.data.frame(count, data.frame(position = missing_positions, var_count = rep(0, length(missing_positions)), TF = TFs[ab_tp]))
  }
  
  
  df %<>% arrange(var_pos) %>% select(-var_pos)
  t(df)
}


make_pcm = function(ab_tp, df) {
  
  # select alleles with positive AI
  df %<>% filter(rel_affinity > 0)
  print(paste0("Number of variants to make PFM: ", nrow(df)))
  
  # make count matrix
  df %<>% group_by(var_pos, allele) %>% summarize(n_var = n()) %>% spread(allele, n_var) %>% ungroup()
  df[is.na(df)] = 0
  
  # number of variants supporting each position
  print("Number of variants supporting each position:")
  count = rowSums(df[ , 2:5])
  #names(count) = df$var_pos
  count = data.frame(position = df$var_pos, var_count = count, TF = TFs[ab_tp])
  print(count)
  
  # number of each preferred allele at each position
  allele_count = df %>% gather(allele, count, 2:5)
  
  # make PFM
  df[, 2:5] = df[, 2:5] / rowSums(df[, 2:5 ])
  
  motif_length = get_motif_length(ab_tp)
  
  # for missing positions - 
  nn = 1:motif_length
  missing_positions = nn[!nn %in% df$var_pos]
  #print(paste0("Missing positions: ",  paste(missing_positions, collapse = ", ")))
  
  if(length(missing_positions) > 0) {
    
    tmp = data.frame(matrix(rep(rep(0.25, 4), length(missing_positions)), nrow = length(missing_positions)))
    names(tmp) = c("A", "C", "G", "T")
    tmp = cbind.data.frame(data.frame(var_pos = missing_positions), tmp)
    # add missing positions
    df = rbind.data.frame(df, tmp)
    # and to the count matrix
    count = rbind.data.frame(count, data.frame(position = missing_positions, var_count = rep(0, length(missing_positions)), TF = TFs[ab_tp]))
  }
  
  
  df %<>% arrange(var_pos) %>% select(-var_pos)
  t(df)
}





# 
# plot_ai_pfm_neutral_variants = function(ab_tp, df, x_axis = T, return_total_count = T) {
# 
#   # select alleles with positive AI
#   print(paste0("Number of variants to make PFM: ", nrow(df)))
# 
#   # make count matrix
#   df %<>% group_by(var_pos, allele) %>% summarize(n_var = n()) %>% spread(allele, n_var) %>% ungroup()
#   df[is.na(df)] = 0
# 
#   # number of variants supporting each position
#   print("Number of variants supporting each position:")
#   count = rowSums(df[ , 2:5])
#   #names(count) = df$var_pos
#   count = data.frame(position = df$var_pos, var_count = count, TF = TFs[ab_tp])
#   #print(count)
# 
#   # number of each preferred allele at each position
#   allele_count = df %>% gather(allele, count, 2:5)
# 
#   # make PFM
#   df[, 2:5] = df[, 2:5] / rowSums(df[, 2:5 ])
# 
#   motif_length = get_motif_length(ab_tp)
# 
#   # for missing positions -
#   nn = 1:motif_length
#   missing_positions = nn[!nn %in% df$var_pos]
#   #print(paste0("Missing positions: ",  paste(missing_positions, collapse = ", ")))
# 
#   if(length(missing_positions) > 0) {
# 
#     tmp = data.frame(matrix(rep(rep(0.25, 4), length(missing_positions)), nrow = length(missing_positions)))
#     names(tmp) = c("A", "C", "G", "T")
#     tmp = cbind.data.frame(data.frame(var_pos = missing_positions), tmp)
#     # add missing positions
#     df = rbind.data.frame(df, tmp)
#     # and to the count matrix
#     count = rbind.data.frame(count, data.frame(position = missing_positions, var_count = rep(0, length(missing_positions)), TF = TFs[ab_tp]))
#   }
# 
# 
#   df %<>% arrange(var_pos) %>% select(-var_pos)
# 
#   # prepare for plottting logo
#   mat = t(df)
#   p = makePWM(mat)
#   seqLogo(p, yaxis = FALSE, xaxis = x_axis, fill = letters_fill)
# 
#   # return count matrix (total or per allele)
#   if(return_total_count) {
#     return(count)
#   } else{
#     return(allele_count)
#   }
# 
# }



plot_number_variants_per_motif_position = function(res) {
  
  ggplot(res, aes(x = as.integer(position), y = var_count)) + 
    geom_bar(stat = "identity", width = 0.5, fill = "darkblue") + 
    theme_bw() +
    xlab("Position") +
    ylab("# of variants") +
    geom_text(aes(x = as.integer(position), y = var_count + 1, label = var_count), size = 5) + 
    scale_x_continuous(labels = res$position, breaks = res$position) +
    theme(legend.position = c(0.25, 0.9), legend.background = element_blank(),
          axis.text.y = element_text(size=14), axis.text.x = element_text(size=14), 
          axis.title.x = element_text(size=14), axis.title.y = element_text(size=14),
          legend.text=element_text(size=14), legend.title=element_text(size=14))
}

plot_number_of_preferred_alleles_per_motif_position = function(res) {
  
  res_tot = res %>% group_by(var_pos) %>% summarize(n = sum(count))
  
  ggplot(res, aes(x = as.integer(var_pos), y = count)) + 
    geom_bar(aes(fill = allele), stat = "identity", width = 0.8) + 
    scale_x_continuous(labels = res$var_pos, breaks = res$var_pos) +
    theme_bw() +
    xlab("Position") +
    ylab("# of variants") +
    scale_fill_manual(values = letter_colors, guide="none") +
    geom_text(aes(x = as.integer(var_pos), y = n + 1, label = n), data = res_tot, size = 5) + 
    theme(axis.text.y = element_text(size=14), axis.text.x = element_text(size=14), 
          axis.title.x = element_text(size=14), axis.title.y = element_text(size=14))
}




plot_allele_preferences_at_motif_position = function(variants_in_motifs, TF, position, letter_cols = letter_colors) {
  
  
  tmp = variants_in_motifs %>% filter(ab == TF & var_pos == position)
  cols = letter_cols[sort(unique(tmp$allele))]
  
  ggplot(tmp, aes(x = allele, y = share_affinity, fill = allele, color = allele)) + 
    geom_hline(yintercept = 0.5, linetype = "dashed", color = "darkgrey") +
    geom_boxplot(width = 0.5, alpha = 0.5) + 
    geom_jitter(width = 0.4, size = 2) +
    scale_fill_manual(name = "Allele", values = cols) +
    scale_color_manual(name = "Allele", values = cols) +
    xlab("Allele") +
    ylab("Share of allele counts") +
    ggtitle(paste(tf, ", position", pos)) +
    theme_bw() +
    theme(axis.text.y = element_text(size=14), axis.text.x = element_text(size=14), 
          axis.title.x = element_text(size=14), axis.title.y = element_text(size=14),
          plot.title = element_text(size= 16, hjust = 0.5), legend.position = "none")
  
}


get_variant_distance2TFmotif = function(ab_tp, cht, fimo, same_peak = F) {
  
  cht %<>% filter(condition == ab_tp)
  cht_gr = cht %>% mutate(chr = TEST.SNP.CHROM, start = TEST.SNP.POS, end = TEST.SNP.POS) %>% GRanges()
  
  # if(is.na(motifs)) {
  #   fimo = load_fimo_motifs(ab_tp, coord_shift = 1) # default - motifs scanned in peaks 
  # }
  
  fimo %<>% filter(condition == ab_tp)
  fimo_gr = fimo %>% select(seqnames, start, end) %>% GRanges()
  
  dist = distanceToNearest(cht_gr, fimo_gr)
  ov = findOverlaps(cht_gr, fimo_gr)
  
  cht$dist2motif = mcols(dist)$distance + 1 # distance 1 is considered as 0 (e.g. SNP at pos 9 and motif starting at 10)
  cht$dist2motif[queryHits(ov)] = 0
  cht$motif_peak_id = fimo$peak_id[subjectHits(dist)]
  
  if(same_peak) {
    cht %<>% filter(peak_id == motif_peak_id) %>% 
             select(-motif_peak_id)
  }
  
  cht 
}



get_distance2motif_summary = function(ab_tp, cht, dist_breaks, dist_labels, N_rep = 100, same_peak_filter = F) {
  
  # all variants
  cht %<>% filter(condition == ab_tp)
  N_var = length(unique(cht$snp_id))
  N_peak = length(unique(cht$peak_id))
  
  # significant variants
  cht_sign = cht %>% filter(signif_strongAI)
  N_var_sign = length(unique(cht_sign$snp_id))
  N_peak_sign = length(unique(cht_sign$peak_id)) 
  
  # significant with motifs in peaks
  cht_sign = get_variant_distance2TFmotif(ab_tp, cht_sign, same_peak = same_peak_filter) 
  N_peak_sign_motif = length(unique(cht_sign$peak_id))
  N_peak_sign_no_motif = N_peak_sign - N_peak_sign_motif
  share_no_motif = N_peak_sign_no_motif / N_peak_sign
  
  
  #df = rbind.data.frame(cht_sign, cht_bg %>% mutate(type = "background"))
  
  sign_sum = cht_sign %>% 
    group_by(peak_id) %>% 
    mutate(min_dist = min(dist2motif)) %>%
    filter(dist2motif == min_dist) %>%
    select(peak_id, dist2motif) %>% unique() %>% ungroup() %>%
    mutate(N_tot = n(), dist_bin = cut(dist2motif, breaks = dist_breaks, labels = dist_labels)) %>%
    group_by(dist_bin) %>%
    summarize(n = n(), share = n / mean(N_tot), share_full = n / N_peak_sign)
  
  sign_sum = rbind.data.frame(sign_sum, 
                              data.frame(dist_bin = "no motif", n = N_peak_sign_no_motif, share = NA, share_full = share_no_motif))
  sign_sum$type = "AI peaks"
  sign_sum$condition = ab_tp
  
  
  background_sum = lapply(1:N_rep, function(i) {
    
    print(i)
    # background 
    cht_bg = cht[sample(1:nrow(cht), size = N_var_sign), ]
    N_peak_bg = length(unique(cht_bg$peak_id)) 
    cht_bg = get_variant_distance2TFmotif(ab_tp, cht_bg, same_peak = same_peak_filter) 
    N_peak_bg_motif = length(unique(cht_bg$peak_id)) 
    N_peak_bg_no_motif = N_peak_bg - N_peak_bg_motif
    share_bg_no_motif = N_peak_bg_no_motif / N_peak_bg
    
    bg_sum = cht_bg %>% 
      group_by(peak_id) %>% 
      mutate(min_dist = min(dist2motif)) %>%
      filter(dist2motif == min_dist) %>%
      select(peak_id, dist2motif) %>% unique() %>% ungroup() %>%
      mutate(N_tot = n(), dist_bin = cut(dist2motif, breaks = dist_breaks, labels = dist_labels)) %>%
      group_by(dist_bin) %>%
      summarize(n = n(), share = n / mean(N_tot), share_full = n / N_peak_bg)
    
    bg_sum = rbind.data.frame(bg_sum, 
                              data.frame(dist_bin = "no motif", n = N_peak_bg_no_motif, share = NA, share_full = share_bg_no_motif))
    
    bg_sum
    
  }) %>% bind_rows()
  
  background_sum$type = "background"
  background_sum$condition = ab_tp
  
  list(sign_sum, background_sum)
  
}



#####################################################################################################
######################################## co-affected peaks ##########################################
#####################################################################################################




get_coaffected_peaks_variants = function(ab_tp1, ab_tp2, cht, motifs, filter_top = F) {
  
  cht1 = get_variant_distance2TFmotif(ab_tp1, cht, motifs)
  cht2 = get_variant_distance2TFmotif(ab_tp2, cht, motifs)
  
  cht1 = get_top_variants(cht1, filter = filter_top)
  cht2 = get_top_variants(cht2, filter = filter_top)
  
  
  merge(cht1, cht2, by = "snp_id") %>%
    mutate(peak_dist = abs(peak_summit.x - peak_summit.y),
           bin_dist = cut(peak_dist, breaks = c(-251, 250, 1000, 5000), include.lowest = T, labels = c("<250 bp (overlapping peaks)", "250-1000 bp", ">1000 bp"))) %>%
    select(snp_id, peak_id.x, peak_id.y, peak_dist, bin_dist, padjust.x, padjust.y, AI.x, AI.y,  AI_abs.x, AI_abs.y, top_var.x, top_var.y, dist2motif.x, dist2motif.y,
           TEST.SNP.CHROM.x, REGION.START.x, REGION.END.x, REGION.START.y, REGION.END.y) %>%
    mutate(bin_dist = factor(bin_dist, levels = c("<250 bp (overlapping peaks)", "250-1000 bp", ">1000 bp"))) %>% 
    filter(bin_dist == "<250 bp (overlapping peaks)") # This line filters for variants on the same peak! bin_dist is binned peaks (1517-1518) distance
  
}



plot_AI_pairwise_correlations = function(res_sign, res_ns, x1 = 0.2, y1 = 0.75, x2 = 0.5, y2 = 0.25){
  
  res_cor = res_sign %>%
    dplyr::select(peak_id.x, peak_id.y, AI.x, AI.y, bin_dist) %>% unique()
  
  sum = res_cor %>% dplyr::summarize(n = n(), cor = cor(AI.x, AI.y))
  n_peaks = res_cor %>% dplyr::select(peak_id.x, peak_id.y) %>% unique() %>% nrow()
  #lab = paste0("R=", round(sum[1, 2], 2), ", ", sum[1, 1], " AI variants \n(", n_peaks, " peak pairs)")
  
  concordance = round((sum(ifelse((res_cor$AI.x>0.5 & res_cor$AI.y>0.5) | (res_cor$AI.x<0.5 & res_cor$AI.y<0.5), TRUE, FALSE)) / nrow(res_cor) * 100), 1)
  #concordance = nrow(res_cor)
  lab = paste0(concordance, "% concordance\n", sum[1, 1], " AI variants \n(", n_peaks, " peak pairs)")
  
  res_cor_ns = res_ns %>%
    dplyr::select(peak_id.x, peak_id.y, AI.x, AI.y, bin_dist) %>% unique() 
  
  sum_ns = res_cor_ns %>% dplyr::summarize(n = n(), cor = cor(AI.x, AI.y))
  n_peaks_ns = res_ns %>% dplyr::select(peak_id.x, peak_id.y) %>% unique() %>% nrow()
  lab_ns = paste0("R=", round(sum_ns[1, 2], 2), ", ", sum_ns[1, 1], " n.s. variants \n(", n_peaks_ns, " peak pairs)")
  
  ggplot(res_cor, aes(x = AI.x, y = AI.y, color = bin_dist)) + 
    geom_point_rast(size = 0.1, color = "lightgrey", alpha = 0.8, data = res_cor_ns, aes(x = AI.x, y = AI.y)) +
    geom_point(size = 1.5) + 
    geom_smooth(method = "lm", se = F, color= "black", size = 0.3) +
    theme_bw() +
    scale_color_brewer(palette = "Set2", name = "Distance between \npeak summits") +
    xlab(paste0("Allele Imbalance, ", ab_tp_labels[ab_tp1])) + 
    ylab(paste0("Allele Imbalance, ",ab_tp_labels[ab_tp2])) +
    guides(colour = guide_legend(override.aes = list(size=6), nrow=2, byrow=TRUE)) +
    annotate("text", x = x1, y = y1, hjust = 0, label = lab, size = 5) + 
    annotate("text", y = y2, x = x2, hjust = 0, label = lab_ns, size = 5, color = "darkgrey") + 
    theme(axis.text.y = element_text(size=12), axis.text.x = element_text(size=12), 
          axis.title.x = element_text(size=14), axis.title.y = element_text(size=14),
          legend.text=element_text(size=12), legend.title=element_text(size=12),
          legend.position = "bottom")
  
}



overlap_ai_tf_peaks_with_ai_atac = function(ab_tp, cht, atac_ai_gr) {
  
  tp = gsub("-", "", timepoints[ab_tp])
  
  # get TF peaks
  AI_peaks = get_peaks_from_cht(ab_tp, cht)
  names(mcols(AI_peaks)) = paste0(names(mcols(AI_peaks)), ".tf")
  
  # get ATAC peaks for selected time-point
  AI_atac = atac_ai_gr[atac_ai_gr$time == tp]
  names(mcols(AI_atac)) = paste0(names(mcols(AI_atac)), ".atac")
  
  ov = findOverlaps(AI_peaks, AI_atac)
  
  
  cbind.data.frame(mcols(AI_peaks[queryHits(ov)]), mcols(AI_atac[subjectHits(ov)]))
  
  
}


overlap_ai_tf_peaks_with_ai_k4me3 = function(ab_tp, cht, k4me3_ai_gr) {
  
  tp = gsub("-", "", timepoints[ab_tp])
  
  # get TF peaks
  AI_peaks = get_peaks_from_cht(ab_tp, cht)
  names(mcols(AI_peaks)) = paste0(names(mcols(AI_peaks)), ".tf")
  
  # get ATAC peaks for selected time-point
  AI_atac = k4me3_ai_gr[k4me3_ai_gr$time == tp]
  names(mcols(k4me3_ai_gr)) = paste0(names(mcols(k4me3_ai_gr)), ".k4me3")
  
  ov = findOverlaps(AI_peaks, k4me3_ai_gr)
  
  
  cbind.data.frame(mcols(AI_peaks[queryHits(ov)]), mcols(k4me3_ai_gr[subjectHits(ov)]))
  
  
}


overlap_ai_tf_peaks_with_ai_k27ac = function(ab_tp, cht, k27ac_ai_gr) {
  
  tp = gsub("-", "", timepoints[ab_tp])
  
  # get TF peaks
  AI_peaks = get_peaks_from_cht(ab_tp, cht)
  names(mcols(AI_peaks)) = paste0(names(mcols(AI_peaks)), ".tf")
  
  # get ATAC peaks for selected time-point
  AI_atac = k27ac_ai_gr[k27ac_ai_gr$time == tp]
  names(mcols(k27ac_ai_gr)) = paste0(names(mcols(k27ac_ai_gr)), ".k27ac")
  
  ov = findOverlaps(AI_peaks, k27ac_ai_gr)
  
  
  cbind.data.frame(mcols(AI_peaks[queryHits(ov)]), mcols(k27ac_ai_gr[subjectHits(ov)]))
  
  
}

#####################################################################################################
######################################## statistics #################################################
#####################################################################################################


fisher_test_two_groups = function(df, ab_tp, group1, group1_val, group2, group2_val) {
  
  df %<>% filter(condition == ab_tp)
  
  n1 = df[df[group1] == group1_val[1] & df[group2] == group2_val[1], ] %>% tally() %>% as.numeric()
  n2 = df[df[group1] == group1_val[1] & df[group2] == group2_val[2], ] %>% tally() %>% as.numeric()
  n3 = df[df[group1] == group1_val[2] & df[group2] == group2_val[1], ] %>% tally() %>% as.numeric()
  n4 = df[df[group1] == group1_val[2] & df[group2] == group2_val[2], ] %>% tally() %>% as.numeric()
  
  
  mat = matrix(c(n1, n2, n3, n4),
               nrow = 2)
  
  ft = fisher.test(mat)  
  
  data.frame(
    condition = ab_tp,
    comp1 = paste(group1, paste(group1_val, collapse = "_vs_"), sep = "_"),
    comp2 = paste(group2, paste(group2_val, collapse = "_vs_"), sep = "_"),
    n11 = n1,
    n12 = n2,
    n21 = n3, 
    n22 = n4,
    r1 = n1 / (n1 + n2),
    r2 = n3 / (n3 + n4),
    pval = ft$p.value,
    odds_ratio = ft$estimate,
    row.names = NULL
  )
  
}


fisher_test_AIpeaks_cobinding = function(ab_tp, peaks = peaksets, cond_list = ab_tp_list){
  
  # Keep only one time-point for tested TF 
  filter_list = c("mef2/68", "bin/68", "mef2/1012", "bin/1012")
  names(filter_list) = c("mef2/1012", "bin/1012", "mef2/68", "bin/68")
  
  if(ab_tp %in% names(filter_list)) {
    ids = setdiff(cond_list, c(ab_tp, filter_list[ab_tp]))
  } else {
    ids = setdiff(cond_list, ab_tp)
  }
  
  set1 = peaks[[ab_tp]]
  set2 = Reduce(union, peaks[ids])
  
  # split set1 into AI and nonAI
  # AI = set1[set1$isAI]
  # nonAI = set1[!set1$isAI]
  AI = set1[set1$AI_peak]
  nonAI = set1[!set1$AI_peak]
  
  N1 = length(AI)
  N2 = length(nonAI)
  
  n1 = length(unique(queryHits(findOverlaps(AI, set2)))) # AI co-binding
  n2 = length(unique(queryHits(findOverlaps(nonAI, set2)))) # nonAI co-binding
  n3 = N1 - n1 # AI non-cobinding
  n4 = N2 - n2 # nonAI non co-binding
  
  mat = matrix(c(n1, n3, n2, n4), nrow = 2)
  
  ft = fisher.test(mat)  
  
  data.frame(
    condition = ab_tp,
    comp1 = "isAI_TRUE_vs_FALSE",
    comp2 = "is_co_binding_TRUE_vs_FALSE",
    n11 = n1,
    n12 = n3,
    n21 = n2, 
    n22 = n4,
    r1 = n1 / (n1 + n3),
    r2 = n2 / (n2 + n4),
    pval = ft$p.value,
    odds_ratio = ft$estimate,
    row.names = NULL
  )
  
}



read_input = function(filename) {
  df = read.table(filename, header=TRUE)
  df$Basenji_REF = (df$Basenji_ref_position_minus1 + df$Basenji_ref_position + df$Basenji_ref_position_plus1)
  df$Basenji_ALT = (df$Basenji_alt_position_minus1 + df$Basenji_alt_position + df$Basenji_alt_position_plus1)
  df$Basenji_AI = (df$Basenji_REF / (df$Basenji_REF + df$Basenji_ALT))
  df$Basenji_abs_AI = ifelse(df$Basenji_AI > 0.5, df$Basenji_AI-0.5, 0.5-df$Basenji_AI)
  df$overlaps_motif = ifelse(df$overlaps_motif == 0, "no_overlap", "overlaps_motif")
  df$overlaps_peak = ifelse(df$overlaps_peak == 0, "no_overlap", "overlaps_peak")
  df$TSS = gsub("TSS", "Proximal", gsub("no_TSS", "Distal", df$overlaps_TSS))
  df[!is.na(df$df), ]
  
  return(df)
}

take_best_Basenji_in_peak = function(df) {
  df_filter = df %>% 
    group_by(peak_ID) %>% 
    filter(Basenji_abs_AI == max(Basenji_abs_AI)) 
  
  df_filter = as.data.frame(df_filter)
  return(df_filter)
}

take_next_best_Basenji_in_peak = function(df, subset) {
  df_minus_subset = subset(df, !variant_peak_ID %in% subset$variant_peak_ID)
  
  df_filter = take_best_Basenji_in_peak(df_minus_subset)
  
  return(df_filter) 
}


plot_correlation = function(df, title) {
  p = ggplot(df, aes(x=Basenji_AI, y=AI)) + 
    geom_point(colour="grey40", alpha=0.5, size=2) +
    geom_line(stat="smooth", method = "lm", alpha = 0.5) +
    xlim(0,1) +
    ylim(0,1) +
    xlab("Basenji predicted Allelic Imbalance") +
    ylab("Experimental Allelic Imbalance") +
    ggtitle(title) +
    geom_vline(xintercept = 0.5, colour = "#C92B27", linetype="dashed") +
    geom_hline(yintercept = 0.5, colour = "#C92B27", linetype="dashed") +
    annotate("text", x = 0.15, y = 0.15, label = paste0("N = ", nrow(subset(df, Basenji_AI < 0.5 & AI < 0.5)) 
                                                        , " (", round(nrow(subset(df, Basenji_AI < 0.5 & AI < 0.5)) / nrow(df) * 100, 1), "%)"), size=4) +
    annotate("text", x = 0.15, y = 0.85, label = paste0("N = ", nrow(subset(df, Basenji_AI < 0.5 & AI > 0.5))
                                                        , " (", round(nrow(subset(df, Basenji_AI < 0.5 & AI > 0.5)) / nrow(df) * 100, 1), "%)"), size=4) +
    annotate("text", x = 0.85, y = 0.15, label = paste0("N = ", nrow(subset(df, Basenji_AI > 0.5 & AI < 0.5))
                                                        , " (", round(nrow(subset(df, Basenji_AI > 0.5 & AI < 0.5)) / nrow(df) * 100, 1), "%)"), size=4) +
    annotate("text", x = 0.85, y = 0.85, label = paste0("N = ", nrow(subset(df, Basenji_AI > 0.5 & AI > 0.5))
                                                        , " (", round(nrow(subset(df, Basenji_AI > 0.5 & AI > 0.5)) / nrow(df) * 100, 1), "%)"), size=4) +
    annotate("text", x = 0.15, y = 0.7, label = paste0("r = ", round(cor(df$Basenji_AI, df$AI), 4)), size=4) +
    theme_bw() + 
    theme(panel.grid = element_line(colour = "grey80", linewidth = 1), axis.text = element_text(size = 12)) +
    theme(axis.title = element_text(size = 12), plot.title = element_text(size=12)) +
    theme(panel.grid.minor = element_line(linewidth = 0.25), panel.grid.major = element_line(linewidth = 0.5)) +
    theme(legend.position = "none")
  
  return(p)
}

plot_counts_barplot = function(df, col1, col2){
  counts = as.data.frame(table(df[[col1]], df[[col2]]))
  colnames(counts) = c(col1, col2, "counts")
  
  p = ggplot(counts, aes(x=get(col1), y=counts, fill=get(col2))) + 
    geom_bar(position = 'stack', stat='identity') +
    geom_text(aes(label=counts),  position = position_stack(vjust = 0.5)) +
    xlab("") +
    ylab("N variants") +
    #labs(fill = "") +
    scale_fill_manual(values = c("grey70", "#FFA736", "#FF2341")) +
    theme_bw() + 
    theme(panel.grid = element_line(colour = "grey80", linewidth = 1), axis.text = element_text(size = 12)) +
    theme(axis.title = element_text(size = 12), plot.title = element_text(size=12)) +
    theme(panel.grid.minor = element_line(linewidth = 0.25), panel.grid.major = element_line(linewidth = 0.5)) +
    theme(legend.title=element_text(size=12))
  
  return(p)
}




plot_correlation_with_point = function(df, variant, nudge_x, nudge_y) {
  
  ggplot(df, aes(x=Basenji_AI, y=AI, label = variant_ID)) + 
    geom_point(aes(colour=condition, label = variant_ID), size=3) +
    geom_line(stat="smooth", method = "lm", alpha = 0.5) +
    geom_text_repel(data = subset(df, variant_ID==variant), nudge_x = nudge_x, nudge_y = nudge_y, size = 5, box.padding = 0.5, point.padding = 0.5, force = 100, segment.size = 0.8, segment.color = "grey0", direction = "x")  +
    xlim(0,1) +
    ylim(0,1) +
    xlab("Basenji predicted Allelic Imbalance") +
    ylab("Experimental Allelic Imbalance") +
    ggtitle(paste0("significant AI, Basenji > 0.1 AI variant: ", variant)) +
    geom_vline(xintercept = 0.5, colour = "#C92B27", linetype="dashed") +
    geom_hline(yintercept = 0.5, colour = "#C92B27", linetype="dashed") +
    scale_colour_manual(values = brewer.pal(n = 6, name = "Set1")) +
    theme_bw() + 
    theme(panel.grid = element_line(colour = "grey80", linewidth = 1), axis.text = element_text(size = 12)) +
    theme(axis.title = element_text(size = 12), plot.title = element_text(size=12)) +
    theme(panel.grid.minor = element_line(linewidth = 0.25), panel.grid.major = element_line(linewidth = 0.5)) +
    theme(legend.position = "right") 
}

rev_comp_motif = function(motif) {
  
  rev_motif = motif
  motif_mat = as.data.frame(motif$mat)
  motif_mat = motif_mat[,ncol(motif_mat):1 ] 
  rownames(motif_mat) = c("T", "G", "C", "A")
  motif_mat = motif_mat[c("A", "C", "G", "T"), ]
  rev_motif$mat = as.matrix(motif_mat)
  
  return(rev_motif)
}








get_test_geneID = function(cht, cond, dist_summit) {
  test_peaks = subset(unique(subset(cht, condition %in% c(cond))[, c("peak_id", "signif_strongAI", "TEST.SNP.CHROM", "peak_summit")]), signif_strongAI==TRUE)
  test_peaks_GRange = GRanges(seqnames = test_peaks$TEST.SNP.CHROM, ranges = test_peaks$peak_summit)
  test_peaks_Peak_anno = annotatePeak(test_peaks_GRange, tssRegion=c(-dist_summit, dist_summit), TxDb=txdb, annoDb = "org.Dm.eg.db")
  test_peaks_geneID = as.data.frame(test_peaks_Peak_anno)$geneId

  return(unique(test_peaks_geneID))
}

get_background_geneID = function(cht, cond, dist_summit, test_peaks_geneID) {
  background_peaks = subset(unique(subset(cht, condition %in% c(cond))[, c("peak_id", "signif_strongAI", "TEST.SNP.CHROM", "peak_summit")]), !(peak_id %in% test_peaks_geneID) & (signif_strongAI == FALSE) )
  background_peaks_GRange = GRanges(seqnames = background_peaks$TEST.SNP.CHROM, ranges = background_peaks$peak_summit)
  background_peaks_Peak_anno = annotatePeak(background_peaks_GRange, tssRegion=c(-1000, 1000), TxDb=txdb, annoDb = "org.Dm.eg.db")
  background_peaks_geneID = as.data.frame(background_peaks_Peak_anno)$geneId

  return(unique(background_peaks_geneID))
}



create_GOdata = function(test_geneID, background_geneID, GO_term) {
  GO_annotation = inverseList(annFUN.org(whichOnto = GO_term, mapping = "org.Dm.eg.db"))
  annotation_ID = check_annotation_ID(GO_annotation)
  test_ID = transfor_gene_ID(test_geneID, annotation_ID)
  background_ID = transfor_gene_ID(background_geneID, annotation_ID)
  test = test_vector(background_ID, test_ID)
  
  GOdata <- new("topGOdata", 
                ontology = GO_term, 
                allGenes = test,
                annot = annFUN.gene2GO,
                gene2GO = GO_annotation)
  
  return(list(GOdata, test))
}


# Convert gene names from FBgn to entrez
FBgn_to_entrez = function(genes_list) {
  entrez = mapIds(org.Dm.eg.db , 
                  keys = as.character(genes_list), 
                  column = "ENTREZID", 
                  keytype = "ENSEMBL", 
                  multiVals = "first")
  return(entrez)
}

# Checks wheter the annotation is in entrez format
check_annotation_ID = function(GO_annotation) {
  if (sum(grep("FBgn", attributes(GO_annotation)$names[[1]]))) {
    annotation_entrez = FALSE 
  } else {
    annotation_entrez = TRUE
  }
  return(annotation_entrez)
}


# Transforms gene IDs in entrez format if the annotation is in entrez format and the gene IDs are in FBgn format
transfor_gene_ID = function(genes_list, annotation_ID) {
  if (as.logical(length(grep("FBgn", genes_list))) && annotation_ID) {
    genes_list = FBgn_to_entrez(genes_list)
  } else {
    genes_list = genes_list
  }
  genes_list = genes_list[!is.na(genes_list)]
  return(genes_list)
}



# Builds the test vector from the background and test lists
test_vector = function(background_ID, test_ID) {
  test = factor(as.integer(background_ID %in% test_ID))
  names(test) = background_ID
  return(test)
}

# Selects the top GOs for plotting based on the parameters given
select_top_GO = function(GO_results, max_sign, min_sign, max_FDR, n_top_GO) {
  GO_results$log10_FDR = -log10(as.numeric(GO_results$FDR))
  GO_results[is.na(GO_results)] <- -log10(1e-30)
  GO_results = GO_results %>% filter(Significant < max_sign)  %>% filter(Significant > min_sign) %>% filter(FDR <= max_FDR)
  GO_results = GO_results[order(GO_results$log10_FDR, decreasing = TRUE),]
  GO_results$Term <- reorder(GO_results$Term, abs(log10(GO_results$Fold_Enrichment)))
  GO_results = head(GO_results, n=n_top_GO)
}

# Enrichment plots for top GO terms
ggplot_GO_enrichment = function(results_df, n_background, n_test, cond) {
  results_df$log_FC = log2(results_df$Fold_Enrichment)
  
  ggplot(results_df, aes(x=log_FC, y=Term)) + 
    geom_vline(xintercept = 0, linetype= "longdash", size = 1) +
    geom_point(aes(size=Significant, colour=log10_FDR)) +
    theme_bw() + 
    theme(panel.grid = element_line(colour = "grey80", size = 1)) +
    theme(axis.text.x = element_text(size = 8), axis.text.y = element_text(size = 8), axis.title = element_text(size = 8), plot.title = element_text(size = 8)) +
    theme(legend.text = element_text(size = 8)) +
    scale_size_continuous(range = c(3, 9)) +
    scale_colour_gradient(low = "#1d78c7", high = "#d35a62") +
    ylab("") +
    xlab("log2 Fold Enrichment") +
    ggtitle(paste0(paste(cond, collapse = '_'), "\nTested genes: ", n_test, "     Genes in background: ", n_background))
}


