
# functions on this script are courtesy of Jason Belsky 

# Add Figure Label
add_figure_label = function(label, label_cex_ratio = 1){
  
  # Set up the plot
  par(mar = rep(0, 4))
  plot(0, 0, type = "n", bty = "n", axes = F)
  text(x = 0, y = 0, labels = label, cex = 2 * label_cex_ratio, font = 2)
  
}



#' Creates a schematic displaying gene locations in a given chromosomal region.
#'
#' Each gene is designated as a gray box, with genes on the Watson (+) strand
#' displayed in the top row and genes on the Crick (-) strand displayed on
#' the bottom row.  Orientation of the gene name also denotes transcription
#' direction.  Currently this function will only make a gene schematic for
#' yeast genes in the sacCer2/SGD R61 genome version.
#'
#' @param feature_chr Character, chromosome (e.g. "1")
#' @param feature_start Numeric, genomic start position
#' @param feature_end Numeric, genomic end position
#' @param cex_title Numeric, specifies the cex expansion factor for the title
#' @param bg_type Character, whether to make the background "white" or "transparent"
#'  (Default: "white").
#' @param fwd_gene_col,rev_gene_col Character, specifies the color for both
#'  the Watson (+) and Crick (-) genes (Default: "gray" for both).
#' @param proteinCoding Boolean, should only protein-coding genes be included as
#'  opposed to including all open-reading-frames (Default: T)
#' @param geneName Boolean, should the gene name be displayed (Default: T)
#' @param omit_genes Character vector, should any gene names be excluded from the
#'  plot.  This is useful if there are overlapping genes. (Default: NA)
#'
#' @return A schematic showing gene locations in a given chromosomal location
#'
#' @examples
#'  MakeGeneSchematic("1", 40000, 41000)
MakeGeneSchematic = function(feature_chr, feature_start, feature_end,
                             cex_title = 1, bg_type = "white",
                             fwd_gene_col = "gray", rev_gene_col = "gray",
                             proteinCoding = T, geneName = T, omit_genes = NA
){
  
  # Set up the plot
  plot(0, 0, type = "n", bty = "n", bg = bg_type,
       xlim = c(feature_start, feature_end), xaxs = "i", xaxt = "n",
       ylim = c(0, 1), yaxs = "i", yaxt = "n",
       ann = F
  )
  
  # Subset only on protein coding if selected
  if(proteinCoding){
    
    idx = which(as.character(yeast_gene.df$name) != as.character(yeast_gene.df$sgd_name))
    
    yeast_gene.df = yeast_gene.df[idx,]
    
  }
  
  # Omit any gene if necessary
  if(any(!is.na(omit_genes))){
    
    yeast_gene.df = yeast_gene.df[-which(yeast_gene.df$sgd_name %in% omit_genes),]
    
  }
  
  # Convert to a GenomicRanges object
  gene.gr = GenomicRanges::GRanges(seqnames = yeast_gene.df$chr,
                                   ranges = IRanges::IRanges(start = yeast_gene.df$start, end = yeast_gene.df$end),
                                   strand = yeast_gene.df$strand
  )
  names(gene.gr) = yeast_gene.df$name
  
  # Create the feature gr
  feature.gr = GenomicRanges::GRanges(seqnames = feature_chr,
                                      ranges = IRanges::IRanges(start = feature_start, end = feature_end)
  )
  
  # Find the overlaps
  overlaps.hits = GenomicRanges::findOverlaps(feature.gr, gene.gr)
  
  if(any(S4Vectors::subjectHits(overlaps.hits))){
    
    # Get the subjectHits
    subject_hits.v = S4Vectors::subjectHits(overlaps.hits)
    
    # Enter in the genes
    for(i in 1:length(subject_hits.v)){
      PlotGene(yeast_gene.df[subject_hits.v[i],], y_low = 0, y_high = 1,
               feature_start, feature_end, cex_title, geneName, x_pos_title = 50, fwd_gene_col, rev_gene_col)
    }
    
  }
  
}

#' Helper function to plot individual genes as rectangles.
#'
#' Helper function for \code{\link{MakeGeneSchematic}} to plot individual
#' genes.  Each gene is designated as a gray box, with genes on the Watson (+)
#' strand displayed in the top row and genes on the Crick (-) strand displayed
#' on the bottom row.  Orientation of the gene name also denotes transcription
#' direction.
#'
#' @param gene.v Vector, selected row from \code{yeast_gene.df} feature file
#'  containing information about the gene.
#' @param y_low Numeric, the bottom \emph{y} rectangular coordinate.
#' @param y_high Numeric, the top \emph{y} rectangular coordinate.
#' @param x_start Numeric, the left \emph{x} rectangular coordinate of the plot.
#' @param x_end Numeric, the right \emph{x} rectangular coordinate of the plot.
#' @param x_pos_title Numeric, the number of inset bp specifying the gene name.
#' @inheritParams MakeGeneSchematic
#'
#' @return Rectangle schematic depicting gene coordinates
#'
#' @examples
#'  PlotGene(gene.v = yeast_gene.df[1,], y_low = 0.5, y_high = 1,
#'           x_start = 40000, x_end = 41000, cex_title = 1,
#'           geneName = T, x_pos_title = 50,
#'           fwd_gene_color = "gray", rev_gene_color = "gray"
#'          )
PlotGene = function(gene.v, y_low, y_high, x_start, x_end,
                    cex_title, geneName, x_pos_title = 50,
                    fwd_gene_color, rev_gene_color
){
  
  # Get y_mid
  y_mid = (y_high + y_low) / 2
  
  # Add in the text
  if(gene.v$strand == "+"){
    
    # Make the rectangle
    rect(gene.v$start, y_mid + 0.1, gene.v$end, y_high - 0.1, col = fwd_gene_color)
    
    if(geneName){
      if(gene.v$start >= x_start){
        text(x = gene.v$start + x_pos_title, y = y_high - 0.15, adj = c(0, 1),
             labels = gene.v$sgd_name, font = 3, cex = cex_title)
      }else{
        text(x = gene.v$end - x_pos_title, y = y_high - 0.15, adj = c(1, 1),
             labels = gene.v$sgd_name, font = 3, cex = cex_title)
      }
    }
  }else{
    
    # Make the rectangle
    rect(gene.v$start, y_low + 0.1, gene.v$end, y_mid - 0.1, col = rev_gene_color)
    
    if(geneName){
      if(gene.v$end <= xend){
        text(x = gene.v$end - x_pos_title, y = y_low + 0.15, adj = c(0, 1),
             labels = gene.v$sgd_name, srt = 180, font = 3, cex = cex_title)
      }else{
        text(x = gene.v$start + x_pos_title, y = y_low + 0.15, adj = c(1, 1),
             labels = gene.v$sgd_name, srt = 180, font = 3, cex = cex_title)
      }
    }
  }
  
}




# Set up the schematic section
SetChromatinSchematic = function(x_start = 0, x_end = 1, y_start = 0, y_end = 1){
  
  plot(0, 0, type = "n", bty = "n",
       xlim = c(x_start, x_end), xaxs = "i", xaxt = "n",
       ylim = c(y_start, y_end), yaxs = "i", yaxt = "n",
       ann = F
  )
  
}

PlotNucleosome = function(nuc.df, y_max = 2, y0 = 0.5, yh = 0.2, nuc_col = "#FF0000"){
  
  # Set up the angle vector
  theta = seq(0, 2 * pi, length = 1000)
  
  # Set the y
  y = y0 + yh * sin(theta)
  
  for(i in 1:nrow(nuc.df)){
    
    # Get the position
    x0 = nuc.df$pos[i]
    
    # Find the coordinates for the nucleosome at each theta position
    x = x0 + 75 * cos(theta)
    
    # Find the signal color shading
    sig_shade = round(100 * nuc.df$sig[i] / y_max)
    
    if(sig_shade > 99){
      sig_shade = 99
    }
    
    sig_shade_str = formatC(sig_shade, flag = "0#", format = "d", width = 2)
    
    # Plot the nucleosome
    
    #polygon(x, y, col = "palevioletred3")
    polygon(x, y,col = paste(nuc_col, sig_shade_str, sep = ""))
  }
  
}



library(RColorBrewer)

cl <- brewer.pal(5, "Pastel1")
num=1

# Get the 1-D feature density
GetMNaseFeatureDensity = function(bam_file_name, chr, start, end, fragL, fragH, bw){
  
  # Get the MNase Reads
  mnase_reads.gr = ConvertPairedReadBAMToGR(bam_file_name, chr, start_pos = start - 250, end_pos = end)
  
  # Subset on the reads in the range
  reads_subset.gr = mnase_reads.gr[BiocGenerics::width(mnase_reads.gr) >= fragL &
                                     BiocGenerics::width(mnase_reads.gr) <= fragH
                                   ]
  
  # Get the vector of reads
  read_pos.v = IRanges::mid(IRanges::ranges(reads_subset.gr))
  
  # Create the density (Note: not a true density since it doesn't sum to 1)
  read_pos.dens = suppressWarnings(
    stats::density(read_pos.v, bw = bw, # weights = rep(1, length(read_pos.v)),
                   from = start, to = end, n = end - start + 1
    )
  )
  
  # Scale the density sum to 100
  sig_density.v = read_pos.dens$y * 1000 / sum(read_pos.dens$y)	
  
  # Get the BamStats
  bamStats.df = GetBamStats(bam_file_name)
  
  # Parse the BamStats df
  total_read_number = bamStats.df[which(bamStats.df$feature == "read_depth"), "value"]
  genome_size = bamStats.df[which(bamStats.df$feature == "genome_size"), "value"]
  
  # Get the predicted read distribution
  predict_read_dist = total_read_number * (end - start + 1) / genome_size
  
  # Get the actual read distribution
  #	NOTE: Using total MNase reads along genome instead of subset of fragL -> fragH
  #		  This is done simply because getting total number of reads is a fast calculation from the bam index stats
  #		  Can update in future to include additional file of fragment length distribution, which could then read to get the actual distribution
  actual_read_dist = length(mnase_reads.gr)
  
  # Scale the density by the proportion of actual_read_dist to predict_read_dist
  read_pos.v = sig_density.v * actual_read_dist / predict_read_dist
  names(read_pos.v) = read_pos.dens$x
  
  # Scale to the number of reads
  return(read_pos.v)
  
}

GetDensityPeaks = function(cov.v, peak_width = 75, isPeakMax = TRUE, min_peak_sig_thresh = 0){
  
  # Set the peak window
  peak_win = (peak_width - 1) / 2
  
  # Find the peaks
  cov_peaks.v = as.numeric(names(cov.v)[splus2R::peaks(x = cov.v, span = peak_width, strict = isPeakMax)])
  
  # Adjust the position
  peaks.df = data.frame(pos = cov_peaks.v,
                        sig = cov.v[as.character(cov_peaks.v)]
  )
  
  # Remove peaks below min_peak_sig_thresh
  peaks.df = peaks.df[which(peaks.df$sig > min_peak_sig_thresh),]
  
  return(peaks.df)
  
}


# Function to convert a BAM file to a GR file
# By default reads in all the reads for the entire chromosome
# 	Otherwise, specify a particular start_pos and end_pos for just a specific region
ConvertPairedReadBAMToGR = function(bam_file_name, chr, start_pos = 1, end_pos = -1){
  
  # Create the BAM File object
  bf = Rsamtools::BamFile(bam_file_name, index = paste(bam_file_name, ".bai", sep = ""))
  
  # Get the chr list
  chr_length.v = Rsamtools::scanBamHeader(bf)$targets
  
  # Update the end_pos if necessary
  if(end_pos == -1){
    
    # Update the end_pos
    end_pos = chr_length.v[chr]
    
  }
  
  # Make a GR file for the chromosome
  chr.gr = GenomicRanges::GRanges(seqnames = chr,
                                  ranges = IRanges::IRanges(start = max(start_pos - 250, 1), end = end_pos),
                                  strand = "*"
  )
  
  # Specify the scan bam paramaeters
  p = Rsamtools::ScanBamParam(what = c("pos", "isize"),
                              which = chr.gr,
                              flag = Rsamtools::scanBamFlag(isMinusStrand = FALSE)
  )
  
  # Get the reads that meet these conditions
  reads.l = Rsamtools::scanBam(bf, param = p)
  
  if(length(reads.l[[1]][["pos"]]) > 0){
    
    # Convert these reads to a GR object
    IP.gr = GenomicRanges::GRanges(seqnames = factor(chr, levels = names(chr_length.v)),
                                   ranges = IRanges::IRanges(start = reads.l[[1]][["pos"]],
                                                             width = reads.l[[1]][["isize"]]
                                   ),
                                   strand = "*"
    )
    GenomeInfoDb::seqlengths(IP.gr) = chr_length.v
    
  }else{
    
    IP.gr = GenomicRanges::GRanges()
    
  }
  
  return(IP.gr)
  
}


GetBamStats = function(bam_file_name){
  
  # Get the summary of the .bai file
  bai_file.v = system(command = paste("samtools idxstats ", bam_file_name, sep = ""), intern = T)
  
  # Convert to list
  bai_file.l = strsplit(bai_file.v, split = "\t")
  
  # Convert to a dataframe
  bai_file.df = as.data.frame(matrix(unlist(bai_file.l), ncol = 4, byrow = T))
  colnames(bai_file.df) = c("chr", "length", "read_num", "unaligned")
  
  # Convert the read number and length into a numeric
  bai_file.df$read_num = as.numeric(as.character(bai_file.df$read_num))
  bai_file.df$length = as.numeric(as.character(bai_file.df$length))
  
  # Get the total genome size and total read depth
  genome_size = sum(bai_file.df$length)
  read_depth = sum(bai_file.df$read_num)
  
  # Create an output dataframe
  stats.df = data.frame("feature" = c("genome_size", "read_depth"), 
                        "value" = c(genome_size, read_depth)
  ) 
  
  # Return the stats.df
  return(stats.df)
  
}

make_gene_schematic = function(feature_chr, feature_start, feature_end, 
                               y_low = 0, y_high = 1, cex_title = 1, bg_type = "white",
                               proteinCoding = T, geneName = T, omit_genes = NA, x_pos_title = 50,
                               gene_file_name = "/data/home/mpg22/feature_files/sacCer3_ucsc_sgd_geneTable.bed"
                               
                               #gene_file_name = "/data/genome_feature_files/yeast/genes/sacCer3_ucsc_geneTable.bed"
){
  
  # Set up the plot
  plot(0, 0, type = "n", bty = "n", bg = bg_type,
       xlim = c(feature_start, feature_end), xaxs = "i", xaxt = "n",
       ylim = c(0, 1), yaxs = "i", yaxt = "n",
       ann = F
  )
  
  # Load the gene dataframe
  gene.df = read.table(gene_file_name, header=F)
  
  # Convert to a GenomicRanges object
  gene.gr = GRanges(seqnames = gene.df[,1],
                    ranges = IRanges(start = gene.df[,2], end = gene.df[,3]),
                    strand = gene.df[,6]
  )
  names(gene.gr) = gene.df[,7]
  
  # Create the feature gr
  feature.gr = GRanges(seqnames = feature_chr,	
                       ranges = IRanges(start = feature_start, end = feature_end)
  )
  
  # Find the overlaps
  overlaps.df = as.data.frame(as.matrix(findOverlaps(feature.gr, gene.gr)))
  
  if(any(nrow(overlaps.df))){
    
    # Enter in the genes
    for(i in 1:nrow(overlaps.df)){
      plot_gene(gene.df[overlaps.df$subjectHits[i],], y_low, y_high, 
                feature_start, feature_end, cex_title, geneName, x_pos_title)
    }
    
  }
  
}

plot_gene = function(gene.v, ylow, yhigh, xstart, xend, cex_title, geneName, x_pos_title = 50){
  
  # Get y_mid
  ymid = (yhigh + ylow) / 2
  
  # Add in the text
  if(gene.v[,6] == "+"){
    
    # Make the rectangle
    rect(gene.v[,2], ymid + 0.1, gene.v[,3], yhigh - 0.1, col = "grey90")
    
    if(geneName){
      if(gene.v[,2] >= xstart){
        text(x = gene.v[,2] + x_pos_title, y = yhigh - 0.15, adj = c(0, 1),
             labels = gene.v[,7], font = 3, cex = cex_title)
      }else{
        text(x = gene.v[,3] - x_pos_title, y = yhigh - 0.15, adj = c(1, 1),
             labels = gene.v[,7], font = 3, cex = cex_title)
      }
    }
  }else{
    
    # Make the rectangle
    rect(gene.v[,2], ylow + 0.1, gene.v[,3], ymid - 0.1, col = "grey70")
    
    if(geneName){
      if(gene.v[,3] <= xend){
        text(x = gene.v[,2] + x_pos_title, y = ylow + 0.35, adj = c(0, 1),
             labels = gene.v[,7], font = 3, cex = cex_title) #srt = 180
      }else{
        text(x = gene.v[,3] - x_pos_title, y = ylow + 0.35, adj = c(1, 1),
             labels = gene.v[,7], font = 3, cex = cex_title) #srt = 180
      }
    }
  }
  
}

######################################################################################################################################################
# dens_dot_plot: Makes a heatmap of data within a matrix
#
# PARAMETERS:
#	dot.m:		The length-midpoint matrix outputed by make_dens_matrix_GR
#	z_min:		The minimum intensity that should be plotted
#	z_max:		The maximum intensity that should be plotted
#	lowCol:		The color specifying the lowest intensity points
#	highCol:	The color specifying the highest intensity points
#	numColors:	The total number of colors for the plot
#	plot_title:	Title for the plot
#	x_label:	An x-label for the plot
#	y_label:	A y-label for the plot
#	use_row_names:	Whether to number rows as 1:nrow(dot.m) or use the designated rownames(dot.m)
#	plot_title_size: The cex number for the plot title
#	plot_title_line: The line number location for the plot title
#	plot_box:	A Boolean on whether to add a box around the entire plot
#	x_axt, y_axt:	Whether to plot ("s") or suppress ("n") the x and y axes
# RETURNS:
#	A heatmap plot utilizing the image function
dens_dot_plot = function(dot.m, z_min = 0, z_max = 100, 
                         lowCol = "white", medCol = "", highCol = "blue", numColors = 100,
                         plot_title = "", x_label = "", y_label = "", 
                         use_row_names = FALSE,
                         plot_title_line = NA, plot_box = TRUE, x_axt = "s", y_axt = "s"){
  
  # For points that are either above or below z_max or z_min respectively, set them to
  # the z_min and z_max (otherwise, plot shows arbitrary colors
  dot.m[which(dot.m >= z_max)] = z_max
  dot.m[which(dot.m <= z_min)] = z_min
  
  # Set the xValues conversion boolean
  xValues_convert = TRUE
  
  # Get the current column names
  if(!is.null(colnames(dot.m))){
    
    # Check if the values can be converted to numeric
    if(is.numeric(type.convert(colnames(dot.m)))){
      
      xValues = as.numeric(colnames(dot.m))
      
      # Flip the xValues conversion boolean
      xValues_convert = FALSE
      
    }
    
  }
  
  # If the xValues have to be converted, assume a span of 1
  if(xValues_convert){
    
    # Get the number of columns
    col_num = ncol(dot.m)
    
    # Get the window
    x_win = (col_num - 1) / 2
    
    # Set the column names
    xValues = new_start:new_end
  }
  
  if(!use_row_names | is.null(rownames(dot.m))){
    yValues = 0.5:(nrow(dot.m) + 0.5)
  }else{
    yValues = as.numeric(rownames(dot.m))
  }
  
  # Make the colorpanel
  if(nchar(medCol) > 0){
    make_colorpanel = colorpanel(numColors, lowCol, medCol, highCol)
  }else{
    make_colorpanel = colorpanel(numColors, lowCol, highCol)
  }	
  
  # Make the heatmap utilizing the parameters specified above
  image(xValues, yValues, t(dot.m), col = make_colorpanel, zlim = c(z_min, z_max),
        xlab = x_label, ylab = y_label, xaxt = x_axt, yaxt = y_axt, bty = "n"
  ) 
  
  # Set the title
  title(main = plot_title, line = plot_title_line)
  
  # Add a box around the plot
  if(plot_box){ 
    box(which = "plot", lty = "solid")
  }
  
}

library(viridis)
cl <- viridis(100, alpha = 1, begin = 0, end = 1, option = "C") #C
ln=  cl[5]
int = "white" #cl[5]
hm=  cl[50]


dens_dot_plotB = function(dot.m, z_min = 0, z_max = 100, 
                          #lowCol = cl[m], medCol = cl[int], highCol = cl[n], 
                          lowCol = ln, medCol = int, highCol = hm, 
                          numColors = 100,
                          plot_title = "", x_label = "", y_label = "", 
                          use_row_names = FALSE,
                          plot_title_line = NA, plot_box = TRUE, x_axt = "s", y_axt = "s"){
  
  # For points that are either above or below z_max or z_min respectively, set them to
  # the z_min and z_max (otherwise, plot shows arbitrary colors
  dot.m[which(dot.m >= z_max)] = z_max
  dot.m[which(dot.m <= z_min)] = z_min
  
  # Set the xValues conversion boolean
  xValues_convert = TRUE
  
  # Get the current column names
  if(!is.null(colnames(dot.m))){
    
    # Check if the values can be converted to numeric
    if(is.numeric(type.convert(colnames(dot.m)))){
      
      xValues = as.numeric(colnames(dot.m))
      
      # Flip the xValues conversion boolean
      xValues_convert = FALSE
      
    }
    
  }
  
  # If the xValues have to be converted, assume a span of 1
  if(xValues_convert){
    
    # Get the number of columns
    col_num = ncol(dot.m)
    
    # Get the window
    x_win = (col_num - 1) / 2
    
    # Set the column names
    xValues = -x_win:x_win
  }
  
  if(!use_row_names | is.null(rownames(dot.m))){
    yValues = 0.5:(nrow(dot.m) + 0.5)
  }else{
    yValues = as.numeric(rownames(dot.m))
  }
  
  # Make the colorpanel
  if(nchar(medCol) > 0){
    make_colorpanel = colorpanel(numColors, lowCol, medCol, highCol)
  }else{
    make_colorpanel = colorpanel(numColors, lowCol, highCol)
  }	
  
  # Make the heatmap utilizing the parameters specified above
  image(xValues, yValues, t(dot.m), col = make_colorpanel, zlim = c(z_min, z_max),
        xlab = x_label, ylab = y_label, xaxt = x_axt, yaxt = y_axt, bty = "n"
  ) 
  
  # Set the title
  title(main = plot_title, line = plot_title_line)
  
  # Add a box around the plot
  if(plot_box){ 
    box(which = "plot", lty = "solid")
  }
  
}


# Set up the origin
make_ars_schematic = function(feature_chr, feature_start, feature_end, 
                              y_low = 0, y_high = 1, cex_title = 1, bg_type = "white",
                              x_pos_title = 50,
                              ars_file_name = "/data/genome_feature_files/yeast/replication_origins/oriDB/oridb_acs_feature_file_jab-curated-798-sites_sacCer3.csv"
){
  
  # Set up the plot
  plot(0, 0, type = "n", bty = "n", bg = bg_type,
       xlim = c(feature_start, feature_end), xaxs = "i", xaxt = "n",
       ylim = c(0, 1), yaxs = "i", yaxt = "n",
       ann = F
  )
  
  # Load the ars dataframe
  ars.df = read.csv(ars_file_name, stringsAsFactors = F)
  
  # Convert to a GenomicRanges object
  ars.gr = GRanges(seqnames = ars.df[,2],
                   ranges = IRanges(start = ars.df[,13], end = ars.df[,14]),
                   strand = ars.df[,4]
  )
  names(ars.gr) = ars.df[,12]
  
  # Create the feature gr
  feature.gr = GRanges(seqnames = feature_chr,	
                       ranges = IRanges(start = feature_start, end = feature_end)
  )
  
  # Find the overlaps
  overlaps.df = as.data.frame(as.matrix(findOverlaps(feature.gr, ars.gr)))
  
  if(any(nrow(overlaps.df))){
    
    # Enter in the arss
    for(i in 1:nrow(overlaps.df)){
      plot_ars(ars.df[overlaps.df$subjectHits[i],], y_low, y_high, 
               feature_start, feature_end, cex_title, arsName, x_pos_title)
    }
    
  }
  
}

# Plot ars
plot_ars = function(ars.v, ylow, yhigh, xstart, xend, cex_title, arsName, x_pos_title = 50){
  
  # Get y_mid
  ymid = (yhigh + ylow) / 2
  
  # Make the rectangle
  rect(ars.v[,13], ymid , ars.v[,14], yhigh , col="darkseagreen2")
  
  text(x = ars.v[,13] , y = ylow + 0.465, adj = c(-0.1, 1),
       labels = ars.v[,12], font = 3, cex = cex_title) #srt = 180
  
}






