"""

          ARIA -- Ambiguous Restraints for Iterative Assignment

                 A software for automated NOE assignment

                               Version 2.2


Copyright (C) Benjamin Bardiaux, Michael Habeck, Therese Malliavin,
              Wolfgang Rieping, and Michael Nilges

All rights reserved.


NO WARRANTY. This software package is provided 'as is' without warranty of
any kind, expressed or implied, including, but not limited to the implied
warranties of merchantability and fitness for a particular purpose or
a warranty of non-infringement.

Distribution of substantively modified versions of this module is
prohibited without the explicit permission of the copyright holders.

$Author: bardiaux $
$Revision: 1.3 $
$Date: 2007/01/09 11:01:55 $
"""



from aria import AriaBaseClass
from xmlutils import XMLBasePickler
from tools import as_tuple

import TypeChecking as TCheck
import Assignment

CROSSPEAK_ASSIGNMENT_TYPE_MANUAL = Assignment.ASSIGNMENT_TYPE_MANUAL
CROSSPEAK_ASSIGNMENT_TYPE_AUTOMATIC = Assignment.ASSIGNMENT_TYPE_AUTOMATIC
CROSSPEAK_ASSIGNMENT_TYPE_SEMIAUTOMATIC = 'semiautomatic'

class CrossPeak(AriaBaseClass):

    def __init__(self, number, volume, intensity):

        """
        Initialization via:
        - peak number
        - spectrum size, i.e. a tuple that consists of the peak's
          intensity / volume and its uncertainty
        """

        TCheck.check_int(number)
        TCheck.check_type(volume, 'Datum', TCheck.NONE)
        TCheck.check_type(intensity, 'Datum', TCheck.NONE)

        from Datum import Datum

        if volume is None and intensity is None:
            self.error(ValueError, 'Either volume or intensity must ' + \
                       'be of type Datum')
        
        if number < 0:
            self.error(ValueError, 'Number must be >= 0')

        if volume is None:
            volume = Datum(None)

        if intensity is None:
            intensity = Datum(None)
            
        AriaBaseClass.__init__(self)

        self.__number = number
        self.__intensity = intensity
        self.__volume = volume
        self.__reliable = 0
        self.__updated = 0
        
        self.setDefaultValues()

    def setDefaultValues(self):

        from Datum import ChemicalShift

        self.__proton1ChemicalShift = ChemicalShift(None)
        self.__proton2ChemicalShift = ChemicalShift(None)

        self.__hetero1ChemicalShift = ChemicalShift(None)
        self.__hetero2ChemicalShift = ChemicalShift(None)

        self.__spectrum = None

        self.__id = None

        self.__proton1_assignments = []
        self.__hetero1_assignments = []
        self.__proton2_assignments = []
        self.__hetero2_assignments = []

        self.__valid = 1

    def isReliable(self, e = None):
        """
        if argument is None:

        returns 1 if cross-peak will be used in calculation under
        any circumstances, 0 otherwise.

        if argumente is int (zero or one): crosspeaks 
        """

        if e is None:
            return self.__reliable

        TCheck.check_int(e)

        self.__reliable = e > 0

    def isAssigned(self):
        """
        Returns 1 if the proton-dimensions have assignments
        """
        if self.getProton1Assignments() and self.getProton2Assignments():
            return 1
        else:
            return 0

    def getAssignmentType(self):

        types = [a.getType() for a in self.getProton1Assignments() + \
                 self.getProton2Assignments() + self.getHetero1Assignments()+ \
                 self.getHetero2Assignments()]

        if types.count(Assignment.ASSIGNMENT_TYPE_MANUAL) == len(types):

            return CROSSPEAK_ASSIGNMENT_TYPE_MANUAL

        elif types.count(Assignment.ASSIGNMENT_TYPE_AUTOMATIC) == len(types):

            return CROSSPEAK_ASSIGNMENT_TYPE_AUTOMATIC

        else:
            return CROSSPEAK_ASSIGNMENT_TYPE_SEMIAUTOMATIC

    def addProton1Assignment(self, a):

        TCheck.check_type(a, 'Assignment')
        self.__proton1_assignments.append(a)

    ## BARDIAUX 15/04/05
    def setProton1Assignments(self, assignments):

        TCheck.check_type(assignments, TCheck.LIST, TCheck.TUPLE)
        TCheck.check_elements(assignments, 'Assignment')

        self.__proton1_assignments = list(assignments)

    def setProton2Assignments(self, assignments):

        TCheck.check_type(assignments, TCheck.LIST, TCheck.TUPLE)
        TCheck.check_elements(assignments, 'Assignment')

        self.__proton2_assignments = list(assignments)


    def setHetero1Assignments(self, assignments):

        TCheck.check_type(assignments, TCheck.LIST, TCheck.TUPLE)
        TCheck.check_elements(assignments, 'Assignment')

        self.__hetero1_assignments = list(assignments)


    def setHetero2Assignments(self, assignments):

        TCheck.check_type(assignments, TCheck.LIST, TCheck.TUPLE)
        TCheck.check_elements(assignments, 'Assignment')

        self.__hetero2_assignments = list(assignments)

    ##
        

    def addHetero1Assignment(self, a):

        TCheck.check_type(a, 'Assignment')
        self.__hetero1_assignments.append(a)

    def addProton2Assignment(self, a):

        TCheck.check_type(a, 'Assignment')
        self.__proton2_assignments.append(a)

    def addHetero2Assignment(self, a):

        TCheck.check_type(a, 'Assignment')
        self.__hetero2_assignments.append(a)

    def getProton1Assignments(self):
        return self.__proton1_assignments

    def getHetero1Assignments(self):
        return self.__hetero1_assignments

    def getProton2Assignments(self):
        return self.__proton2_assignments

    def getHetero2Assignments(self):
        return self.__hetero2_assignments

    def setId(self, id):
        TCheck.check_int(id)
        self.__id = id

    def getId(self):
        return self.__id

    def setSpectrum(self, s):
        TCheck.check_type(s, 'NOESYSpectrum')
        self.__spectrum = s

    def getSpectrum(self):
        return self.__spectrum

    def getNumber(self):
        """
        Returns the peak's number.
        """
        return self.__number

    def getIntensity(self):
        """
        Returns the peak's intensity and its uncertainty.
        """
        return self.__intensity

    def getVolume(self):
        """
        Returns the peak's volume and its uncertainty.
        """
        return self.__volume

    ## TODO: why set-methods for chemical shifts?
    ## presumably they will be set only once.

    def setProton1ChemicalShift(self, v):
        TCheck.check_type(v, 'ChemicalShift')
        self.__proton1ChemicalShift = v

    def getProton1ChemicalShift(self):
        return self.__proton1ChemicalShift

    def setHetero1ChemicalShift(self, v):
        TCheck.check_type(v, 'ChemicalShift')
        self.__hetero1ChemicalShift = v

    def getHetero1ChemicalShift(self):
        return self.__hetero1ChemicalShift

    def setProton2ChemicalShift(self, v):
        TCheck.check_type(v, 'ChemicalShift')
        self.__proton2ChemicalShift = v

    def getProton2ChemicalShift(self):
        return self.__proton2ChemicalShift

    def setHetero2ChemicalShift(self, v):
        TCheck.check_type(v, 'ChemicalShift')
        self.__hetero2ChemicalShift = v

    def getHetero2ChemicalShift(self):
        return self.__hetero2ChemicalShift

    def is_valid(self, valid = None):

        TCheck.check_type(valid, TCheck.NONE, TCheck.INT)
        
        if valid is not None:
            self.__valid = valid
        else:
            return self.__valid == 1
        
    ## BARDIAUX 25/04/05
    def getDimension(self,spinsystem):
        """
        Input : spinsystem
        Return : Dimension of the spinsytem in this Xpk
        """

        from N import array, sum

        p1chem = self.getProton1ChemicalShift().getValue()
        p2chem = self.getProton2ChemicalShift().getValue()
        cpchem = [p1chem,p2chem]

        sschem = [a.getValue() for a in spinsystem.getChemicalShifts() \
                  if a.getValue() is not None]
        
        sschem = sum(array(sschem)) / len(sschem)

        if abs((sschem - p1chem)) <  abs((sschem - p2chem)):
            dim = 1
        else:
            dim = 2

        return dim
    
    def __str__(self):
        
        p1a = self.getProton1Assignments()
        h1a = self.getHetero1Assignments()
        p2a = self.getProton2Assignments()
        h2a = self.getHetero2Assignments()

        string = 'CrossPeak(id=%s, number=%d, volume=%s, intensity=%s, ' + \
                 'reliable=%s, proton1ppm=%s, hetero1ppm=%s, ' + \
                 'proton2ppm=%s, hetero2ppm=%s, ' + \
                 'proton1assignments=%s, hetero1assignments=%s ' + \
                 'proton2assignments=%s, hetero2assignments=%s)'

        return string % (str(self.getId()),
                         self.getNumber(),
                         str(self.getVolume()), 
                         str(self.getIntensity()),
                         str(self.isReliable()),
                         str(self.getProton1ChemicalShift()),
                         str(self.getHetero1ChemicalShift()),
                         str(self.getProton2ChemicalShift()),
                         str(self.getHetero2ChemicalShift()),
                         str(p1a), str(h1a), str(p2a), str(h2a))

    __repr__ = __str__

class CrossPeakXMLPickler(XMLBasePickler):

    order = ('number', 'reliable', 'volume', 'intensity', 'proton1',
             'proton2','hetero1', 'hetero2')

    def __init__(self):
        XMLBasePickler.__init__(self)
        self._set_name('Cross peak')

    def _xml_state(self, peak):

        from xmlutils import XMLElement
        from Atom import AtomXMLPickler
        from aria import YES, NO

        atom_xml_state = AtomXMLPickler()._xml_state
        
        e = XMLElement(tag_order = self.order)
        e.number = peak.getNumber()

        if peak.isReliable():
            e.reliable = YES
        else:
            e.reliable = NO
            
        e.volume = peak.getVolume()
        e.intensity = peak.getIntensity()

        d = {'proton1': {'shift': peak.getProton1ChemicalShift,
                         'assignment': peak.getProton1Assignments},
             'proton2': {'shift': peak.getProton2ChemicalShift,
                         'assignment': peak.getProton2Assignments},
             'hetero1': {'shift': peak.getHetero1ChemicalShift,
                         'assignment': peak.getHetero1Assignments},
             'hetero2': {'shift': peak.getHetero2ChemicalShift,
                         'assignment': peak.getHetero2Assignments}}

        for dim, attr in d.items():

            dd = XMLElement()
            tag_order = ['shift']

            shift = attr['shift']()
            dd.shift = shift
                
            assignments = attr['assignment']()
            if assignments:
                dd.assignment = assignments

            setattr(e, dim, dd)
            
        return e

    def load_from_element(self, e):

        from Datum import Datum
        from aria import YES, NO
        
        number = int(e.number)
        
        value, error = e.volume
        
        if value is not None and value < 0.:
            value = abs(value)
            
        volume = Datum(value, error)

        value, error = e.intensity
        
        if value is not None and value < 0.:
            value = abs(value)
            
        intensity = Datum(value, error)

        p = CrossPeak(number, volume, intensity)
        reliable = e.reliable

        if not reliable in (YES, NO):
            s = 'Crosspeak: attribute "reliable" must be either ' + \
                '"%s" or "%s". "%s" given.'
            self.error(ValueError, s % (str(YES), str(NO), reliable))

        p.isReliable(reliable == YES)

        d = {'proton1': {'shift': p.setProton1ChemicalShift,
                         'assignment': p.addProton1Assignment},
             'proton2': {'shift': p.setProton2ChemicalShift,
                         'assignment': p.addProton2Assignment},
             'hetero1': {'shift': p.setHetero1ChemicalShift,
                         'assignment': p.addHetero1Assignment},
             'hetero2': {'shift': p.setHetero2ChemicalShift,
                         'assignment': p.addHetero2Assignment}}

        for dim, attr in d.items():

            if not hasattr(e, dim): continue

            dd = getattr(e, dim)

            if hasattr(dd, 'shift'):
                attr['shift'](dd.shift)

            if hasattr(dd, 'assignment'):
                [attr['assignment'](a) for a in as_tuple(dd.assignment)]

        return p

CrossPeak._xml_state = CrossPeakXMLPickler()._xml_state
    
