library(tidyverse)
library(ggpubr)


#simulate_ivf_experiment_power.R
colors = c("#E69F00", "#666699")

###############################################################################
#ASSUMPTIONS: 

#  No mutation rate biases due to sex
#  Ignore inherited mutations from G0 and G1 
#  Assume equal number of queryable sites in both cohorts
#  Assume coverage values drawn from normal distribution with mean and sd from
#     empirical data
#______________________________________________________________________________

mu = 0.5*10^-8
genome_size = 2230000000
diffs = c(seq(1,2,0.025))
min_DP = 10
dp_per_allele = 3

setwd("/PATH/TO/WORKING/DIR")

#load the pedigree and coverage values for all samples
pedigree <- as_tibble(read.table(file = "pedigree.txt", header = T))
coverage.stats <- as_tibble(read.table(file = "varianceReadCoverage.txt", header = T))

pedigree <- left_join(pedigree, coverage.stats, by = c("sampleId" = "id"))
pedigree <- pedigree %>% filter(Generation == 2) 

n_nat = dim(pedigree %>% filter(Dad == 518))[1]
n_ivf = dim(pedigree %>% filter(Dad == 121586))[1]

nat_cov = (pedigree %>% filter(Dad == 518))$mean
nat_sd = (pedigree %>% filter(Dad == 518))$sd

ivf_cov = (pedigree %>% filter(Dad == 121586))$mean
ivf_sd = (pedigree %>% filter(Dad == 121586))$sd

#subroutine for simulating new mutations
sim_mutations <- function(coverage, sd, mu) {
  
  #simulate number of new mutations from maternal and paternal chromosomes
  n_dnm_mat = rpois(1, lambda = mu*genome_size)
  n_dnm_pat = rpois(1, lambda = mu*genome_size)
  
  #calculate the probability that transmitted mutations are observed (>= 1 
  #read per allele) given the observed sequencing coverage
  #####simulate coverage at focal site
  draw_mat = round(rnorm(n = n_dnm_mat, mean = coverage, sd = sd)) 
  draw_mat = draw_mat[draw_mat >= min_DP]
  if (length(draw_mat) > 0) {
    prob_observed_mat = 1-(pbinom(q = dp_per_allele, size = draw_mat, prob = 0.5))
  }
  
  draw_pat = round(rnorm(n = n_dnm_pat, mean = coverage, sd = sd)) 
  draw_pat = draw_pat[draw_pat >= min_DP]
  if (length(draw_pat) > 0) {
    prob_observed_pat = 1-(pbinom(q = dp_per_allele, size = draw_pat, prob = 0.5))
  }
  
  #determine how many of these new mutations are actually transmitted and 
  #observed given the sequencing coverage
  n_trans_dnm_overall = 0
  if (length(draw_pat) > 0 && length(draw_mat) > 0) {
    n_trans_dnm_mat = sum(rbinom(n = length(draw_mat), size = 1, prob = 0.5*prob_observed_mat))
    n_trans_dnm_pat = sum(rbinom(n = length(draw_pat), size = 1, prob = 0.5*prob_observed_pat))
    n_trans_dnm_overall = n_trans_dnm_mat + n_trans_dnm_pat
  }
  #print(n_trans_dnm_overall)
  
  #simulate position of new mutations
  if (n_trans_dnm_overall > 0) {
    dnm_positions = runif(n_trans_dnm_overall, min = 0, max = genome_size)
    return(sort(round(dnm_positions)))
  } else { return ("NA") }
  
}

#simulate new mutations in NATURAL and IVF cohorts from G1 to G2
if (file.exists("SimulationResults.txt")) {
  file.remove("SimulationResults.txt")
}
header <- c("diff", "rep", "t.stat.ind", "t.df.ind", "t.p.ind", "mwu.stat.ind", "mwu.p.ind", "NegBinomReg.p", "poisReg.p")
simDat <- as.data.frame(matrix(data = NA, nrow = length(diffs)*1000, ncol = 9))
colnames(simDat) <- header

output_iter = 0

for (diff in diffs) {
  print(diff)
  nat_iter = 0
  ivf_iter = 0
  
  natural <- as.data.frame(matrix(data = NA, nrow = 1000*n_nat, ncol = 4))
  ivf <- as.data.frame(matrix(data = NA, nrow = 1000*n_ivf, ncol = 4))
  colnames(natural) <- c("rep", "sampleid", "n_dnm", "positions")
  colnames(ivf) <- c("rep", "sampleid", "n_dnm", "positions")
  
  for (r in 1:1000) {
    for (i in 1:n_nat) {
      nat_iter = nat_iter + 1
      positions = sim_mutations(coverage = nat_cov[i], sd = nat_sd[i], mu)
      natural[nat_iter,] = c(r, i, length(positions), paste(positions, collapse = ","))
    }
  
    for (i in 1:n_ivf) {
      ivf_iter = ivf_iter + 1
      positions = sim_mutations(coverage = ivf_cov[i], sd = ivf_sd[i], mu = mu*diff)
      ivf[ivf_iter,] = c(r, i, length(positions), paste(positions, collapse = ","))
    }
  }
  
  natural$rep <- as.numeric(natural$rep); natural$n_dnm <- as.numeric(natural$n_dnm)
  ivf$rep <- as.numeric(ivf$rep); ivf$n_dnm <- as.numeric(ivf$n_dnm)
  
  for (r in 1:1000) {
    output_iter = output_iter+1;
    
    #computer power at sample level using t-test and MWU
    ivf.rep <- as_tibble(ivf) %>% filter(rep == r) %>% mutate(Cohort = "ART")
    natural.rep <- as_tibble(natural) %>% filter(rep == r) %>% mutate(Cohort = "Natural")
    comb.rep <- bind_rows(ivf.rep, natural.rep)
    comb.rep$Cohort <- as.factor(comb.rep$Cohort)
    comb.rep$Cohort <- relevel(comb.rep$Cohort, ref = "Natural")

    out.t.ind <- t.test(ivf.rep$n_dnm, natural.rep$n_dnm, alternative = "greater")
    out.mwu.ind <- wilcox.test(ivf.rep$n_dnm, natural.rep$n_dnm, alternative = "greater")
    
    out.nb.p <- summary(glm.nb(n_dnm ~ Cohort, data = comb.rep))$coefficients["CohortART", "Pr(>|z|)"]
    out.pois.p <- summary(glm(n_dnm ~ Cohort, family = poisson, data = comb.rep))$coefficients["CohortART", "Pr(>|z|)"]
    
    #print results to output
    simDat[output_iter,] = c(diff, r, as.numeric(out.t.ind$statistic), 
                             as.numeric(out.t.ind$parameter), 
                             as.numeric(out.t.ind$p.value), 
                             as.numeric(out.mwu.ind$statistic), 
                             as.numeric(out.mwu.ind$p.value), 
                             as.numeric(out.nb.p), 
                             as.numeric(out.pois.p))
  }
  
  #write.table(natural, file = paste("natural_1000reps_diff_", diff, ".txt", collapse = ""), quote = FALSE, sep = "\t", row.names = FALSE)
  #write.table(ivf, file = paste("ivf_1000reps_diff_", diff, ".txt", collapse = ""), quote = FALSE, sep = "\t", row.names = FALSE)
}

write.table(simDat, file = "SimulationResults.txt", quote = F, sep = " ", row.names = FALSE)

#_______________________________________________________________________________
#calculate power at each mutation diff value
#_______________________________________________________________________________

power.t = numeric(length(diffs))
power.mwu = numeric(length(diffs))
power.nbr = numeric(length(diffs))
power.pois = numeric(length(diffs))

for (d in 1:length(diffs)) {
  simDat.part <- simDat %>% filter(diff == diffs[d])
  power.t[d] <- length(simDat.part$t.p.ind[which(simDat.part$t.p.ind <= 0.05)])/1000
  power.mwu[d] <- length(simDat.part$mwu.p.ind[which(simDat.part$mwu.p.ind <= 0.05)])/1000
  power.nbr[d] <- length(simDat.part$NegBinomReg.p[which(simDat.part$NegBinomReg.p <= 0.05)])/1000
  power.pois[d] <- length(simDat.part$poisReg.p[which(simDat.part$poisReg.p <= 0.05)])/1000
}

#combine data into a tibble
power <- bind_cols(diffs, power.t, power.mwu, power.nbr, power.pois) 
colnames(power) <- c("diff", "power.t", "power.mwu", "power.nbr", "power.pois")
power <- power %>% mutate(diff = diff-1)
  
#generate plots
ggplot(data = power, aes(x = diff, y = power.nbr)) + 
  geom_point(size = 2.5) + 
  theme_classic() +
  geom_line(linetype = "dashed") +
  ylab("Power") + xlab("Fractional Increase In Mutation Rate")

ggsave(filename = "Power_NBR.pdf", width = 4, height = 3.5)

#MWU: 
ggplot(data = power, aes(x = diff, y = power.mwu)) + 
  geom_point(size = 2.5) + 
  theme_classic() +
  geom_line(linetype = "dashed") +
  ylab("Power") + xlab("Fractional Increase In Mutation Rate")

ggsave(filename = "Power_MWU.pdf", width = 4, height = 3.5)


