#!/usr/bin/Rscript

# Load required packages
library(tidyverse)
library(data.table)
library(recalibration)

#' Load GTEx gene expression data from TPM file
#'
#' @param tpm_path Path to GTEx TPM file
#' @return Data frame containing tissue TPM values
load_gtex_data <- function(tpm_path) {
    # Read GTEx TPM file, skipping header rows
    tpm_data <- read.table(
        tpm_path,
        sep = "\t",
        skip = 2,
        header = TRUE
    )
    
    # Filter genes with zero expression across all tissues
    expressed_genes <- apply(tpm_data[3:56], 1, sum) > 0
    tpm_data <- tpm_data[expressed_genes, ]
    
    # Clean Ensembl gene IDs by removing version numbers
    rownames(tpm_data) <- sapply(
        tpm_data$Name,
        function(x) strsplit(x, "\\.")[[1]][1]
    )
    
    return(tpm_data)
}

#' Process GTEx colors information
#'
#' @param color_path Path to GTEx colors file
#' @return Data frame containing tissue color information
process_gtex_colors <- function(color_path) {
    gtex_infodf <- read.table(
        color_path,
        sep = "\t",
        header = TRUE,
        row.names = 'tissue_id'
    )
    
    # Add Kidney Medulla entry
    gtex_infodf['Kidney_Medulla', ] <- c('Kidney - Medulla', 'KDNMED', '', '')
    
    # Convert tissue site detail to dotstring format
    gtex_infodf$dotstring <- sapply(
        gtex_infodf$tissue_site_detail,
        function(s) gsub("[-(), ]", ".", s)
    )
    rownames(gtex_infodf) <- gtex_infodf$dotstring
    
    return(gtex_infodf)
}

#' Calculate pseudocount-adjusted TPM values
#'
#' @param tpm_data Input TPM data
#' @return Data frame with adjusted TPM values
calculate_pseudocount_tpm <- function(tpm_data) {
    # Find minimum positive TPM value per tissue
    smallest_tissuetpm <- sapply(colnames(tpm_data)[3:56], function(i) {
        min(tpm_data[, i][tpm_data[, i] > 0])
    })
    
    # Convert to numeric and add pseudocount
    tpm_adj <- as.data.frame(sapply(tpm_data[, 3:56], as.numeric))
    rownames(tpm_adj) <- rownames(tpm_data)
    tpm_adj <- sweep(tpm_adj, 2, smallest_tissuetpm, "+")
    
    return(list(
        adjusted_tpm = tpm_adj,
        mean_tpm = apply(tpm_data[, 3:56], 1, mean)
    ))
}

#' Adjust VG estimates based on TPM values
#'
#' @param vgh_df VG data frame
#' @param tissue_tpm TPM data frame
#' @param mean_tpm Mean TPM values
#' @return Adjusted VG data frame
adjust_vg_estimates <- function(vgh_df, tissue_tpm, mean_tpm) {
    # Filter genes present in both datasets
    vgh_df_adj <- vgh_df[rownames(vgh_df) %in% rownames(tissue_tpm), ]
    vgh_df_adj <- vgh_df_adj[!is.na(apply(vgh_df_adj, FUN = function(x) {
        mean(x, na.rm = TRUE)
    }, 1)), ]
    
    tissues <- colnames(tissue_tpm)
    tissues <- tissues[tissues %in% colnames(vgh_df)]
    
    for (tissue in tissues) {
        df_tissue <- data.frame(
            row.names = rownames(vgh_df),
            vg = vgh_df[, tissue],
            tpm = tissue_tpm[rownames(vgh_df), tissue]
        )
        
        df_tissue <- subset(df_tissue, vg > 0 & tpm > 0)
        genes <- rownames(df_tissue)
        
        model <- lm(log10(vg) ~ log10(tpm), data = df_tissue)
        df_tissue$vg_adj <- (mean_tpm[genes] / df_tissue$tpm) ** 
            model$coefficients[2] * df_tissue$vg
        
        vgh_df_adj[genes, tissue] <- df_tissue$vg_adj
    }
    
    return(vgh_df_adj)
}

#' Calculate weighted VG values based on correlation
#'
#' @param gene Gene symbol
#' @param weight Weight vector
#' @param correlations Correlation matrix
#' @param n Number of top correlations to consider
#' @return Weighted VG estimates
calculate_weighted_vg <- function(gene, weight, correlations, n = 5) {
    # Filter non-NA weights
    weight <- weight[names(weight)[!is.na(weight)]]
    vgvalues <- vgh_df_adj[gene, names(weight)]
    
    # Filter tissues with VG values
    vgvalues <- vgvalues[names(vgvalues)[!is.na(vgvalues)]]
    weight <- weight[names(vgvalues)]
    
    # Select top n correlations if available
    if (length(weight) > n) {
        correlations <- correlations[names(vgvalues)]
        correlations <- correlations[order(correlations, decreasing = TRUE)[1:n]]
        vgvalues <- vgvalues[names(correlations)]
        weight <- weight[names(correlations)]
    }
    
    return(sum(weight * vgvalues) / sum(weight))
}

# Main analysis pipeline
main() {
    # Load GTEx expression data. File from gtexportal.org
    tpm_data <- load_gtex_data('GTEx_Analysis_2017-06-05_v8_RNASeQCv1.1.9_gene_median_tpm.gct.gz')
    
    # Use GTEx colors file to match TPM and VG tissue names. File from:
    # https://github.com/stephenslab/gtexresults/blob/master/data/GTExColors.txt
    gtex_infodf <- process_gtex_colors('gtex_colors.txt')
    colnames(tpm_data)[3:56] <- gtex_infodf[colnames(tpm_data)[3:56], 'tissue_abbrv']
    
    # Calculate adjusted TPM values (with pseuodcount)
    tpm_results <- calculate_pseudocount_tpm(tpm_data)
    tissue_tpm_adj <- tpm_results$adjusted_tpm
    mean_tpm <- tpm_results$mean_tpm
    
    # Adjust VG estimates to mean
    vgh_df_adj <- adjust_vg_estimates(vgh_df, tissue_tpm_adj, mean_tpm)
    
    # Calculate tissue correlations
    tissue_cor <- cor(vgh_df, method = 'spearman', use = 'pairwise.complete.obs')
    
    # Create imputed VG estimates
    tissues <- colnames(tissue_tpm_adj)
    tissues <- tissues[tissues %in% colnames(vgh_df)]
    vgh_df_impute <- vgh_df_adj
    
    for (tissue in tissues) {
        # Get correlations excluding current tissue
        tissuesCorrelations <- tissue_cor[tissue, !colnames(tissue_cor) %in% c(tissue, 'MEAN')]
        
        # Calculate weights based on TPM and correlation
        tissueWeights <- as.data.frame(tissue_tpm_adj[, names(tissuesCorrelations)] ** 0 * tissuesCorrelations)
        
        # Impute VG estimates at mean expression
        vgh_df_impute[rownames(tissueWeights), tissue] <- sapply(
            rownames(tissueWeights),
            function(gene) calculate_weighted_vg(gene, tissueWeights[gene, ], tissuesCorrelations)
        )
        
        # Readjust from mean TPM to tissue expression
        df_tissue <- data.frame(
            row.names = rownames(vgh_df),
            vg = vgh_df[, tissue],
            tpm = tissue_tpm_adj[rownames(vgh_df), tissue]
        )
        genes <- rownames(df_tissue)
        model <- lm(log10(vg) ~ log10(tpm), data = subset(df_tissue, vg > 0 & tpm > 0))
        df_tissue$vg_adj <- (df_tissue$tpm/mean_tpm[genes]) ** model$coefficients[2] * vgh_df_impute[genes, tissue]
        vgh_df_impute[genes, tissue] <- df_tissue$vg_adj
    }
    
    # Save results
    write.table(vgh_df_impute, file = 'VGHimputed.tsv', quote = FALSE, sep = '\t')
}

# Run main analysis
main()