function buildTRNs_mLassoStARS(instabOutMat,tfaMat,priorMergedTfsFile,...
    meanEdgesPerGene,targInstability,instabSource,subsampHistPdf,trnOutMat,...
    outNetFileSparse)
%% buildTRNs_mLassoStARS(instabOutMat,tfaMat,meanEdgesPerGene,...
%    targInstability,instabSource,plotSubSampHist,trnOutMat,...
%    outNetFileSparse)
%% GOAL: Rank TF-gene interactions according to stability (frequency 
%   of nonzero edges across subsamples + |partial correlation (TF,gene)|
%   If outNetFileSparse (file name) is supplied, this fxn will output the TRN 
%   in "sparse" format : TF, target gene, edge weight + additional edge metadata for
%   viewing with jp_gene_viz) as well as a .mat file summary file
%% Reference:
% Miraldi et al. "Leveraging chromatin accessibility data for 
%   transcriptional regulatory network inference in T Helper 17 Cells"
%% Author: Emily R. Miraldi, Ph.D., Divisions of Immunobiology and Biomedical
%   Informatics, Cincinnati Children's Hospital
%% INPUTS:
% instabOutMat -- contains network- and gene-level instabilities,
%   lambdaRange, number of nonzero subsamples per edge (e.g., as generated
%   by estimateInstabilitiesTRN.m or estimateInstabilitiesTRNbStARS.m)
% tfaMat -- a .mat file containing the prior of TF-gene interactions as
%   well as TFA (prior-based and TF mRNA), e.g., as generated by 
%   integratePrior_estTFA.m
% priorMergedTfsFile -- if degenerate TFs were merged (e.g., by
%   mergeDegeneratePriorTFs.py in priorParsingFunctions) to enable
%   prior-based TFA calculation, supply the two-column, tab-delimited file,
%   where column 1 = merged TF names as in the merged TF prior, and column
%   2 = individual TF names, separated by ", " e.g., as in
%   Th17_example/inputs/priors/ATAC_allTh_mergedTfs.txt:
%         Mesp1_Mesp2     Mesp1, Mesp2
%         Npas1_Npas3...  Npas1, Npas3, Sim1, Sim2
%         Nr1h2_Nr1h3     Nr1h2, Nr1h3   
%   NOTE: provide an empty string '' if there were no degenerate TFs in the
%   prior, this file does not exist, or TF mRNA was used for TFA
% meanEdgesPerGene -- used to calculate quantiles, total number of edges =
%   number of gene models X meanEdgesPerGene
% targInstability -- instability cutoff of interest, belongs to range 
%   (0,.5]
% instabSource -- source of instability estimates, two options:
%   'Network'  --> for network-wide stability estimates
%   'Gene' --> stability based on each gene model separately
% subsampHistPdf -- a filename to be used for generating a histogram of
%   nonzero edge subsample frequencies at the target instability cutoff.  
%   NOTE: empty string '' signals not to create this output
% trnOutMat -- name for output .mat containing ranks, partial correlation,
%   etc., for downstream analysis (e.g., gene expression prediction)
% outNetFileSparse -- name for tab-delimited network file in "sparse" format (i.e., 
%   first three columns are TF, gene, and signed confidence + metadata,
%   including metadata to modulate edge color and other features in
%   visualization tool jp_gene_viz
%   (https://github.com/simonsfoundation/jp_gene_viz), NOTE: empty string
%   '' signals not to create this output
%% OUTPUTS:
%% ${outDir}/Results_lassoStARS/${quantNetFolderName} outputs:
% trnOutMat -- contains ranked lists of network edges, stabilities,
%   whether they were in the input prior, etc.
% outNetFileSparse -- (optional) 3-column network file format for visualization in 
%       jp_gene_viz, limit models to size "meanEdgesPerGene"
%        0.  Edge confidence is a quantile, where total edges is set to
%              meanEdgesPerGene * total Gene Models
%        1.  Edge thickness (in output sparse network) is proportional to 
%              edge stability: 2*(.5-instability) E [0,1] 
%        2.  Edges signs are calculated based on partial correlation
% subsamHistPdf -- (optional) histogram as described above

load(tfaMat)
load(instabOutMat)

[totLambdas, totNetGenes, totNetTfs] = size(ssMatrix);
ssOfInt = zeros(totNetGenes,totNetTfs);  

%% transform StARS instabilities into stabilities 
if length(find(ismember({'Gene'},instabSource)))
    disp('Per-gene instabilities detected.')
    % have to find per-gene lambda corresponding to instability
    totMins = 0; % keep track of instabilities at min or max lambda (suggesting a wider lambda range is needed)
    totMaxes = 0;
    for targ = 1:totNetGenes
        currInstabs = geneInstabilities(targ,:);
        devs = abs(currInstabs - targInstability);
        globalMin = min(devs);
        minInds = find(devs == globalMin);
        % take the largest lambda that is closest to targInstability
        minInd = minInds(end);
        ssOfInt(targ,:) = ssMatrix(minInd,targ,:);
        if minInd == 1
            totMins = totMins + 1;
        elseif minInd == totLambdas
            totMaxes = totMaxes + 1;
        end
    end
    if totMins
        disp(['Target instability reached at minimum lambda for ' num2str(totMins) ' gene(s), cut = ' num2str(targInstability) '.'])
    elseif minInd == totLambdas
        disp(['Target instability reached at maximum lambda for ' num2str(totMins) ' gene(s), cut = ' num2str(targInstability) '.'])
    end   
    hist(ssOfInt(:),totSS+1)
    xlabel(['Subsamples at ' num2str(targInstability) '.'],'FontSize',14)
    ylabel('Counts','FontSize',14)
elseif length(find(ismember({'Network'},instabSource)))
    disp('Network instabilities detected.')
    % find the single lambda corresponding to the cutoff
    devs = abs(netInstabilities - targInstability);
    globalMin = min(devs);
    minInds = find(devs == globalMin);
    % take the largest lambda that is closest to targInstability
    minInd = minInds(end);
    if minInd == 1
        disp(['Minimum lambda was used for maximum instability ' num2str(netInstabilities(minInd)) ' to reach target cut = ' num2str(targInstability) '.'])
    elseif minInd == totLambdas
        disp(['Maximum lambda was used for minimum instability ' num2str(netInstabilities(minInd)) ' to reach target cut = ' num2str(targInstability) '.'])
    end
    ssOfInt(:,:) = ssMatrix(minInd,:,:);
else
    error('instabSource not recognized, should be either Gene or Network.')
end
% ssMatrix has infinity entries to mark illegal TF-gene interactions
% (e.g., TF mRNA TFA cannot be used to predict TF gene expression)
ssOfIntVec = ssOfInt(:);
ssOfIntVec(isinf(ssOfIntVec)) = 0;
ssOfInt = reshape(ssOfIntVec,totNetGenes,totNetTfs);

figure(1), clf
% hist(ssOfInt(:),0:totSS)
vals = hist(ssOfInt(:),0:totSS);
bar(1:totSS,vals(2:end))
yMax = vals(2);%(totSS/50)*1E5;
hold on
ssIn = totSS*(.5+ sqrt(.25-targInstability/2)); % inst = 2 * p * (1-p), solve for p to get minimum # of nonzero subsamples for edge to be in network at the instability cutoff
plot(ssIn*[1 1], [0 yMax], 'r:','LineWidth',1.5)
text(.75*totSS,.75*yMax,['~' roundstring1(length(find(ssOfInt(:)>=ssIn))/totNetGenes) ' TFs/gene'], 'FontSize',14) % estimate size of models at this cutoff
xlabel(['Number of nonzero subsamples at instability ' num2str(targInstability) ', ' instabSource],'FontSize',14)
ylabel(['Counts'],'FontSize',14)
%     axis tight
[xx, infFileBase,ext] = fileparts(trnOutMat);
title(strrep(infFileBase,'_',' '),'FontSize',14)
axis([.5 totSS+.5 0 yMax]), grid on, grid minor
set(gca,'FontSize',12)
shg    

if subsampHistPdf
    figName = subsampHistPdf;
    saveas(gcf,figName,'fig')
    fp = fillPage(gcf, 'margins', [0 0 0 0], 'papersize', [7 6]);
    print('-painters','-dpdf','-r150',[figName '.pdf'])
    disp(figName)
end
% will want to calculate partial correlations further below
zTfa = zscore(predictorMat')';
zTargGeneMat = zscore(responseMat')';
inPriorMat = sign(abs(priorMat));  

if length(find(ssOfInt(:)>=ssIn))/totNetGenes > meanEdgesPerGene
    disp(['Size of the model is ~' roundstring1(length(find(ssOfInt(:)>=ssIn))/totNetGenes) ' TFs/gene. Please adjust meanEdgesPerGene']);
end

%% Rank edges based on stability at instability cutoff
disp('Ranking Edges')
totInts = totNetGenes * totNetTfs;   
targs0 = repmat(targGenes',totNetTfs,1);
regs1 = repmat(allPredictors',totNetGenes,1);
regs0 = reshape(regs1,totInts,1);
%% only keep nonzero & finite edge subsample counts
rankTmp = ssOfInt(:);
keepInds = setdiff(find(rankTmp),find(isinf(rankTmp))); % keep nonzero, remove infinite values (e.g., corresponding to TF-TF edges when TF mRNA used for TFA)
rankTmp2 = rankTmp(keepInds);
[rankings,inds] = sort(rankTmp2,'descend');
regs = {regs0{keepInds(inds)}}'; targs = {targs0{keepInds(inds)}}';
totInfInts = length(rankings);    

%% convert the stabilities to quantiles according to meanEdgesPerGene
disp(['Calculating quantiles, assuming mean of ' num2str(meanEdgesPerGene) ' TFs/gene.'])
totQuantEdges = length(unique(targs))*meanEdgesPerGene;
quantiles = zeros(totQuantEdges,1);
if totInfInts > totQuantEdges
    ranks4quant = rankings(1:totQuantEdges); % note there might be stability
    % ties at the end of the ranks4quant matrix
    disp(['Total networks edges (' num2str(totInfInts) ') > meanEdgesPerGene (' num2str(meanEdgesPerGene) ', ' num2str(totQuantEdges) ').']) 
else
    ranks4quant = zeros(totQuantEdges,1);
    ranks4quant(1:totInfInts) = rankings;
    disp(['Total networks edges (' num2str(totInfInts) ') < meanEdgesPerGene (' num2str(meanEdgesPerGene) ', ' num2str(totQuantEdges) ').']) 
end
uniRanks = sort(setdiff(unique(ranks4quant),[0]),'descend');
totRanks = length(uniRanks);
totVals = 0;
for rind = 1:totRanks
    rankInds = find(ranks4quant==uniRanks(rind));
    totVals = totVals + length(rankInds);
    quantiles(rankInds) = 1 - totVals/totQuantEdges;
end
totQuantEdges = length(find(quantiles));            

%% take what's in the meanEdgesPerGene network and get partial correlations
allCoefs = zeros(totNetGenes,totNetTfs);
allQuants = zeros(totNetGenes,totNetTfs);
allStabsTest = ssOfInt;
keptTargs = {targs{1:totQuantEdges}}';
uniTargs = unique(keptTargs);
totUniTargs = length(uniTargs);
tfsPerGene = zeros(totUniTargs,1);
for targ = 1:totUniTargs
    currTarg = uniTargs{targ};
    targRankInds = find(ismember(keptTargs,currTarg));
    currRegs = {regs{targRankInds}}';
    targInd = find(ismember(targGenes,currTarg));
    tfsPerGene(targ) = length(targRankInds);
    [vals, regressIndsMat, rankVecInds] = intersect(allPredictors,currRegs);         
    currTargVals = responseMat(targInd,:)';
    currPredVals = predictorMat(regressIndsMat,:)';
    prho = partialcorri(currTargVals,currPredVals,'Rows','complete');
    if not(length(find(isnan(prho))))  % make sure there weren't too many edges, 
        allCoefs(targInd,regressIndsMat) = prho;
    else
        disp([currTarg ' pcorr was singular, # TFs = ' num2str(length(regressIndsMat))])
%             gompers
    end
    allQuants(targInd,regressIndsMat) = quantiles(targRankInds(rankVecInds));
    % way to get finer resolution within ranks, might help with P-R
    % round to limit too many decimals
    allStabsTest(targInd,regressIndsMat) = rankings(targRankInds(rankVecInds))+round(abs(prho),2)';  
end
% save stabilities, targs and TFs before mergning -- needed for
% R^2_pred, LO analysis
allStabsMergedTFs = allStabsTest;
%     targsMergedTFs = targs;
%     regsMergedTFs = regs;

mergeTfLocVec = zeros(totNetTfs,1); % for keeping track of merged TFs (needed for partial correlation calculation)

if length(priorMergedTfsFile) % there could be merged TFs
    disp('Found Merged Prior')
    fid = fopen(priorMergedTfsFile,'r');
    C = textscan(fid,'%s%s','Delimiter','\t','Headerlines',0);
    tmergePredTfs = C{1};
    tmergeVals = C{2};
    fclose(fid);
    % convert the raw regulators into TFs, if possible
    % will basically duplicate edges for merged TFs so that they
    % will overlap with the GS
    totMerged = length(tmergeVals);
    rmInds = [];        % remove merged TFs from regulators
    addRegs = '';
    addInts = [];
    addCoefs = [];
    addPMat = [];
    addPredMat = [];
    addLoc = [];
    addQuants = [];
    for mind = 1:totMerged
        mTf = tmergePredTfs{mind};
        inputLocs = find(ismember(allPredictors,{mTf}));
        totMInts = length(inputLocs); % number of interactions for merged TF
        rmInds = [rmInds; inputLocs];
        if inputLocs
            indTfs = intersect(strsplit(tmergeVals{mind},', ')',pRegsNoTfa); % intersect ensures that TF was a potential regulator (e.g., based on gene expression)
            disp([mTf ' expanded to ' strjoin(indTfs,', ') '.'])
            totIndTfs = length(indTfs);
            for indt = 1:totIndTfs
                indTf = indTfs{indt};
                addRegs = strvcat(addRegs,strvcat(repmat(indTf,totMInts,1)));
                addInts = [addInts,allStabsTest(:,inputLocs)];
                addPMat = [addPMat,inPriorMat(:,inputLocs)];
                addQuants = [addQuants, allQuants(:,inputLocs)];
                addCoefs = [addCoefs, allQuants(:,inputLocs)];
                addPredMat = [addPredMat ; predictorMat(inputLocs,:)];
                addLoc = [addLoc; mind];
            end                
        end
    end    
    disp(['Total of ' num2str(length(rmInds)) ' TFs expanded.'])
    keepInds = setdiff(1:totNetTfs,rmInds);
    % remove merged TFs and add individual TFs
    allPredictors = cellstr(strvcat(strvcat(allPredictors{keepInds}),addRegs));
    allStabsTest = [allStabsTest(:,keepInds) addInts];
    allCoefs = [allCoefs(:,keepInds) addCoefs];
    allQuants = [allQuants(:,keepInds) addQuants];
    inPriorMat = [inPriorMat(:,keepInds) addPMat];
    predictorMat = [predictorMat(keepInds,:); addPredMat]; 
    mergeTfLocVec = [mergeTfLocVec(keepInds); addLoc];
else
    disp('No merged TFs file found.')        
end

%% re-rank based on possibly de-merged TFs
rankings = allStabsTest(:);            
coefVec = allCoefs(:);
quantiles = allQuants(:);
inPriorVec = inPriorMat(:);
totNetTfs = length(allPredictors);

totInts = totNetGenes * totNetTfs;   
targs = repmat(targGenes',totNetTfs,1);
regs1 = repmat(allPredictors',totNetGenes,1);
regs = reshape(regs1,totInts,1);

%% only keep nonzero rankings
keepRankings = find(rankings);    
[vals, inds] = sort(rankings(keepRankings),'descend');
% update info sources
rankings = rankings(keepRankings(inds));
coefVec = coefVec(keepRankings(inds));
quantiles = quantiles(keepRankings(inds));
inPriorVec = inPriorVec(keepRankings(inds));    
regs = {regs{keepRankings(inds)}}';
targs = {targs{keepRankings(inds)}}';
totInfInts = length(rankings);

absCoefVec = abs(coefVec);
keep = find(absCoefVec);
% coefVecNz = absCoefVec(keep);
% allStabsTestVecNz = rankings(keep);

%% update quantile ranks to take into account instability + |partial correlation|
% i.e., further refinement
lastQuant = 1;
uniQuants = sort(unique(quantiles),'descend');
totQuants = length(uniQuants);
quantilesRefined = zeros(size(quantiles));
for qind = 1:totQuants
    currQuant = uniQuants(qind);
    currInds = find(quantiles == currQuant);
    currStabs = rankings(currInds);
    uniStabs = sort(unique(currStabs),'descend');
    totStabs = length(uniStabs);
    for sind = 1:totStabs
        sInds = find(currStabs==uniStabs(sind));
        quantilesRefined(currInds(sInds)) = currQuant + (1-(sind-1)/totStabs)*(lastQuant-currQuant);
    end
    lastQuant = currQuant;
end
        
if outNetFileSparse
    %% Network Edge colors, for jp_gene_viz
    medBlue = [0,85,255];
    medRed = [228,26,28];
    lightGrey = [217,217,217];

    %% output meanEdgesPerGene "Quantile" network in sparse network format
    fout = fopen(outNetFileSparse,'w');
    fprintf(fout,'TF\tTarget\tSignedQuantile\tNonzeroSubsamples(Stability)\tpCorr\tstroke\tstroke-width\tstroke-dasharray\n');
    % For network visualization in jp_gene_viz (on Github):
    % Stroke -- denotes color in R,G,B format? -- current "stroke" only
    %   accepts color names, so have capital Stroke until R,G,B is
    %   recognized by "stroke"
    % stroke-width will go from [1,2] and will be proportional to stability
    % stroke-dash-array will incorporate prior information:
    % None --> will be solid and is for Prior-supported edges
    % 2,2 --> will be for edges Not supported by the prior    
    minRank = min(rankings); maxRank = max(rankings); rrange = maxRank-minRank;

    for ii = 1:min(totInfInts,totQuantEdges)
        strokeWidth = 1 + (rankings(ii)-minRank)/rrange;  % 
        currPrho = coefVec(ii);
        if currPrho >= 0
            strokeVals = cellstr(num2str(round([currPrho*medRed + (1-currPrho)*lightGrey]')));
        else            
            strokeVals = cellstr(num2str(round([-currPrho*medBlue + (1+currPrho)*lightGrey]')));
        end
        stroke = ['rgb(' strjoin(strokeVals,',') ')'];
        if inPriorVec(ii) % solid line for thing in the prior (jp_gene_viz)
            fprintf(fout,[regs{ii} '\t' targs{ii} '\t' num2str(sign(currPrho)*quantilesRefined(ii)) ...
                '\t' num2str(rankings(ii))...
                '\t' roundstring3(coefVec(ii)) '\t' stroke ...
               '\t' num2str(strokeWidth) '\tNone\n']);
        else % dotted line for things not in the prior
            fprintf(fout,[regs{ii} '\t' targs{ii} '\t' num2str(sign(currPrho)*quantilesRefined(ii)) ...
               '\t' num2str(rankings(ii))...
               '\t' roundstring3(coefVec(ii)) '\t' stroke ...
               '\t' num2str(strokeWidth) '\t2,2\n']);
        end
    end
    fclose(fout);
    disp(outNetFileSparse)

end

%% save results
outMat = [strrep(trnOutMat,'.mat','') '.mat'];
save(outMat,... 
    'predictorMat',... % include predictorMat
    'responseMat',... % responseMat so partial correlations can be calculated anew later, if needed
    'mergeTfLocVec',...% keep track of TFs that have been merged and share predictor profile
    'allStabsTest',...
    'allCoefs',...
    'allQuants',...
    'inPriorMat',...
    'targGenes',...
    'allPredictors',...
    'allStabsMergedTFs',...
    'regs',...
    'targs',...
    'rankings',...
    'coefVec',...
    'quantiles',...
    'quantilesRefined',...
    'inPriorVec',...
    'instabSource')                
disp(outMat)

