import torch
import torch.nn as nn
import torch.nn.functional as F

class InfoNCE(nn.Module):
    def __init__(self, batch_size, temperature=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.register_buffer("temperature", torch.tensor(temperature))
        self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float())
    
    def forward(self, feat1, feat2):
        features = torch.cat([feat1, feat2], dim=0)
        representations = F.normalize(features, dim=1, p=2)
        
        similarity_matrix = F.cosine_similarity(
            representations.unsqueeze(1), 
            representations.unsqueeze(0), 
            dim=2
        )
        
        sim_ij = torch.diag(similarity_matrix, self.batch_size)
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)
        
        nominator = torch.exp(positives / self.temperature)
        denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)
        
        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss

class L1Regularization(nn.Module):
    def __init__(self, weight_decay=0.1):
        super().__init__()
        self.weight_decay = weight_decay
    
    def forward(self, model):
        return sum(torch.mean(torch.abs(param)) * self.weight_decay for param in model.parameters())

def non_corr(x):
    def cor(m):
        m = m.t()
        fact = 1.0 / (m.size(1) - 1)
        m = m - torch.mean(m, dim=1, keepdim=True)
        mt = m.t()
        return fact * m.matmul(mt).squeeze()
    
    l = torch.mean(torch.abs(torch.triu(cor(x), diagonal=1)))
    return l

def zero_center(x):
    l = torch.mean(torch.abs(x))
    return l

def max_var(x):
    def max_moment1(feats):
        loss = 1 / torch.mean(torch.abs(feats - torch.mean(feats, dim=0)))
        return loss
    return max_moment1(x) 