# J. M. Keller, M. R. Gray, and J. A. Givens, Jr., "A Fuzzy K-Nearest Neighbor Algorithm", IEEE Transactions on Systems, Man, and Cybernetics, Vol. 15, No. 4, pp. 580-585.
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class FuzzyKNN(BaseEstimator, ClassifierMixin):
    def __init__(self, k=3, plot=False):
        self.k = k
        self.plot = plot

    def fit(self, X, y=None):
        self._check_params(X, y)
        self.X_train = X
        self.y_train = y

        self.xdim = len(self.X_train[0])
        self.n = len(self.y_train)

        classes = np.unique(y)
        classes.sort()
        self.classes = classes

        self.memberships = self._compute_memberships()

        self.fitted_ = True
        return self

    def predict(self, X):
        if self.fitted_ == None:
            raise Exception('predict() called before fit()')
        else:
            m = 2
            # m为模糊权重调节因子
            y_pred = []
            y_prob = []
            for i in range(X.shape[0]):
                neighbors_index, neighbors_dist = self._find_k_nearest_neighbors(self.X_train, X[i])
                eps = np.exp(-34)
                votes = {}
                den = 0
                for dist in neighbors_dist:
                    den += 1 / ((dist + eps) ** (2 / (m - 1)))
                for c in self.classes:
                    neighbors_votes = 0
                    for n in range(self.k):
                        dist = neighbors_dist[n]
                        num = (self.memberships[neighbors_index[n]][c]) / ((dist + eps) ** (2 / (m - 1)))

                        neighbors_votes += num / den
                    votes[c] = neighbors_votes

                pred = sorted(votes.items(), key=lambda d: d[1], reverse=True)[0]
                y_pred.append(pred[0])
                y_prob.append(list(votes.values()))
            self.y_prob = np.array(y_prob)
            return y_pred
    def predict_proba(self,X):
        return self.y_prob
    def score(self, X, y):
        if self.fitted_ == None:
            raise Exception('score() called before fit()')
        else:
            predictions = self.predict(X)
            y_pred = [t[0] for t in predictions]
            confidences = [t[1] for t in predictions]
            accurancy = 0
            for i in range(len(y_pred)):
                if y_pred[i][0] == y[i]:
                    accurancy += 1
            accurancy = accurancy / len(y)
            return accurancy, predictions

    def _find_k_nearest_neighbors(self, X, x):
        distance = np.array([np.linalg.norm(X[i] - x) for i in range(self.n)])
        distance_index = np.argsort(distance)
        neighbors_index = distance_index[0:self.k]
        neighbors_distance = distance[neighbors_index]
        return neighbors_index, neighbors_distance

    def _get_counts(self, neighbors_index):
        neighbors_y = self.y_train[neighbors_index]
        counts = {}
        for neigh_y in neighbors_y:
            if neigh_y not in counts:
                counts[neigh_y] = 1
            else:
                counts[neigh_y] += 1
        return counts

    def _compute_memberships(self):
        memberships = []
        for i in range(self.n):
            x = self.X_train[i, :]
            y = self.y_train[i]

            neighbors_index, neighbors_distance = self._find_k_nearest_neighbors(self.X_train, x)
            counts = self._get_counts(neighbors_index)
            membership = dict()
            for c in self.classes:
                try:
                    uci = 0.49 * (counts[c] / self.k)
                    if c == y:
                        uci += 0.51
                    membership[c] = uci
                except:
                    membership[c] = 0

            memberships.append(membership)
        return memberships

    def _check_params(self, X, y):
        if type(self.k) != int:
            raise Exception('"k" should have type int')
        elif self.k >= len(y):
            raise Exception('"k" should be less than no of feature sets')
        #         elif self.k % 2 == 0:
        #             raise Exception('"k" should be odd')

        if type(self.plot) != bool:
            raise Exception('"plot" should have type bool')