#!/usr/bin/env python
import pandas as pd
import numpy as np
from scipy.interpolate import PchipInterpolator
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

This script computes t_half, the estimated time to reach half the 
maximal (or minimal) logFC. This is estimated by using Piecewise
Cubic Hermite Interpolating Polynomial to fit a curve to the
values preceding the minimum or maximum.

""")

##################################################
# required args:
parser.add_argument("-i", "--df", help="""required, file path to logFC dataframe
column headings are in format, e.g. 01h, 02h, or 01h30m,02h10m, or 01h30m15s,02h10m35s
where h = hours, m = minutes, s = seconds.
                    """, required=True)
parser.add_argument("-o", "--out", help="""required, file path to t_half (output) dataframe
                    """, required=True)

###################################################

parser.add_argument("--not_relative_to_00h", help="if --not_relative_to_00h, logFC was ",
                    action="store_true")

###################################################

args = parser.parse_args()

df = pd.read_csv(args.df, index_col=0, sep="\t")

# find min/max and assign latter values
# to min/max if min/max not last time point
arr = df.values.copy()
for i,j in enumerate(np.abs(arr).argmax(axis=1)):
    arr[i,j:] = [arr[i,j]] * len(arr[i,j:])

# flip the sign for decreasing elements
for i,j in enumerate(np.abs(arr).argmax(axis=1)):
    if arr[i,j] < 0:
        arr[i] = -1 * arr[i]

def timepoint_conversion_to_numeric(t):
    hrs = 0
    convert_dict = {"h":1,"m":1/60.,"s":1/3600.}
    for i in range(0,len(t), 3):
        hrs += float(t[i:i+2]) * convert_dict[t[i+2]]
    
    if hrs == int(hrs):
        hrs = int(hrs)
    
    return hrs

# add a zero in logFC for time point 0
if not args.not_relative_to_00h:
    arr = np.hstack([np.vstack([0]*len(arr)), arr])

# scale values in 0-1
scaled_arr = (arr - np.vstack(arr.min(axis=1))) \
                / (np.vstack(arr.max(axis=1)) - np.vstack(arr.min(axis=1)))

# time points
x = [timepoint_conversion_to_numeric(t) for t in df.columns]
if not args.not_relative_to_00h:
    x = [0] + x

x_grid = np.linspace(min(x), max(x), 1000)

# compute t_half
t_half=[]
for i in range(len(scaled_arr)):
    f = PchipInterpolator(x, scaled_arr[i])
    y_new = f(x_grid)
    t_half.append(x_grid[np.abs(y_new - 0.5).argmin()])

t_half_df = pd.DataFrame(index=df.index)
t_half_df['t_half'] = t_half
t_half_df.to_csv(args.out, sep="\t", index=True)