#! /usr/bin/python

import re
import sys
import math
from numpy import *


#================== DEFINES FUNCTIONS HERE================


#Finds the weights of the different data points used for intepolating between them
def getWeights(elementLength,lengthList):

	if elementLength <= lengthList[0]:
		return [[0,1],[1,0]]

	for i in range(0,len(lengthList)-1):
		if lengthList[i+1] > elementLength:
			l1=lengthList[i]
			l2=lengthList[i+1]
			return [[i, i+1],[(float(l2)-elementLength)/(l2-l1),(float(elementLength)-l1)/(l2-l1)]]

	if elementLength >=  lengthList[-1]:
		#If element in query is longer than longest length pre-computed data-point, take data for longest available
		return [[-2,-1],[0,1]]
	else:
		print "Error"

#Interpolates the p-vectors between the different datapoints	
def interpolateP(elementLength, lengthList, pList):
	[[i1,i2],[w1,w2]] = getWeights(elementLength, lengthList)
	
	return array([1.0]+pList[i1])*w1+array([1.0]+pList[i2])*w2


#================== REGULAR CODE STARTS HERE ================

argsOk = True

#Parsing arguments
#usage:
# calculateStats query.bed pMatrix.lst
if len(sys.argv)==3:
	queryPath = sys.argv[1]
	queryFile = open(queryPath,'r')
	matrixPath = sys.argv[2]
	matrixFile = open(matrixPath,'r')
	
else:
	argsOk = False

if argsOk == False:
	print 'calculateDistanceExpectation.py query.bed pMatrix.lst'
	
	argsOk=False

#If ok, let's go!!
if argsOk == True:
	#Parses the probability matrix file
	#Parsing first line in matrix file
	line =  matrixFile.readline()
	lData= line.rstrip("\n").split("\t")
	binNames=lData[1:]
	if lData[0][0] != '#':
		print "Error: First line doesn't start with '#'"
	#Reading the rest
	line = matrixFile.readline()
	lengthList=[]
	pList=[]
	while line != '':
		lData = line.rstrip("\n").split("\t")
		lengthList+=[int(lData[0])]
		pList+=[[float(str) for str in lData[1:]]]

		line = matrixFile.readline()

	#Parses query.bed, calculating <N> and sigma for each bin.
	line = queryFile.readline()
	pSum=interpolateP(0,lengthList,pList)*0	
	pSum= array([0.0 for i in range(len(pList[0])+1)])
	p2Sum= array([0.0 for i in range(len(pList[0])+1)])
	
	while line != '':
		lData = line.rstrip("\n").split("\t")

		pVector=interpolateP(int(lData[2])-int(lData[1]),lengthList,pList)

		pSum += pVector
		p2Sum += pVector*pVector

		line = queryFile.readline()

	#Printing results
	binExp = ["%f"%pSum[i+1] for i in range(len(binNames))]
	binSigma = ["%.13f"%sqrt(pSum[i+1]-p2Sum[i+1]) for i in range(len(binNames))]
	print "%s\t%s"%("	".join(binExp),"	".join(binSigma))

	queryFile.close()
	matrixFile.close()

