extractGeneList <- function(str){
  out <- c()
  inlist <- unlist(strsplit(str, '~'))
  for(x in inlist){
    x = gsub('^.', '', x) # remove strand sign
    out <- c(out, x)
  }
  return(out)
}

detMissing <- function(inpmat){
  mat = inpmat
  ids = rownames(mat)
  coln = dim(mat)[2]
  for(id in ids) {
    for(i in 1:coln){
      if(is.na(mat[id, i])){
        ## the first left non-NA value
        a = NA
        for(j in i:1){
          a = mat[id, j]
          if(!is.na(a)) break 
        }
        ## the first right non-NA value
        b = NA
        for(j in i:coln){
          b = mat[id, j]
          if(!is.na(b)) break
        }
        if(!is.na(a) & !is.na(b) & b != 1) mat[id, i] = 0
      }
    }
  }
  return(mat)
}

# load data
args = commandArgs(trailingOnly=TRUE)

hapfile = args[1]
mapfile = args[2]
divfiledir = args[3]
tag = args[4]

#mapfile = "/homes8/yingzhou/Projects/2022-immune-annotation/refData/Data-2023Jun05/gene.names.txt.gz"
#hapfile = "hprc.all.hap.f0.txt"
#divfiledir = "/homes8/yingzhou/Projects/2022-immune-annotation/analysis/3.allele.hap/2.3.intra.div/2.cal.div/hprc-all"
#tag = "hprc"


if(T){
  da<-read.csv(file=hapfile, header=F,sep="\t")
  mapda<-read.csv(file=mapfile, header=F,sep=",")
  rownames(mapda) <- mapda$V1

  hlagenes <- c("HLA-F","HLA-V","HLA-P","HLA-G","HLA-H",
                "HLA-T","HLA-K","HLA-U","HLA-A","HLA-W","HLA-Y",
                "HLA-J","HLA-L","HLA-N","HLA-E","HLA-C","HLA-B",
                "HLA-S","MICA","MICB", 
                "HLA-DRA","HLA-DRB3",
                "HLA-DRB4","HLA-DRB5","HLA-DRB1","HLA-DQA1",
                "HLA-DQB1","HLA-DQB2", "HLA-DQA2","HLA-DOB","TAP2","TAP1",
                "HLA-DMB","HLA-DMA","HLA-DOA", "HLA-DPA1",
                "HLA-DPB1","HLA-DPA2","HLA-DPB2")

  c4genes <- c("C4A", "C4A", "C4B", "C4B")

  pseudogenes <- c("HLA-DPA2", "HLA-DPA3", "HLA-DPB2", "HLA-DRB2",
                   "HLA-DRB6", "HLA-DRB7", "HLA-DRB8", "HLA-DRB9",
                   "HLA-H", "HLA-J", "HLA-K", "HLA-L", "HLA-N", "HLA-P",
                   "HLA-S", "HLA-T", "HLA-U", "HLA-V", "HLA-W",
                   "HLA-X", "HLA-Y", "HLA-Z", "KIR2DP1", "KIR3DP1")

  classI <- c("HLA-A","HLA-B","HLA-C","HLA-E","HLA-F","HLA-G",
              "HLA-H","HLA-HFE","HLA-J","HLA-K","HLA-L","HLA-N",
              "HLA-P","HLA-S","HLA-T","HLA-U","HLA-V","HLA-W","HLA-Y")


  classII <- c("HLA-DMA","HLA-DMB","HLA-DOA","HLA-DOB","HLA-DPA1",
               "HLA-DPA2","HLA-DPB1","HLA-DPB2","HLA-DQA1","HLA-DQA2",
               "HLA-DQB1","HLA-DRA","HLA-DRB1","HLA-DRB2","HLA-DRB3",
               "HLA-DRB4","HLA-DRB5","HLA-DRB6","HLA-DRB7",
               "HLA-DRB8","HLA-DRB9", "HLA-DRB345")

  ids = unique(da$V1)
  idn = length(ids)
  hlan = length(hlagenes)

  hlamat <- matrix(NA, nrow=idn, ncol=hlan)
  c4mat <- matrix(NA, nrow=idn, ncol=4)

  colnames(hlamat) = hlagenes
  rownames(hlamat) = ids

  colnames(c4mat) = c4genes
  rownames(c4mat) = ids

  for( i in  1: dim(da)[1]){
    id = da$V1[i]
    hap = extractGeneList(da$V3[i])
    hlahap = hap[hap %in% hlagenes]
    c4hap = hap[hap %in% c4genes]
    hapL = length(hlahap)
    if (hapL < 1) next
    for(j in 1:hapL){
      gene = hlahap[j]
      if(gene %in% hlagenes){hlamat[id, gene] = j}
    }
    c4a = sum(c4hap == 'C4A')
    c4b = sum(c4hap == 'C4B')
    if(c4a > 0) c4mat[id, 1:c4a] = 1
    if(c4b > 0) c4mat[id, (1:c4b)+2] = 1
  }
}
## determine missing (NA) or not carry (0)
if(T){
  hlamat <- detMissing(hlamat)

  for( i in  1: dim(da)[1]){
    id = da$V1[i]
    if(sum(!is.na(c4mat[id,])) == 0) c4mat[id, ] = "missing"
    c4mat[id,is.na(c4mat[id,])] = 'non-carrier'
    c4mat[id,c4mat[id,] == '1'] = 'carrier'
  }
  ## fill in divergence
  #### clean index
  for(i in 1:dim(hlamat)[1]){
    for(j in 1:dim(hlamat)[2]){
      x = hlamat[i,j]
      if(is.na(x)) hlamat[i,j] = "missing"
      else if(x == 0) hlamat[i,j] = 'non-carrier'
      else if (x > 0) hlamat[i,j] = 'carrier'
    }
  }
}
# add diversity score
if(T){

  for(gene in c(hlagenes, 'C4A', 'C4B')){
    tid = mapda[gene,4]
    divfile=paste0(divfiledir, '/', tid, ".cds.div.txt.gz")
    divval <- read.csv(divfile, header=F, sep=" ")
    val = divval[, c(1,3)]
    if (gene %in% c("C4A", "C4B")){
      for(id in rownames(c4mat)){
        o <- val$V1 == id
        if(sum(o)>0){
          v = mean(val[o, 2]) * 100
          newvalue = floor(v/0.25)*0.25
        }
        else {
          newvalue = '0'
        }
        o = colnames(c4mat) == gene & c4mat[id, ] == 'carrier'
        c4mat[id, o] = paste0(newvalue, '%')
      }
    }
    else {
      for(id in rownames(hlamat)){
        o <- val$V1 == id
        if(sum(o)>0){
          v = mean(val[o, 2]) * 100
          newvalue = floor(v/0.25)*0.25
          hlamat[id, gene] = paste0(newvalue, '%')
        }
      }
    }
  }
}

#print(head(hlamat))

# merge matrix

hlamat = cbind(hlamat[, 1:20], c4mat, hlamat[, 21:39])

if(T){
  # sorting based on the haplotype group

  rowL <- dim(hlamat)[1]
  colL <- dim(hlamat)[2]
  hlakeys <- array(0, rowL)
  for(i in 1:rowL){
    s = ''
    for(j in 1:colL){
      if(!is.na(hlamat[i,j]) & hlamat[i,j] != "non-carrier") s = paste0(s, '1')
      else s = paste0(s, '0')
    }
    hlakeys[i] = s

  }
  hlamat <- hlamat[order(hlakeys), ]
}


if(T){

  # plotting
  library('plot.matrix')
  fun_color_range <- colorRampPalette(c("#eeb551", "#a94949"))
  fun_color_range <- colorRampPalette(c("#ffd635", "#6d001a"))

  #png(paste0(tag, ".HLA.cds.div.png"), width= 3200, height=5000, res=250)
  pdf(paste0(tag, ".HLA.cds.div.pdf"), width= 12, height=17)

  par(mar=c(12,10,2,6))

  hlanames<-colnames(hlamat)
  for(i in 1:length(hlanames)){
    if (hlanames[i] %in% pseudogenes) hlanames[i] = paste0('*',hlanames[i])
  }
  colnames(hlamat) = hlanames
  hlamat = hlamat[rowSums(!is.na(hlamat))>1, ]
  cols = fun_color_range(31)
  bks = paste0(seq(0,6, by=0.25), '%')
  #print(hlamat)
  #print(bks)
  #print(bks)
  plot(hlamat, xlab="", ylab="", main=paste0(toupper(tag),"- HLA CDS relative divergence"),
       axis.col = list(side=1, las=2), 
       axis.row = list(side=2, las=1),
       col = c('white', 'lightgray', cols), 
       breaks=c('missing','non-carrier', bks))
  dev.off()


  #png(paste0(tag, ".HLA.cds.div.compact.png"), width= 3200, height=2000, res=250)
  pdf(paste0(tag, ".HLA.cds.div.compact.pdf"), width= 12, height=9)
  par(mar=c(8,2,2,10))

  hlanames<-colnames(hlamat)
  for(i in 1:length(hlanames)){
    if (hlanames[i] %in% pseudogenes) hlanames[i] = paste0('*',hlanames[i])
  }
  colnames(hlamat) = hlanames
  hlamat = hlamat[rowSums(!is.na(hlamat))>1, ]
  plot(hlamat, xlab="", ylab="", main=paste0(toupper(tag)," - HLA haplotypes"),
       axis.col=list(side=1, labels = FALSE), 
       axis.row = NULL, 
       cex.axis = 1.5,
       cex.main = 2,
       border='lightgray', 
       col = c('white', 'lightgray', cols), 
       breaks=c('missing','non-carrier', bks))
  text(x = 1:dim(hlamat)[2],
     y = par("usr")[3] - 2.45,
     labels = colnames(hlamat),
     xpd = NA,
     ## Rotate the labels by 35 degrees.
     srt = 45,
     adj = 1,
     cex = 1.2)
  dev.off()
}
