
#' TSS Heatmap Count Matrix
#'
#' Generate count matrix to make TSS heatmap
#'
#' @include tsrexplorer.R
#' @include annotate.R
#'
#' @import tibble
#' @importFrom dplyr bind_rows rename ntile filter between select mutate group_by summarize pull ungroup left_join desc
#' @importFrom magrittr %>%
#' @importFrom tidyr complete
#' @importFrom forcats fct_reorder
#' 
#' @param experiment tsrexplorer object with annotated TSSs
#' @param samples Either 'all' or name of samples to analyze
#' @param upstream Bases upstream to consider
#' @param downstream bases downstream to consider
#' @param threshold Reads required per TSS
#' @param anno_type Whether the heatmap is built on genes or transcripts ("geneId", "transcriptId")
#' @param quantiles Number of quantiles to break data down into
#' @param use_cpm Whether to use the CPM normalized values or not
#'
#' @return matrix of counts for each gene/transcript and position
#'
#' @rdname tss_heatmap_matrix-function
#'
#' @export

tss_heatmap_matrix <- function(
	experiment,
	samples = "all",
	upstream = 1000,
	downstream = 1000,
	threshold = 1,
	anno_type = c("transcriptId", "geneId"),
	quantiles = 1,
	use_cpm = FALSE
) {
	## Grab requested samples.
        if (use_cpm) {
                if (samples == "all") samples <- names(experiment@annotated$TSSs$cpm)
                sample_data <- experiment@annotated$TSSs$cpm[samples]
        } else if (!(use_cpm)) {
                if (samples == "all") samples <- names(experiment@annotated$TSSs$raw)
                sample_data <- experiment@annotated$TSSs$raw[samples]
	}
	
	## Start preparing data for plotting.
	annotated_tss <- sample_data %>%
		bind_rows(.id = "sample") %>%
		rename(feature = anno_type) %>%
		filter(
			score >= threshold,
			between(distanceToTSS, -upstream, downstream)
		) %>%
		select(sample, feature, distanceToTSS, score) %>%
		mutate(feature = factor(feature))

	## Get order of genes for heatmap (mean accross samples).
	feature_order <- annotated_tss %>%
		mutate(feature = factor(feature)) %>%
		group_by(feature) %>%
		summarize(total_sum = sum(score)) %>%
		mutate(feature = fct_reorder(feature, desc(total_sum))) %>%
		pull(feature) %>%
		levels

	## Get quantiles of genes.
	feature_quantiles <- annotated_tss %>%
		select(-distanceToTSS) %>%
		group_by(sample, feature) %>%
		summarize(total_sum = sum(score)) %>%
		ungroup %>%
		complete(
			sample,
			feature = feature_order,
			fill = list(total_sum = 0)
		) %>%
		group_by(sample) %>%
		mutate(ntile = ntile(total_sum, quantiles)) %>%
		ungroup %>%
		select(-total_sum)

	## Generate count matrix
	tss_matrix <- annotated_tss %>%
		complete(
			sample,
			feature = feature_order,
			distanceToTSS = -upstream:downstream,
			fill = list(score = 0)
		) %>%
		mutate(log2_score = log2(score + 1)) %>%
		rename(position = distanceToTSS) %>%
		left_join(y = feature_quantiles, by = c("sample", "feature")) %>%
		mutate(feature = factor(feature, levels = feature_order))

	return(tss_matrix)
}

#' Plot Heatmap
#'
#' Plot heatmap from count matrix generated by tss_heatmap_matrix or tsr_heatmap_matrix
#'
#' @import tibble
#' @import ggplot2
#' @importFrom dplyr mutate case_when pull
#' @importFrom forcats fct_rev
#' @importFrom magrittr %>%
#'
#' @param heatmap_matrix TSS or TSR heatmap matrix from tss_heatmap_matrix ot tsr_heatmap_matrix
#' @param max_value Max log2 value to truncate heatmap color
#' @param ncol Number of columns when plotting multiple samples
#' @param ... Arguments passed to geom_tile
#'
#' @return ggplot2 object of TSS heatmap
#'
#' @rdname plot_heatmap-function
#'
#' @export

plot_heatmap <- function(heatmap_matrix, max_value = 5, ncol = 1, ...) {
	
	## Set values above max_value to max_value.
	heatmap_matrix <- heatmap_matrix %>%
		mutate(
			log2_score = case_when(
				log2_score <= max_value ~ log2_score,
				log2_score > max_value ~ max_value
			)
		)

	## Generate heatmap.
	p <- ggplot(heatmap_matrix, aes(x = position, y = fct_rev(feature), fill = log2_score, color = log2_score)) +
		geom_tile(...) +
		theme_minimal() +
		theme(
			axis.text.y = element_blank(),
			panel.grid = element_blank()
		) +
		scale_fill_viridis_c(
			limits = c(0, max_value),
			breaks = c(0:max_value),
			labels = c(0:(max_value - 1), paste0(">", max_value)),
			name = "Log2(Score + 1)"
		) +
		scale_color_viridis_c(
			limits = c(0, max_value),
			breaks = c(0:max_value),
			labels = c(0:(max_value - 1), paste0(">", max_value)),
			name = "Log2(Score + 1)"
		) +
		labs(
			fill = "log2(Score + 1)",
			x = "Position",
			y = "Feature"
		)

	n_quantiles <- heatmap_matrix %>%
		pull(ntile) %>%
		unique %>%
		length

	if (n_quantiles > 1) {
		p <- p + facet_grid(fct_rev(factor(ntile)) ~ sample)
	} else if (n_quantiles == 1) {
		p <- p + facet_wrap(~ sample, ncol = ncol)
	}

	return(p)
}

#' TSR Heatmap Count Matrix
#'
#' Generate count matrix to make TSR heatmap
#'
#' @include tsrexplorer.R
#' @include annotate.R
#'
#' @import tibble
#' @importFrom dplyr bind_rows rename select mutate case_when group_by summarize ungroup filter between left_join
#' @importFrom magrittr %>%
#' @importFrom tidyr complete
#' @importFrom purrr pmap
#' 
#' @param experiment tsrexplorer object with annotated TSRs
#' @param samples Either 'all' or names of samples to analyze
#' @param upstream Bases upstream to consider
#' @param downstream bases downstream to consider
#' @param feature_type Whether the heatmap is built on genes or transcripts ("geneId", "transcriptId")
#' @param quantiles Number of quantiles to split data into
#' @param threshold Reads required per TSR
#' @param use_cpm Whether to use CPM normalized or raw counts
#'
#' @return matrix of counts for each gene/transcript and position
#'
#' @rdname tsr_heatmap_matrix-function
#'
#' @export

tsr_heatmap_matrix <- function(
	experiment,
	samples = "all",
	upstream = 1000,
	downstream = 1000,
	feature_type = c("transcriptId", "geneId"),
	quantiles = 1,
	threshold = 1,
	use_cpm = FALSE
) {
	
	## Pull samples out.
        if (use_cpm) {
                if (samples == "all") samples <- names(experiment@annotated$TSRs$cpm)
                sample_data <- experiment@annotated$TSRs$cpm[samples]
        } else if (!(use_cpm)) {
                if (samples == "all") samples <- names(experiment@annotated$TSRs$raw)
                sample_data <- experiment@annotated$TSRs$raw[samples]
        }

	## Prepare data to be made into count matrix
	annotated_tsr <- sample_data %>%
		bind_rows(.id = "sample") %>%
		rename("feature" = feature_type) %>%
		select(
			sample, strand, start, end, feature,
			geneStart, geneEnd, score
		) %>%
		mutate(
			startDist = case_when(
				strand == "+" ~ start - geneStart,
				strand == "-" ~ -(end - geneEnd)
			),
			endDist = case_when(
				strand == "+" ~ end - geneStart,
				strand == "-" ~ -(start - geneEnd)
			)
		) %>%
		select(sample, startDist, endDist, score, feature)

	## Get order of features based on sum of TSR scores.
	feature_order <- annotated_tsr %>%
		mutate(feature = factor(feature)) %>%
		group_by(feature) %>%
		summarize(total_sum = sum(score)) %>%
		mutate(feature = fct_reorder(feature, desc(total_sum))) %>%
		pull(feature) %>%
		levels

	## Get ntiles by sample and feature.
	feature_quantiles <- annotated_tsr %>%
		select(-startDist, -endDist) %>%
		group_by(sample, feature) %>%
		summarize(total_sum = sum(score)) %>%
		ungroup %>%
		complete(
			sample,
			feature = feature_order,
			fill = list(total_sum = 0)
		) %>%
		group_by(sample) %>%
		mutate(ntile = ntile(total_sum, quantiles)) %>%
		ungroup %>%
		select(-total_sum)		

	## Put TSR score for entire range of TSR.
	annotated_tsr <- annotated_tsr %>%
		pmap(function(sample, startDist, endDist, score, feature) {
			tibble(
				sample = sample,
				position = seq(startDist, endDist, 1),
				score = score,
				feature = feature
			)
		}) %>%
		bind_rows %>%
		filter(between(position, -upstream, downstream)) %>%
		mutate(log2_score = log2(score + 1)) %>%
		complete(
			sample,
			feature = feature_order,
			position = -upstream:downstream,
			fill = list(score = 0, log2_score = 0)
		) %>%
		left_join(y = feature_quantiles, by = c("sample", "feature")) %>%
		mutate(feature = factor(feature, levels = feature_order))

	return(annotated_tsr)
}
