#############################################################
#  aria_consensus.py
#  
#  Benjamin Bardiaux 
#  bardiaux@pasteur.fr 
#  Institut Pasteur, PARIS 
#  
#  copyright 2020 
##############################################################

# REFRENCES
# 1) Buchner L, Guntert P. Increased reliability of nuclear magnetic resonance protein structures by consensus structure bundles.
#   Structure. 2015 Feb 3;23(2):425-34. doi: 10.1016/j.str.2014.11.014. Epub 2015 Jan 8. PMID: 25579816.
#
# 2) Belyy A, Lindemann F,  Roderer D, Funk J, Bardiaux B, Protze J, Bieling P, Oschkinat O, Raunser S. 2022. To be published.

import os, sys
aria_path = os.path.join(os.environ['ARIA2'], 'src/py')
sys.path.insert(0, aria_path)

from aria.tools import Load
from numpy import argmin
import re, getopt, random
from aria.ariabase import *
import aria.AriaXML as AriaXML

runkey = "c"

class RestraintList:

    def __init__(self):

        self.restraints = []

    def __str__(self):

        
        return """<!DOCTYPE distance_restraint_list SYSTEM "distance_restraint1.0.dtd">
 <distance_restraint_list>
 %s
</distance_restraint_list>""" %  "\n".join([str(x) for x in self.restraints])

    __repr__ = __str__
 
class Restraint:

    __counter = 0

    def __init__(self, d, l , u, weight = 1., active = 1, reliable = 0, list_name = "general"):

        self.id =self.__class__.__counter

        self.weight = weight
        self.d = d
        self.u = u
        self.l = l
        self.active = active
        self.reliable = reliable
        self.list_name = list_name
        
        self.contributions = []
        
        self.__class__.__counter +=1


    def __str__(self):

        if not len(self.contributions):
            return ""
        w = 1.0 / float(len(self.contributions))
        [c.setWeight(w) for c in (self.contributions)]

        return """  <restraint id="%d" weight="%.3f" distance="%.3f" lower_bound="%.3f" upper_bound="%.3f" active="%d" reliable="%s" list_name="%s">
%s
  </restraint>""" % (self.id, self.weight, self.d, self.l, self.u, self.active, self.reliable, self.list_name, "\n".join([str(x) for x in self.contributions]))

    __repr__ = __str__

class Contribution:

    __counter = 0
    
    def __init__(self, spin_pairs, weight=1.0, ctype="normal"):

        self.id =self.__class__.__counter
        self.spin_pairs = spin_pairs
        self.ctype = ctype
        self.weight = weight

        self.__class__.__counter += 1

    def setWeight(self, w):
        self.weight = w

    def str_spinpairs(self):

        s =''

        for i in self.spin_pairs:
            s+= """
    <spin_pair>
      %s
      %s
    </spin_pair>""" % i

        return s  

    def __str__(self):

        return """   <contribution id="%d" weight="%.3f" type="%s">%s
   </contribution>""" % (self.id, self.weight, self.ctype, self.str_spinpairs())

    __repr__ = __str__

class Atom:
    
    def __init__(self, name, resid, segid):
            
        self.resid = resid
        self.name = name
        self.segid = segid

        
    def __str__(self):

        return """<atom segid="%4s" residue="%d" name="%s"/>""" % (self.segid, self.resid, self.name)
    
    __repr__ = __str__


def convert_contrib_list(contrib_list):

    new_contrib_list = []
    for contrib in contrib_list:
        spin_pairs = []

        for sp in contrib.getSpinPairs():
            new_sp = []
            atoms = sp.getAtoms()
            for a in atoms:
                newa = Atom(a.getName(), a.getResidue().getNumber(), a.getSegid())
                new_sp.append(newa)
            spin_pairs.append(tuple(new_sp))

        c = Contribution(spin_pairs, ctype=contrib.getType())
        new_contrib_list.append(c)

    return new_contrib_list

def get_active_peaks(f):

    aria_peaks = Load(f)
    
    peaks_index = {}
    spectra = {}
   
    for ap in aria_peaks:

        #contributions = ap.getActiveContributions()

        #if not ap.isActive():
        #    contributions = []

        cp = ap.getReferencePeak()
        spectrum_name = cp.getSpectrum().getName()
        cp_no = cp.getNumber()

        spectra.setdefault(spectrum_name, {})
        spectra[spectrum_name][cp_no] = ap

        peaks_index[ap.getId()] = ap

    return aria_peaks, spectra, peaks_index

def get_merged_peaks(merged_file, peaks_index):

    spectra = {}

    reg = re.compile("\s*(.*)\s+(\d+)\s+(\d+)\s+(.*)\s+(\d+)\s+(\d+)")

    missing = []

    for l in open(merged_file):

        o = reg.match(l)

        #if s and s[0].startswith("#"):
        #    continue
        #if len(s) == 6:

        if o is not None:
            
            m_spec = o.groups()[0].strip()
            m_no = int(o.groups()[1])
            m_id = int(o.groups()[2])

            p_spec = o.groups()[3].strip()
            p_no = int(o.groups()[4])
            p_id = int(o.groups()[5])
            
            spectra.setdefault(m_spec, {})
	    try:
                spectra[m_spec][m_no] = peaks_index[p_id]
	    except:
	     	missing.append((m_spec , m_no, m_id, p_spec, p_no, p_id ))

    miss2 = []
    #2nd pass for missing peaks:
    for p in missing:
        m_spec , m_no, m_id, p_spec, p_no, p_id = p
        try:
            spectra[m_spec][m_no] = spectra[p_spec][p_no]
        except:
            miss2.append(p)

    #3rd pass
    for p in miss2:
        m_spec , m_no, m_id, p_spec, p_no, p_id = p
        spectra[m_spec][m_no] = spectra[p_spec][p_no]

    return spectra

def join_peaks(peaks_d, peaks_merged_d):

    spec_names = set(peaks_d.keys() + peaks_merged_d.keys())
    for spec in spec_names:
        n = peaks_merged_d.get(spec)
        if n is not None:
            peaks_d[spec].update(peaks_merged_d[spec])
    
def uniq_id_contrib(c):

    spin_pairs =  c.getSpinPairs()
    cid = []
    for sp in spin_pairs:
        atoms = [a.getId() for a in sp.getAtoms()]
        atoms.sort()
        cid.append(tuple(atoms))
    cid.sort()
    return cid
        

def merge_assignments(contributions_list):
    # contributions_list: list of list (fisrt over calc, then over contributions in a calc)
    final_contributions = []
    c_ids = []

    for i, contrib in enumerate(contributions_list):
        for c in contrib:
            cid = uniq_id_contrib(c)
            
            if cid not in c_ids:
                final_contributions.append(c)
                c_ids.append(cid)

    return final_contributions
    

def make_consensus(peaks_list, merged_peaks_list, cutoff = {'nmr':0.6},  name ='consensus'):

    #from tbl2xml_ec import RestraintList, Restraint, convert_contrib_list

    for p, m in zip(peaks_list, merged_peaks_list):
        join_peaks(p, m) # inplace update of peaksN

    all_peaks = {}
    new_restraints = []

    ncalc = len(peaks_list)

    for spec, ap_dict in peaks_list[0].items():  
     
        _cutoff = cutoff.get(spec)

        if _cutoff is None:
            _cutoff = cutoff['nmr']

        cutoff_r = ncalc * _cutoff

        all_peaks.setdefault(spec, {})
        for cp_no, ap in ap_dict.items():
            all_peaks[spec].setdefault(cp_no, [])
            for p in peaks_list:
                contributions = p[spec][cp_no].getActiveContributions()
                if not p[spec][cp_no].isActive():
                    contributions = []

                all_peaks[spec][cp_no].append(contributions)

            frac_assi = sum([len(x) > 0 for x in all_peaks[spec][cp_no]])

            if frac_assi >= cutoff_r:
                consensus = merge_assignments(all_peaks[spec][cp_no])

                # find aria_peak with shortest distance
                dist = [_[spec][cp_no].getDistance() for _ in peaks]
                index = argmin(dist)
                best_ap = peaks[index][spec][cp_no]

                r = Restraint(best_ap.getDistance(),best_ap.getLowerBound(), best_ap.getUpperBound(), list_name = "%s" % name)
                new_contrib_list = convert_contrib_list(consensus)

                r.contributions = new_contrib_list 
                new_restraints.append(r)

    r_list = RestraintList()
    r_list.restraints = new_restraints

    return r_list

def write_consensus_restraints(restraints, outfile):

    #outf = open("%s_%s_%.1f.xml" % (name, 'nmr', cutoff['nmr']), 'w')
    outf = open(outfile, 'w')
    outf.write(str(restraints))
    outf.close()     
                               
    print "Consensus restraints %s written." % outfile
    
def read_project(file):

    pickler = AriaXML.AriaXMLPickler()
    project = pickler.load_relaxed(file)

    return project

def update_project(project, n, outfile):

    project_settings =  project.getSettings()
    project_settings["run"] = "%s%d" % (runkey, n)
    
    ccpn_settings = project.getReporter()['ccpn']
    ccpn_settings['export_noe_restraint_list'] = 'no'
    ccpn_settings['export_structures'] = 'no'
    ccpn_settings['export_assignments'] = 'no'

    reporter_settings = project.getReporter()['noe_restraint_list']
    #if not reporter_settings['pickle_output'] in ('gzip', 'yes'):
    reporter_settings['pickle_output'] = 'yes'
    #if not  reporter_settings['text_output'] in ('gzip', 'yes'):
    reporter_settings['pickle_output'] = 'yes'

    md = project.getStructureEngine().getMDParameters()
    md['random_seed'] = random.randint(1, 9999999)

    pickler = AriaXML.AriaXMLPickler()
    pickler.dump(project, outfile)

    return True


def usage():

    USAGE = """Usage (requires the ARIA2 environment variable to be set)

%s --generate|merge [--n 20] [--cutoff 0.6] aria_project.xml

    --generate         Generate ARIA XML projects (using aria_project.xml) as
                       reference with different random seeds numbered from 1 to <n>

    --merge            Generate consensus restraints (XML format) from <n> runs

    --n <integer>      Number of runs to generate/merge (default: 20)

    --cutoff <float>   Cut-off for keeping consensus cross-peaks (default: 0.6)
""" % sys.argv[0]

    print USAGE
    sys.exit(2)
    


if __name__ == "__main__":

    try:
        opts, args = getopt.getopt(sys.argv[1:], [], ['generate', 'merge', 'n=','cutoff='])
    except getopt.GetoptError as err:
        print "Error: %s\n" % err
        usage()
        
    action = 'generate'
    
    nruns = 20
    cutoff = 0.6
    
    for o, a in opts:
        if o == '--generate':
            action='generate'
        elif o == '--merge':
            action = 'merge'        
        elif o == '--n':
            nruns = int(a)
        elif o == '--cutoff':
            cutoff = float(a)            
        else:
            usage()
            
    if not args:
        usage()

    project_file = args[0]
    
    if action == 'generate':
        
        baseproject = read_project(project_file)
        
        for a in range(0, nruns):
            out = project_file[:-4] + "_%s%d.xml" % (runkey, a)
            v = update_project(baseproject, a, out)

        print "Project files %s%s written." % (project_file[:-4], "_c[0-%d].xml" % (nruns-1))
        
    elif action == 'merge':

        baseproject = read_project(project_file)
        project_settings =  baseproject.getSettings()
        
        runbase =  project_settings["working_directory"] + "/run%s%s"

        peaks = []
        merged_peaks = []
        restraint_list = []

        for run in range(0, nruns):

            print 'Reading run%s' % run
            rund = os.path.join(runbase, 'structures/it8/')

            pickle_file = rund % (runkey, run) + 'noe_restraints.pickle'
            merged_file = rund % (runkey, run) + 'noe_restraints.merged'

            _aria_peaks, _peaks, _peaks_index = get_active_peaks(pickle_file)
            _peaks_merged = get_merged_peaks(merged_file, _peaks_index)

            peaks.append(_peaks)
            merged_peaks.append(_peaks_merged)

        # this will write the consensus restraints in "consensus.xml" using cutoff
        restraint_list = make_consensus(peaks, merged_peaks, cutoff = {'nmr': cutoff}, name='consensus')

        outfile = 'consensus_%.1f.xml' % cutoff
        write_consensus_restraints(restraint_list, outfile)

        # write project
        # dele spectra, add restraints
        import aria.DataContainer as DC
        for spec in baseproject.getData(DC.DATA_SPECTRUM):
            baseproject.delData(spec)

        # existing restraints
        unambig = [_ for _ in baseproject.getData(DC.DATA_UNAMBIGUOUS) if _['enabled'] == 'yes']
        ambig = [_ for _ in baseproject.getData(DC.DATA_AMBIGUOUS) if _['enabled'] == 'yes']
        
        project_settings['run'] = "cons"

        data = DC.AmbiguousDistanceData()
        data.reset()

        # has 
        data['format'] = 'xml'
        data['filename'] = os.path.abspath(outfile)
        data['calibrate'] = 'no'
        data['enabled'] = 'yes'
        data['add_to_network'] = 'no'
        data['run_network_anchoring'] = 'no'
        data['filter_contributions'] = 'yes'
        
        baseproject.addData(data)

        ## delete all iterations except last, rename it to 0
        protocol = baseproject.getProtocol()        
        it_settings = protocol.getSettings()['iteration_settings']
        nit = len(it_settings)
        if nit > 1:
            for it in range(0, nit-1):
                protocol.getSettings().delIterationSettings(it_settings[it])
            last_it = it_settings[nit-1]
            last_it['number'] = 0

        outfile = project_file[:-4] + "_consensus.xml"

        pickler = AriaXML.AriaXMLPickler()
        pickler.dump(baseproject, outfile)        
        print "Consensus ARIA project %s written." % outfile
        if len(unambig) or len(ambig):
            print "[WARNING] You have unambiguous/ambiguous distance restraints data in your project file."
            print "[WARNING] Make sure to disabled them if they were used to generate the consensus restraints."
        
    else:
        usage()
    

