### Plots cross species gamete panel in Fig. 4(~c?)
### intermediate file is generated by calculate_new_cross_species_snp_barcode_statistics

import sys
import numpy
import pylab
import parse_midas_data
import matplotlib.colors as colors
import matplotlib.cm as cmx
from math import log10,ceil,fabs
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from numpy.random import randint
from scipy.stats import beta
import figure_utils
import parse_timecourse_data
import stats_utils

good_species_list = parse_midas_data.parse_super_good_species_list()
#good_species_list = parse_midas_data.parse_good_species_list()
good_species_pretty_list = figure_utils.get_pretty_species_names(good_species_list)
pretty_species_name_map = {species_name: pretty_species_name for species_name, pretty_species_name in zip(good_species_list, good_species_pretty_list)}

# Load fixations
sys.stderr.write("Loading fixations...\n")
import calculate_barcode_within_species_fixations as calculate_within_species_fixations
fixation_longsnps = set()
for species_name in good_species_list:
	
	snp_changes, gene_changes = calculate_within_species_fixations.load_within_species_fixations(species_name,allowed_epochs=parse_timecourse_data.initial_plus_previous_epoch_intervals)
	snp_changes = list(snp_changes)
		
	for snp_change in sorted(snp_changes):
		contig = snp_change[0]
		location = snp_change[1]
		gene_name = snp_change[2]
		
		longsnp = (species_name, contig, location)
		fixation_longsnps.add(longsnp)

mpl.rcParams['font.size'] = 7
mpl.rcParams['lines.linewidth'] = 0.5
mpl.rcParams['legend.frameon']	= False
mpl.rcParams['legend.fontsize']	 = 'small'

types = ['nonshared']
for type_idx in xrange(0,len(types)):
	
	type = types[type_idx]
	filename = ('new_cross_species_snp_barcode_statistics_%s.txt' % type)

	# Non core genome SNVs
	#filename = 'new_cross_species_snp_barcode_statistics.txt'
	# Core genome SNVs only
	file = open(filename,"r")

	linkage_data = {}

	for line in file:	 
		items = line.split()
		species_name = items[0]
		
		if species_name not in good_species_list:
			continue
		
		lmin = long(items[1])
		testable = long(items[2])
		onegametes = long(items[3])
		twogametes = long(items[4])
		threeplusgametes = long(items[5])
		testable_fourgametes = long(items[6])
		fourgametes = long(items[7])
		
		threegametes = threeplusgametes - fourgametes
		
		Dprimes = []
		if len(items)>8:
			fourgamete_items = items[8:]
			for item in fourgamete_items:
		
				subitems = item.split(",")
				focal_snp_items = subitems[0].strip().split("|")
				target_snp_items = subitems[1].strip().split("|")
				
				ns = numpy.array([long(subitem) for subitem in subitems[2:]])
				min_n = ns.min()
				ntot = ns.sum()
				
				f00 = ns[0]*1.0/ntot
				f10 = ns[1]*1.0/ntot
				f01 = ns[2]*1.0/ntot
				f11 = ns[3]*1.0/ntot
				
				D = f11*f00-(f10*f01)
				if D>=0:
					Dmax = min([(f00+f01)*(f01+f11), (f10+f11)*(f00+f10)])
				else:
					Dmax = -1*min([(f11+f10)*(f11+f01), (f00+f01)*(f00+f10)])
					
				Dprime = D/Dmax
				
				Dprimes.append(Dprime)
	
				focal_contig = focal_snp_items[0]
				focal_location = long(focal_snp_items[1])
		
				target_contig = target_snp_items[0]
				target_location = long(target_snp_items[1])
		
		# Need at least 5 SNV pairs to make a comparison! 
		# (we're dividing by two because we are recording snv pairs in both orders)
		if testable/2.0 < 4.5:
			continue
		
		if species_name not in linkage_data:
			linkage_data[species_name] = {0:[0], 1:[0], 201:[0], 2001:[0]}
		
		Dprimes = numpy.array(Dprimes)
		
		linkage_data[species_name][lmin] = (testable,onegametes,twogametes,threegametes,fourgametes,Dprimes)
	

	species_list = linkage_data.keys()
	total_testable = [linkage_data[species_name][201][0]+linkage_data[species_name][2001][0] for species_name in species_list]
	total_testable, species_list = (numpy.array(x) for x in zip(*sorted(zip(total_testable, species_list), key=lambda pair: (pair[0]),reverse=True))) 
	
	print len(species_list)

	pylab.figure(type_idx+1,figsize=(7,1.5))
	fig = pylab.gcf()
	outer_grid	= gridspec.GridSpec(3,1,height_ratios=[0.05, 0.25,0.5],hspace=0.2)

	legend_axis = plt.Subplot(fig, outer_grid[0])
	fig.add_subplot(legend_axis)

	legend_axis.set_ylim([0,1])
	legend_axis.set_xlim([0,1])

	legend_axis.spines['top'].set_visible(False)
	legend_axis.spines['right'].set_visible(False)
	legend_axis.spines['left'].set_visible(False)
	legend_axis.spines['bottom'].set_visible(False)

	legend_axis.set_xticks([])
	legend_axis.set_yticks([])
   

	legend_axis.plot([-2], [-1],'s',color='k',markersize=3,label='All',markeredgewidth=0)
	legend_axis.plot([-2], [-1],'s',color='g',markersize=3,label='Perfect LD',markeredgewidth=0)
	legend_axis.plot([-2], [-1],'s',color='b',markersize=3,label="Complete LD ",markeredgewidth=0)
	legend_axis.plot([-2], [-1],'s',color='#e34a33',markersize=3,label='4 haplotypes',markeredgewidth=0)

	#legend_axis.plot([-2], [-1],'s',color='#b30000',markersize=3,label='4 haplotypes + SNV diff',markeredgewidth=0)
	
	legend_axis.legend(loc='lower right',frameon=False,fontsize=7,numpoints=1,ncol=5,handlelength=1)   


	fraction_axis = plt.Subplot(fig, outer_grid[2])
	fig.add_subplot(fraction_axis)
	fraction_axis.set_ylim([0,1])
	fraction_axis.set_ylabel('LD category',fontsize=7)
	fraction_axis.set_yticks([])
	log_axis = plt.Subplot(fig, outer_grid[1])
	fig.add_subplot(log_axis)
	#log_axis.set_ylim([0.5,3e06])
	#log_axis.set_ylabel('SNV pairs',fontsize=7)
	log_axis.set_ylim([1,2e06])
	log_axis.set_ylabel('# linked\n SNV pairs')
	
	TOTAL_SPECIES = 0
	TOTAL_CHECKED = 0
	TOTAL_FOURGAMETES = 0
	
	Dprime_data = []
	for species_idx in xrange(0,len(species_list)):
	
		species_name = species_list[species_idx]
		print species_name
	
		Dprimes = linkage_data[species_name][0][-1]
		if len(Dprimes)>=100:
			Dprime_data.append((species_name,Dprimes))
	
		lmins = [1,201,2001]
		#lmins = sorted(linkage_data[species_name].keys())[2:] # Don't do "all"
		dx = (0.6)/len(lmins)
		dxs = [-0.3+i*dx for i in xrange(0,len(lmins))]
		
		for lmin,dx in zip(lmins,dxs):
		
			if linkage_data[species_name][lmin][0]<1:
				fraction_axis.bar([species_idx+dx], [1],width=0.2,color='0.7',linewidth=0,align='edge')
				continue
				
			#print linkage_data[species_name][lmin]	
				
			testable,onegametes,twogametes,threegametes,fourgametes,Dprimes = linkage_data[species_name][lmin]
			
			if testable>0:
				pass
				log_axis.bar([species_idx+dx], [testable/2.0],width=0.2,color='k',log=True,linewidth=0,align='edge')
			else:
				continue
			
			fraction_fourgametes = fourgametes*1.0/testable
			fraction_threegametes = (threegametes*1.0)/testable
			fraction_onetwogametes = (onegametes+twogametes)*1.0/testable
			fraction_longfourgametes = (fourgametes)*1.0/testable
		
			if lmin==2001:
				TOTAL_SPECIES += 1
				TOTAL_CHECKED += testable
				TOTAL_FOURGAMETES += fourgametes
		
			print species_name, lmin, testable, fraction_onetwogametes, fraction_threegametes, fraction_fourgametes
		
			if fraction_onetwogametes>1e-02:
				fraction_axis.bar([species_idx+dx], 1.0,width=0.2,color='#4daf4a',linewidth=0,align='edge')
	
			if fraction_threegametes>1e-02:
				fraction_axis.bar([species_idx+dx], [fraction_threegametes+fraction_fourgametes],width=0.2,color='#377eb8',linewidth=0,align='edge')
		
			if fraction_fourgametes>1e-02:
				print species_name, 'longfourgametes',fraction_longfourgametes
				fraction_axis.bar([species_idx+dx], [fraction_longfourgametes],width=0.2,linewidth=0,color='r',align='edge') #'#b30000')
		
	xticklabels = []
	for species_name in species_list:
		xticklabels.append(pretty_species_name_map[species_name])
	
	log_axis.set_xlim([-1,len(species_list)])  
	log_axis.set_yticks([1,1e02,1e04,1e06])
	log_axis.set_xticks( numpy.arange(0,len(species_list)))
	log_axis.set_xticklabels([])
	#log_axis.set_xticklabels(species_list,rotation=90)
	
	#log_axis.set_xticklabels(species_list,rotation=90)
	log_axis.tick_params(axis='x', labelsize=6,direction='out',length=3,pad=1)
	log_axis.get_xaxis().tick_bottom()
	log_axis.get_yaxis().tick_left()
	log_axis.spines['top'].set_visible(False)
	log_axis.spines['right'].set_visible(False)

	fraction_axis.set_xlim([-1,len(species_list)])	
	fraction_axis.set_xticks(numpy.arange(0,len(species_list)))
	#fraction_axis.set_xticklabels([])
	fraction_axis.set_xticklabels( xticklabels,rotation=90)
	
	fraction_axis.tick_params(axis='x', labelsize=6,direction='out',length=3,pad=1)
	fraction_axis.get_xaxis().tick_bottom()
	fraction_axis.get_yaxis().tick_left()
	fraction_axis.spines['top'].set_visible(False)
	fraction_axis.spines['right'].set_visible(False)

	fig = pylab.gcf()
	fig.savefig('%s/cross_species_gametes_%s.pdf' % (parse_midas_data.analysis_directory, type),bbox_inches='tight')
	
	# Now make Dprime figure
	num_panels = len(Dprime_data)
	num_cols = 3
	num_rows = long(ceil(num_panels*1.0/num_cols))
	
	pylab.figure(10, figsize=(7,2*num_rows))
	fig2 = pylab.gcf()
	outer_grid	= gridspec.GridSpec(num_rows,num_cols, hspace=0.3)

	axes = []
	for idx in xrange(0,num_panels):
		
		row_idx = idx//num_cols
		col_idx = idx % num_cols
		
		species_name, Dprimes = Dprime_data[idx]
		
		axis = plt.Subplot(fig2, outer_grid[row_idx,col_idx])
		fig2.add_subplot(axis)
		
		axis.set_title("%s (n4=%d)" % (pretty_species_name_map[species_name], len(Dprimes)),fontsize=mpl.rcParams['font.size'])
		
		xs, ns = stats_utils.calculate_unnormalized_survival_from_vector(Dprimes, min_x=-0.01, max_x=1.01)
	    
		cdf = 1-ns*1.0/ns[0]

		#axis.step(xs,cdf,'-',where='pre')
		num_bins = min([50,(len(Dprimes)//10)])
		axis.hist(Dprimes,bins=num_bins)
		#axis.set_ylim([0,1])
		axis.set_xlim([0,1])

		axis.set_yticks([])
		if col_idx==0:
			axis.set_ylabel('CDF')
		else:
			axis.set_yticklabels([])
			
		if row_idx==0:
			axis.set_xticklabels([])
		
		if col_idx==1 and row_idx==(num_rows-1):
			axis.set_xlabel("D'")

	fig2.savefig('%s/supplemental_dprimes.pdf' % (parse_midas_data.analysis_directory),bbox_inches='tight')
	
print "FINAL STATISTICS:"

print TOTAL_SPECIES, "species examined"
print TOTAL_CHECKED, "SNV pairs checked" 
print TOTAL_FOURGAMETES, "fourgamete pairs", TOTAL_FOURGAMETES*1.0/TOTAL_CHECKED
		