#!/usr/bin/python3


#
# Version: 20180210a
#


#
# Run DESeq2 equivalence test result filter (with some analysis)
#

#
# Output is written to stdout so redirect it to a file
# in order to save it.
#


import sys
import os
import re
import math
import gzip
import textwrap
import pickle
import argparse
import statistics

gpr_padjustcutoff = 0.1
gpr_maxfoldchange = 2.0


#
# The following files are the sources for the gene
# information. This program is run initially with the
# '-m' command line parameter in order to read these 
# files and make a pickle file.
#
gfn_gadam  = 'GeneTPM.180123.edit1.tissue_mean.long_gene_names.txt'


#
# Path to directory that has DESeq2 pairwise comparison results.
#
gfn_deseq2 = 'run.all.180126'

#
# output pickle default filename
#
gfn_pickle = 'gene_equivexpress.pickle'


gflg_exclude = 0


def main():

  argparser = argparse.ArgumentParser()
  argparser.add_argument( '-a', '--run-analysis',       help='run analysis', action='store_true' )
  argparser.add_argument( '-m', '--make-pickle-file',   help='read data files and make a pickle file', action='store_true' )
  argparser.add_argument( '-p', '--pickle-filename',    help='input pickle filename (default: %s)' % gfn_pickle, default=gfn_pickle )

  args = argparser.parse_args()

  if( args.pickle_filename ):
    fn_pickle = args.pickle_filename
  else:
    fn_pickle = gfn_pickle

  if( args.make_pickle_file ):
    dadam  = {}
    lhadam = []
    print( 'read Adam\'s expression data next...', file=sys.stderr )
    xreadAdam( gfn_gadam, lhadam, dadam )
    ddeseq2    = {}
    print( 'read DESeq2 data next...', file=sys.stderr )
    xreadDeseq2( gfn_deseq2, ddeseq2 )
    dfiles = { 'tpm': gfn_gadam, 'deseq2': gfn_deseq2 }
    ddata = { 'lhadam': lhadam, 'dadam': dadam, 'ddeseq2': ddeseq2, 'dfiles': dfiles }
    print( 'write pickle file next...', file=sys.stderr )
    xwritePickle( ddata, fn_pickle )
  elif( args.run_analysis ):
    print( 'run analysis', file=sys.stderr )
    xrunAnalysis( fn_pickle, gflg_exclude, gpr_padjustcutoff, gpr_maxfoldchange )
  return( 0 )


def xreadAdam( fn, lhadam, dadam ):
  fp = open( fn, 'r' )
  for inline in fp:
    inline = inline.rstrip()
    toks   = inline.split()
    if( toks[0] == 'gene' ):
      for i in range( 1, len( toks ) ):
        lhadam.append( toks[i] )
      continue
    tnam = toks[0].split( '_' )[0]
    gnam = tnam
    ladam = []
    for s in toks[1:]:
      ladam.append( float( s ) )
    dadam.setdefault( gnam, ladam.copy() )
  fp.close()
  return( 0 )


def xreadDeseq2( fn_deseq2, ddeseq2 ):
  lfil = os.listdir( fn_deseq2 )
  lfres = []
  for fil in lfil:
    if( re.search( r'\.results\.txt$', fil ) ):
      lfres.append( fil )
  for fil in lfres:
    mobj = re.match( r'^run.(.+).results.txt$', fil )
    if( mobj == None ):
      print( 'Error: unexpected condition', file=sys.stderr )
      sys.exit( -1 )
    nmcmp = mobj.group( 1 )
    if( nmcmp == None ):
      print( 'Error: unexpected condition', file=sys.stderr ) 
      sys.exit( -1 )
    mpath = '%s/%s' % ( fn_deseq2, fil )
    fp = open( mpath, 'r' )
    for inline in fp:
      inline = inline.rstrip()
      stok   = inline.split()
      if( stok[0] == 'baseMean' ):
        continue
      if( stok[1] != 'NA' ):
        baseMean = float( stok[1] )
      else:
        baseMean = float( 'nan' )
      if( stok[2] != 'NA' ):
        log2FoldChange = float( stok[2] )
      else:
        log2FoldChange = float( 'nan' )
      if( stok[6] != 'NA' ):
        padjust = float( stok[6] )
      else:
        padjust = float( 'nan' )
      nmgene = stok[0].split( '_' )[0]
      ddeseq2.setdefault( nmgene, {} )
      ddeseq2[nmgene].setdefault( nmcmp, { 'baseMean': baseMean, 'log2FoldChange': log2FoldChange, 'padjust': padjust, 'fullname': stok[0] } )
    fp.close()

  return( 0 )


def xwritePickle( object, fn ):
  fp = open( fn, 'wb' )
  pickle.dump( object, fp )
  fp.close()
  return( 0 )


def xreadPickle( fn ):
  fp = open( fn, 'rb' )
  object = pickle.load( fp )
  fp.close()
  return( object )


def xanalyzeGene( nmgene, altname, ddeseq2, padjustcutoff, maxfoldchange, dgene ):
  dtissue = {}
  dvalue  = {}
  dnequiv = {}
  dcmpdat = {}
  dgene.setdefault( nmgene, {} )
  for nmcmp in ddeseq2[nmgene].keys():
    stok1  = nmcmp.split( '~' )
    nmtis1 = stok1[0].split( '_' )[0]
    nmtis2 = stok1[1].split( '_' )[0]
    #
    # select within tissue comparisons
    #
    if( nmtis1 != nmtis2 ):
      continue
    dtissue.setdefault( nmtis1, {} )
    dvalue.setdefault( nmtis1, { 'max_padjust_all': 0.0, 'max_abslfc_all': 0.0, 'max_padjust_sig': 0.0, 'max_abslfc_sig': 0.0, 'nan_padjust_all': 'n', 'nan_abslfc_all': 'n' } )
    dnequiv.setdefault( nmtis1, 0 )
    dcmpdat.setdefault( nmtis1, '' )
    if( math.isnan( ddeseq2[nmgene][nmcmp]['padjust'] ) ):
      dvalue[nmtis1]['nan_padjust_all'] = 'y'
    elif( ddeseq2[nmgene][nmcmp]['padjust'] > dvalue[nmtis1]['max_padjust_all'] ):
      dvalue[nmtis1]['max_padjust_all'] = ddeseq2[nmgene][nmcmp]['padjust']
    if( math.isnan( ddeseq2[nmgene][nmcmp]['log2FoldChange'] ) ):
      dvalue[nmtis1]['nan_abslfc_all'] = 'y'
    elif( abs( ddeseq2[nmgene][nmcmp]['log2FoldChange'] ) > dvalue[nmtis1]['max_abslfc_all'] ):
      dvalue[nmtis1]['max_abslfc_all'] = abs( ddeseq2[nmgene][nmcmp]['log2FoldChange'] )
    if( not math.isnan( ddeseq2[nmgene][nmcmp]['padjust'] ) and ddeseq2[nmgene][nmcmp]['padjust'] <= padjustcutoff ):
      dnequiv[nmtis1] += 1
      dtissue[nmtis1].setdefault( nmcmp, 1 )
      if( ddeseq2[nmgene][nmcmp]['padjust'] > dvalue[nmtis1]['max_padjust_sig'] ):
        dvalue[nmtis1]['max_padjust_sig'] = ddeseq2[nmgene][nmcmp]['padjust']
      if( abs( ddeseq2[nmgene][nmcmp]['log2FoldChange'] ) > dvalue[nmtis1]['max_abslfc_sig'] ):
        dvalue[nmtis1]['max_abslfc_sig'] = abs( ddeseq2[nmgene][nmcmp]['log2FoldChange'] )
    if( len( dcmpdat[nmtis1] ) > 0 ):
      dcmpdat[nmtis1] += ' '
    dcmpdat[nmtis1] += '%s:%.5f:%.5f:%.1f' % ( nmcmp, ddeseq2[nmgene][nmcmp]['log2FoldChange'], ddeseq2[nmgene][nmcmp]['padjust'], ddeseq2[nmgene][nmcmp]['baseMean'] )
  for nmtis in dtissue.keys():
    print( 'DATA_TISSUE: %s %s' % ( nmgene, dcmpdat[nmtis] ) )

    if( 2.0**dvalue[nmtis]['max_abslfc_all'] > maxfoldchange ):
      maxfcallflag = 'y'
    else:
      maxfcallflag = 'n'

    if( dvalue[nmtis]['max_padjust_all'] > padjustcutoff ):
      maxpadjflag = 'y'
    else:
      maxpadjflag = 'n'

    #
    # store for xreportAllTissueGenes() below
    #
    # Note: if dnequiv[nmtis] == 0 then dvalue[nmtis]['max_*_sig'] have their initial value of zero.
    #
    dgene[nmgene].setdefault( nmtis, { 'name_tissue': nmtis, 'num_equiv': dnequiv[nmtis], 'max_padjust_sig': dvalue[nmtis]['max_padjust_sig'], 'max_fc_sig': 2.0**dvalue[nmtis]['max_abslfc_sig'], 'nan_padjust_all': dvalue[nmtis]['nan_padjust_all'], 'max_padjust_all': dvalue[nmtis]['max_padjust_all'], 'nan_fc_all': dvalue[nmtis]['nan_abslfc_all'], 'max_fc_all': 2.0**dvalue[nmtis]['max_abslfc_all'], 'maxfcallflag': maxfcallflag, 'maxpadjflag': maxpadjflag, 'altname': altname, 'lcomparisons': list( dtissue[nmtis].keys() ) } )

    print( 'ONE_TISSUE: %s %s %s %s %s %s %s %d %.3f %.1f %.3f %.1f' % ( nmgene, altname, nmtis, dvalue[nmtis]['nan_abslfc_all'], dvalue[nmtis]['nan_padjust_all'], maxfcallflag, maxpadjflag, dnequiv[nmtis], dvalue[nmtis]['max_padjust_all'], 2.0**dvalue[nmtis]['max_abslfc_all'], dvalue[nmtis]['max_padjust_sig'], 2.0**dvalue[nmtis]['max_abslfc_sig'] ), end='' )
    for nmcmp in dtissue[nmtis].keys():
      print( ' %s' % ( nmcmp ), end='' )
    print()
  return( 0 )


def xreportAllTissueGenes( dgene, padjustcutoff, maxfoldchange ):
  for nmgene in dgene.keys():
    mxcnt = 0
    altname = ''
    for nmtis in dgene[nmgene].keys():
      if( dgene[nmgene][nmtis]['num_equiv'] > mxcnt ):
        mxcnt = dgene[nmgene][nmtis]['num_equiv']
      if( len( altname ) == 0 ):
        altname = dgene[nmgene][nmtis]['altname']
    lhist = [ 0 ] * ( mxcnt + 1 )

    mx_padj_all  = 0.0
    mx_fc_all    = 0.0
    nan_padj_all = 'n'
    nan_fc_all   = 'n'
    mx_padj_flag = 'n'
    mx_fc_flag   = 'n'
    for nmtis in dgene[nmgene].keys():
      icnt = dgene[nmgene][nmtis]['num_equiv']
      lhist[icnt] += 1
      if( dgene[nmgene][nmtis]['nan_padjust_all'] == 'y' ):
        nan_padj_all = 'y'
      if( dgene[nmgene][nmtis]['nan_fc_all'] == 'y' ):
        nan_fc_all = 'y'
      if( dgene[nmgene][nmtis]['max_padjust_all'] > mx_padj_all ):
        mx_padj_all = dgene[nmgene][nmtis]['max_padjust_all']
      if( dgene[nmgene][nmtis]['max_fc_all'] > mx_fc_all ):
        mx_fc_all = dgene[nmgene][nmtis]['max_fc_all']
      if( dgene[nmgene][nmtis]['maxfcallflag'] == 'y' ):
        mx_fc_flag = 'y'
      if( dgene[nmgene][nmtis]['maxpadjflag'] == 'y' ):
        mx_padj_flag = 'y'

    if( mxcnt == 0 or mx_fc_all > maxfoldchange ):
      continue

    print( 'ALL_TISSUE: %s %s %s %s %s %s %.3f %.1f' % ( nmgene, altname, nan_fc_all, nan_padj_all, mx_fc_flag, mx_padj_flag, mx_padj_all, mx_fc_all ), end='' )
    for icnt in range( len( lhist ) - 1, -1, -1 ):
      if( lhist[icnt] > 0 ):
        print( ' %d:%d' % ( icnt, lhist[icnt] ), end='' )
    print()

  return( 0 )


def xgetCommonNameGene( nmgene, ddeseq2 ):
  nmcmp = list( ddeseq2[nmgene].keys() )[0]
  stok = ddeseq2[nmgene][nmcmp]['fullname'].split( '_' )
  cmnname = stok[len(stok)-2]
  return( cmnname )


def xrunAnalysis( fn_pickle, flag_exclude, padjustcutoff, maxfoldchange ):
  ddata   = xreadPickle( fn_pickle )
  lhadam  = ddata['lhadam']
  dadam   = ddata['dadam']
  ddeseq2 = ddata['ddeseq2']
  dfiles  = ddata['dfiles' ]

  print( '# File TPM data:       %s' % ( dfiles['tpm'] ) )
  print( '# Directory DESeq2:    %s' % ( dfiles['deseq2'] ) )
  print( '# Padjust cutoff:      %f' % ( padjustcutoff ) )
  print( '# MaxFoldChangeCutoff: %f' % ( maxfoldchange ) )

  dgene = {}
  for nmgene in ddeseq2.keys():
    cmnname = xgetCommonNameGene( nmgene, ddeseq2 )
    xanalyzeGene( nmgene, cmnname, ddeseq2, padjustcutoff, maxfoldchange, dgene )

  xreportAllTissueGenes( dgene, padjustcutoff, maxfoldchange )

  return( 0 )


if __name__ == '__main__':
  main()

