import pandas as pd
import numpy as np
import struct
import random
from sklearn.model_selection import train_test_split

class read_plink(object):
    def __init__(self, filename, **kwargs):
        self.filename = filename
        if "n_SNP" not in kwargs.keys():
            self.read_bim()
        else:
            self.NumSNP = kwargs["n_SNP"]
        if "n_indi" not in kwargs.keys():
            self.read_fam()
        else:
            self.n_indi = kwargs["n_indi"]
            self.li_indi = None

    def read_bim(self):
        feature_info = []
        with open(self.filename + ".bim", "r") as f:
            line = f.readline()
            while line != "":
                li = line.strip("\n").split("\t")
                feature_info.append(li[1])
                line = f.readline()
        self.NumSNP = len(feature_info)
        return feature_info

    def read_fam(self):
        f = open(self.filename + ".fam", "r")
        indi_name = f.readline()
        li_indi = []
        while indi_name != "":
            if len(indi_name.split("\t")) < 3:
                indi_nameli = indi_name.split(" ")
            else:
                indi_nameli = indi_name.split("\t")
            li_indi.append(indi_nameli[1])
            indi_name = f.readline()
        self.n_indi = len(li_indi)
        f.close()
        self.li_indi = li_indi

    def read_bed2numpy(self):
        indi = self.n_indi
        snp = self.NumSNP
        f = open(self.filename + ".bed", "rb")
        col = (indi - 1) // 4 + 1
        f.seek(3)
        arr = np.zeros((indi, snp), dtype=np.int8)
        for s in range(snp):
            r = 0
            c = indi % 4
            data = f.read(col)
            sample_bit = struct.unpack("B" * col, data)
            r = 0
            for bit in sample_bit:
                inter = indi - r
                if inter < 4:
                    c = 4 - inter
                else:
                    c = 0
                while c < 4:
                    a = bit % 2
                    bit = bit // 2
                    b = bit % 2
                    bit = bit // 2
                    if a == 1 and b == 1:
                        arr[r, s] = 0
                    elif a == 0 and b == 1:
                        arr[r, s] = 1
                    elif a == 0 and b == 0:
                        arr[r, s] = 2
                    else:
                        arr[r, s] = -1
                    r += 1
                    c += 1
        f.close()
        return arr

    def construct_input(self, total_label):
        df = pd.read_csv(total_label, dtype="str")
        if self.li_indi == None:
            self.read_fam()
        pd_indi = pd.DataFrame(self.li_indi, dtype="str", columns=["indi"])
        pd_indi1 = pd.merge(pd_indi, df, how="left", on="indi")
        if pd_indi1["class"].isna().any():
            print("warning:there are individuals that don't have corresponding labels")
        else:
            y = pd_indi1["class"].tolist()
            label_uni = np.unique(y)
            label_y = np.arange(0, len(label_uni))
            label_dict = {}
            for i, j in zip(label_uni, label_y):
                label_dict[i] = j
            recode_label = []
            for i in y:
                recode_label.append(label_dict[i])
            return recode_label, label_dict, np.array(self.li_indi)
