class Cxx(object):
  def __init__(self):
    self.integers = {}

    return

  def clear(self):
    del self.typeMap
    del self.nullVar
    self.integers.clear()
    return

  def getTypeMap(self):
    if not hasattr(self, '_typeMap'):
      self._typeMap = self.createTypeMap()
    return self._typeMap
  def setTypeMap(self, typeMap):
    self._typeMap = typeMap
    return
  def delTypeMap(self):
    if hasattr(self, '_typeMap'):
      del self._typeMap
    return
  typeMap = property(getTypeMap, setTypeMap, delTypeMap, doc = 'A map from typenames to Cxx type objects')

  def getNullVar(self):
    if not hasattr(self, '_nullVar'):
      self._nullVar = self.getVar('NULL')
    return self._nullVar
  def setNullVar(self, nullVar):
    self._nullVar = nullVar
    return
  def delNullVar(self):
    if hasattr(self, '_nullVar'):
      del self._nullVar
    return
  nullVar = property(getNullVar, setNullVar, delNullVar, doc = 'The C variable NULL')
  
  def createTypeMap(self):
    from Cxx import Pointer
    from Cxx import Struct
    from Cxx import Type
    typeMap = {}

    cxxType = Type()
    cxxType.identifier = 'bool'
    cxxType.baseType = True
    typeMap['bool'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'char'
    cxxType.baseType = True
    typeMap['char'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'char'
    cxxType.baseType = True
    cxxType.const = True
    typeMap['const char'] = cxxType
    cxxType = Pointer()
    cxxType.children = [typeMap['char']]
    typeMap['char pointer'] = cxxType
    cxxType = Pointer()
    cxxType.children = [typeMap['const char']]
    typeMap['const char pointer'] = cxxType
    cxxType = Struct()
    cxxType.identifier = 'complex<double>'
    typeMap['dcomplex'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'double'
    cxxType.baseType = True
    typeMap['double'] = cxxType
    typeMap['double1'] = typeMap['double']
    cxxType = Type()
    cxxType.identifier = 'double2'
    cxxType.baseType = True
    typeMap['double2'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'double3'
    cxxType.baseType = True
    typeMap['double3'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'double'
    cxxType.baseType = True
    cxxType.const = True
    typeMap['const double'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'double2'
    cxxType.baseType = True
    cxxType.const = True
    typeMap['const double2'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'double3'
    cxxType.baseType = True
    cxxType.const = True
    typeMap['const double3'] = cxxType
    cxxType = Pointer()
    cxxType.children = [typeMap['double']]
    typeMap['double pointer'] = cxxType
    cxxType = Pointer()
    cxxType.children = [typeMap['double pointer']]
    typeMap['double pointer pointer'] = cxxType
    cxxType = Struct()
    cxxType.identifier = 'complex<float>'
    typeMap['fcomplex'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'float'
    cxxType.baseType = True
    typeMap['float'] = cxxType
    typeMap['float1'] = typeMap['float']
    cxxType = Type()
    cxxType.identifier = 'float2'
    cxxType.baseType = True
    typeMap['float2'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'float3'
    cxxType.baseType = True
    typeMap['float3'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'float'
    cxxType.baseType = True
    cxxType.const = True
    typeMap['const float1'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'float2'
    cxxType.baseType = True
    cxxType.const = True
    typeMap['const float2'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'float3'
    cxxType.baseType = True
    cxxType.const = True
    typeMap['const float3'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'int'
    cxxType.baseType = True
    typeMap['int'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'int'
    cxxType.baseType = True
    cxxType.const    = True
    typeMap['const int'] = cxxType
    cxxType = Pointer()
    cxxType.children = [typeMap['int']]
    typeMap['int pointer'] = cxxType
    cxxType = Pointer()
    cxxType.children = [typeMap['const int']]
    typeMap['const int pointer'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'long long'
    cxxType.baseType = True
    typeMap['long'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'void'
    cxxType.baseType = True
    typeMap['void'] = cxxType
    cxxType = Pointer()
    cxxType.children = [typeMap['void']]
    typeMap['void pointer'] = cxxType
    typeMap['opaque'] = typeMap['void pointer']
    cxxType = Pointer()
    cxxType.children = [typeMap['void pointer']]
    typeMap['void pointer pointer'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'std::str'
    cxxType.baseType = True
    typeMap['string'] = cxxType
    # Petsc types
    cxxType = Type()
    cxxType.identifier = 'PetscReal'
    cxxType.baseType = True
    typeMap['PetscReal'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'PetscReal'
    cxxType.baseType = True
    cxxType.const = True
    typeMap['const PetscReal'] = cxxType
    cxxType = Type()
    cxxType.identifier = 'PetscScalar'
    cxxType.baseType = True
    typeMap['PetscScalar'] = cxxType
    return typeMap
  
  def getType(self, name, numPointers = 0, isConst = 0):
    '''
    Return the C type corresponding to this sidl type
    '''
    cType = None
    if name in self.typeMap:
      if isConst:
        name = 'const '+name
      cType = self.typeMap[name]
    else:
      from Cxx import Type
      cType = Type()
      cType.identifier = name
      if isConst:
        cType.const = isConst
    for i in range(numPointers):
      from Cxx import Pointer
      pType = Pointer()
      pType.children = [cType]
      cType = pType
    return cType
  
  def getNull(self):
    if self.nullVar is None:
      raise RuntimeError('Object prematurely cleared')
    return self.nullVar
  
  def getInteger(self, num):
    if not num in self.integers:
      from Cxx import Constant
      integer = Constant()
      integer.value = num
      self.integers[num] = integer
    return self.integers[num]
  
  def getDouble(self, num):
    from Cxx import Constant
    double = Constant()
    double.value = num
    return double

  def getInclude(self, filename):
    import ASE.Compiler.Cxx.Include
    include = ASE.Compiler.Cxx.Include.Include()
    if filename[0] == '<' and filename[-1] == '>':
      include.setIdentifier(filename)
    else:
      include.setIdentifier('"' + filename + '"')
    return include
  
  def getString(self, s):
    from Cxx import Constant
    string = Constant()
    string.identifier = s
    string.value = s
    return string
  
  def getComment(self, comment):
    from Cxx import Comment
    c = Comment()
    c.comment = comment
    return c
  
  def getVar(self, name):
    from Cxx import Variable
    var = Variable()
    var.identifier = name
    return var
  
  def getAddress(self, exp):
    from Cxx import Address
    if isinstance(exp, str):
      exp = self.getValue(exp)
    addr = Address()
    addr.children = [self.getValue(exp)]
    return addr
  
  def getIndirection(self, exp):
    from Cxx import Indirection
    ind = Indirection()
    ind.children= [self.getValue(exp)]
    return ind

  def getMinus(self, exp):
    from Cxx import UnaryMinus
    minus = UnaryMinus()
    minus.children = [exp]
    return minus
  
  def getGroup(self, exp):
    from Cxx import Group
    group = Group()
    group.children = [exp]
    return group
  
  def getDeclGroup(self, exp):
    from Cxx import DeclaratorGroup
    group = DeclaratorGroup()
    group.children = [exp]
    return group

  def getPointer(self, exp):
    from Cxx import Pointer
    pointer = Pointer()
    pointer.children = [exp]
    return pointer
  
  def getDecl(self, declarator, comment = None):
    '''
    Make a declaration from a declarator
    '''
    from Cxx import Declaration
    decl = Declaration()
    decl.children = [declarator]
    if not comment is None:
      decl.comments = [self.getComment(comment)]
    return decl
  
  def getInitializer(self, type, value = None, isList = 0):
    from Cxx import Constant
    from Cxx import Initializer
    from Cxx import Pointer
    from Cxx import Struct
    from Cxx import Type
    init = Initializer()
    if not value is None:
      if not isinstance(value, list):
        value = [value]
      init.setChildren(value)
    elif type.baseType:
      typeName = type.identifier
      if typeName in ['bool', 'char', 'int', 'long', 'long long', 'int32_t', 'int64_t']:
        init.setChildren([self.getInteger(0)])
      elif typeName in ['double', 'float']:
        init.setChildren([self.getDouble(0.0)])
      elif typeName in ['dcomplex', 'fcomplex']:
        init.setChildren([self.getDouble(0.0), self.getDouble(0.0)])
      elif typeName in ['string']:
        init.setChildren([self.getNull()])
    elif isinstance(type, Pointer):
      init.setChildren([self.getNull()])
    elif isinstance(type, Struct):
      if type.identifier in ['SIDL_dcomplex', 'SIDL_fcomplex']:
        init.setChildren([self.getDouble(0.0), self.getDouble(0.0)])
      else:
        return None
    elif isinstance(type, Type):
      if type.identifier in ['Py_complex']:
        init.setChildren([self.getDouble(0.0), self.getDouble(0.0)])
      else:
        init.setChildren([self.getInteger(0)])
    else:
      init.setChildren([self.getInteger(0)])
    if not isList is None:
      init.list = isList
    return init

  def getDeclarator(self, var, t, isStatic = 0):
    from Cxx import Declarator
    decl = Declarator()
    if isinstance(var, str):
      decl.identifier = var
    elif not var is None:
      decl.identifier = var.identifier
    if isinstance(t, str):
      t = self.getType(t)
    decl.type   = t
    decl.static = isStatic
    return decl

  def getDeclaration(self, var, type, initializer = None, comment = None, isStatic = 0, isForward = 0):
    decl = self.getDeclarator(var, type, isStatic)
    if initializer is None:
      if not isForward:
        init = self.getInitializer(type)
        if not init is None:
          decl.initializer = init
    else:
      decl.initializer = initializer
    return self.getDecl(decl, comment)

  def getParameter(self, name, paramType):
    from Cxx import Parameter
    param = Parameter()
    if isinstance(name, str):
      param.identifier = name
    elif name is None:
      param.identifier = name
    else:
      param.identifier = name.identifier
    param.type = paramType
    return param

  def getArray(self, var, entryType, size = None, initializer = None, isStatic = 0):
    from Cxx import Array
    arrayDecl = Array()
    arrayDecl.children = [self.getValue(var)]
    arrayDecl.type = entryType
    if not size is None:
      arrayDecl.size = self.getValue(size)
    if not initializer is None:
      arrayDecl.initializer = initializer
    arrayDecl.static = isStatic
    return self.getDecl(arrayDecl)

  def getSizeof(self, type):
    from Cxx import Declarator
    from Cxx import Sizeof
    typeDecl = Declarator()
    typeDecl.type = self.getType(type)
    sizeof = Sizeof()
    sizeof.setTypeName(self.getDeclarator(None, type))
    return sizeof
  
  def getExpStmt(self, exp, comment = None, caseLabel = None):
    '''
    Make a statement from an expression
    '''
    from Cxx import ExpressionStatement
    stmt = ExpressionStatement()
    stmt.setChildren([exp])
    if not comment is None:
      stmt.setComments([self.getComment(comment)])
    if not caseLabel is None:
      stmt.caseLabel = self.getValue(caseLabel)
    return stmt

  def getNegation(self, exp):
    from Cxx import Negation
    neg = Negation()
    neg.children = [exp]
    return neg

  def getAddition(self, a, b):
    from Cxx import Addition
    exp = Addition()
    exp.children = [self.getValue(a), self.getValue(b)]
    return exp

  def getSubtraction(self, a, b):
    from Cxx import Subtraction
    exp = Subtraction()
    exp.children = [self.getValue(a), self.getValue(b)]
    return exp

  def getMultiplication(self, a, b):
    from Cxx import Multiplication
    exp = Multiplication()
    exp.children = [self.getValue(a), self.getValue(b)]
    return exp

  def getDivision(self, a, b):
    from Cxx import Division
    exp = Division()
    exp.children = [self.getValue(a), self.getValue(b)]
    return exp

  def getAdditionAssignment(self, a, b):
    from Cxx import AdditionAssignment
    exp = AdditionAssignment()
    exp.children = [self.getValue(a), self.getValue(b)]
    return exp

  def getSubtractionAssignment(self, a, b):
    from Cxx import SubtractionAssignment
    exp = SubtractionAssignment()
    exp.children = [self.getValue(a), self.getValue(b)]
    return exp

  def getEquality(self, a, b):
    from Cxx import Equality
    exp = Equality()
    exp.children = [self.getValue(a), self.getValue(b)]
    return exp

  def getInequality(self, a, b):
    from Cxx import Inequality
    exp = Inequality()
    exp.children = [self.getValue(a), self.getValue(b)]
    return exp

  def getLeftShift(self, a, b):
    from Cxx import LeftShift
    leftShift = LeftShift()
    leftShift.children = [self.getValue(a), self.getValue(b)]
    return leftShift

  def getArrayRef(self, arrayName, index):
    '''arrayName is a Variable, and index is processed by getValue()'''
    from Cxx import ArrayReference
    arrayRef = ArrayReference()
    arrayRef.children = [self.getValue(arrayName)]
    arrayRef.setIndex(self.getValue(index))
    return arrayRef

  def getStructRef(self, structName, member, isPointer = 1):
    '''structName is a Variable, and member is a string'''
    from Cxx import StructReference
    structRef = StructReference()
    structRef.children = [self.getValue(structName)]
    structRef.setMember(member)
    structRef.setPointer(isPointer)
    return structRef

  def getAssignment(self, target, value):
    from Cxx import Assignment
    assignment = Assignment()
    assignment.children = [target, self.getValue(value)]
    return assignment

  def castToType(self, exp, t, numPointers = 0):
    from Cxx import Cast
    from Cxx import Pointer
    for n in range(numPointers):
      pointer = Pointer()
      pointer.children = [t]
      t = pointer
    cast = Cast()
    cast.children = [self.getValue(exp)]
    cast.setTypeName(self.getDeclarator(None, t))
    return cast
  
  def castToNamedType(self, exp, typeName, numPointers = 0, isConst = 0, isStruct = 0):
    if isStruct:
      import ASE.Compiler.Cxx.Struct
      type = ASE.Compiler.Cxx.Struct.Struct()
      type.setIdentifier(typeName)
    else:
      if typeName in self.typeMap and not isConst:
        type = self.typeMap[typeName]
      else:
        import ASE.Compiler.Cxx.Type
        type = ASE.Compiler.Cxx.Type.Type()
        type.setIdentifier(typeName)
        type.setBaseType(typeName in self.typeMap.keys())
    if isConst:
      type.setConst(1)
    return self.castToType(exp, type, numPointers)
  
  def castToVoidPointer(self, exp):
    return self.castToNamedType(exp, 'void', 1)
  
  def castToString(self, constString):
    return self.castToNamedType(self.getString(constString), 'char', 1)

  def getValue(self, value):
    '''Return a constant or variable which best reflect the Python object'''
    if isinstance(value, str):
      cxxValue = self.getVar(value)
    elif isinstance(value, int):
      cxxValue = self.getInteger(value)
    elif isinstance(value, float):
      cxxValue = self.getDouble(value)
    else:
      cxxValue = value
    return cxxValue

  def getFunctionCall(self, name, args = []):
    from Cxx import FunctionCall
    funcCall = FunctionCall()
    funcCall.children = [self.getValue(name)]
    funcCall.arguments.extend([self.getValue(arg) for arg in args])
    return funcCall

  def getSimpleLoop(self, loopVar, lowerBound, upperBound, increment = 1, allowEquality = 0, allowInequality = 0, isPrefix = 0, body = None):
    from Cxx import CompoundStatement
    from Cxx import Declarator
    from Cxx import For
    from Cxx import Initializer

    if isinstance(loopVar, Declarator):
      loopDecl = loopVar
      loopDecl.initializer = self.getInitializer(None, value = self.getValue(lowerBound))
      loopVar  = self.getValue(loopDecl.identifier)
    else:
      loopVar  = self.getValue(loopVar)
      loopDecl = self.getDeclarator(loopVar, 'int')
      loopDecl.initializer = self.getInitializer(None, value = self.getValue(lowerBound))
    if allowEquality:
      from Cxx import LessThanOrEqual
      comp = LessThanOrEqual()
    elif allowInequality:
      from Cxx import Inequality
      comp = Inequality()
    else:
      from Cxx import LessThan
      comp = LessThan()
    comp.children = [loopVar, self.getValue(upperBound)]
    if increment == 1:
      from Cxx import Increment
      inc  = Increment()
      inc.children = [loopVar]
      inc.setPostfix(not isPrefix)
    elif increment == -1:
      from Cxx import Increment
      inc  = Increment()
      inc.children = [loopVar]
      inc.setDecrement(1)
      inc.setPostfix(not isPrefix)
    else:
      from Cxx import AdditionAssignment
      inc = AdditionAssignment()
      inc.children = [loopVar, self.getValue(increment)]
    loop = For()
    loop.setInitialization(loopDecl)
    loop.setBranch(comp)
    loop.setIncrement(inc)
    if body is None:
      loop.children = [CompoundStatement()]
    else:
      loop.children = body
    return loop

  def getReturn(self, exp = None, isPetsc = 0, caseLabel = None):
    from Cxx import Return
    if isPetsc:
      if not exp is None:
        args = [exp]
      else:
        args = [self.getInteger(0)]
      returnStmt = self.getExpStmt(self.getFunctionCall('PetscFunctionReturn', args))
    else:
      returnStmt = Return()
      if not exp is None:
        returnStmt.children = [self.getValue(exp)]
    if not caseLabel is None:
      returnStmt.caseLabel = self.getValue(caseLabel)
    return returnStmt

  def getThrow(self, exp = None):
    from Cxx import Throw
    throw = Throw()
    if not exp is None:
      throw.children = [exp]
    return throw

  def getPetscCheck(self, exp):
    testVar = self.getVar('ierr')
    return [self.getExpStmt(self.getAssignment(testVar, exp)), self.getExpStmt(self.getFunctionCall('CHKERRQ', [testVar]))]

  def getPetscError(self, errorCode, errorMsg, *args):
    if len(args):
      func = 'SETERRQ'+str(len(args))
    else:
      func = 'SETERRQ'
    return self.getExpStmt(self.getFunctionCall(func, [errorCode, self.getString(errorMsg)]+list(args)))

  def getFunctionHeader(self, name):
    from Cxx import Define
    undef = Define()
    undef.undef = 1
    undef.identifier = '__FUNCT__'
    define = Define()
    define.identifier = '__FUNCT__'
    define.replacementText = '"'+name+'"'
    return [undef, define]

  def getFunction(self, name, retType, params, decls, stmts):
    from Cxx import CompoundStatement
    from Cxx import Function
    func = Function()
    func.children = [self.getValue(name)]
    func.type = retType
    func.parameters = params
    func.body = CompoundStatement()
    func.body.declarations = decls
    func.body.children = stmts
    return func

  def getFunctionPointer(self, name, retType, params, numPointers = 1):
    from Cxx import Function
    from Cxx import Pointer
    func = Function()
    head = Pointer()
    for i in range(numPointers-1):
      new  = Pointer()
      new.children = [head]
      head = new
    func.children = [self.getDeclGroup(self.getDeclarator(self.getValue(name), head))]
    func.type = retType
    func.parameters = params
    return func

  def getIf(self, branch, cases):
    from Cxx import If
    ifb = If()
    ifb.branch = branch
    ifb.children = cases
    return ifb
