import graph

class Element(graph.Element):
  '''
  Needed multiple inheritance here
  '''
  def __init__(self):
    graph.Element.__init__(self)
    self.value = None
  
  def getValue_bool(self):
    '''
    Retrieve the node value
    '''
    shape = Numeric.shape(self.value)
    return Numeric.reshape(map(bool, Numeric.ravel(self.value)), shape)
  
  def getValue_char(self):
    '''
    Retrieve the node value
    '''
    shape = Numeric.shape(self.value)
    return Numeric.reshape(map(lambda v: chr(ord(v)), Numeric.ravel(self.value)), shape)
  
  def getValue_dcomplex(self):
    '''
    Retrieve the node value
    '''
    shape = Numeric.shape(self.value)
    return Numeric.reshape(map(complex, Numeric.ravel(self.value)), shape)
  
  def getValue_double(self):
    '''
    Retrieve the node value
    '''
    shape = Numeric.shape(self.value)
    return Numeric.reshape(map(float, Numeric.ravel(self.value)), shape)
  
  def getValue_int(self):
    '''
    Retrieve the node value
    '''
    shape = Numeric.shape(self.value)
    return Numeric.reshape(map(int, Numeric.ravel(self.value)), shape)
  
  def getValue_string(self):
    '''
    Retrieve the node value
    '''
    shape = Numeric.shape(self.value)
    return Numeric.reshape(map(str, Numeric.ravel(self.value)), shape)
  
  def setValue_bool(self, value):
    '''
    Set the node value
    '''
    self.value = value
  
  def setValue_char(self, value):
    '''
    Set the node value
    '''
    self.value = value
  
  def setValue_dcomplex(self, value):
    '''
    Set the node value
    '''
    self.value = value
  
  def setValue_double(self, value):
    '''
    Set the node value
    '''
    self.value = value
  
  def setValue_int(self, value):
    '''
    Set the node value
    '''
    self.value = value
  
  def setValue_string(self, value):
    '''
    Set the node value
    '''
    self.value = value
  
  def accept(self, visitor):
    visitor.visitEdge(self)
  
  def clone(self):
    import ASE.Loader
    return ASE.Loader.Loader.createClass('expression.Edge')

class Edge(graph.Edge, Element):
  def __init__(self):
    graph.Edge.__init__(self)
    Element.__init__(self)

class Vertex(graph.Vertex, Element):
  def __init__(self):
    graph.Vertex.__init__(self)
    Element.__init__(self)
    self.inEdges = []
    self.outEdges = []
    self.discoveryNumber = 0
    self.finishNumber = 0
    self.level = 0
  
  def getInEdges(self):
    '''
    Retrieve the edge for which this vertex is the target
    '''
    return self.inEdges
  
  def setInEdges(self, edges):
    '''
    Set the edge for which this vertex is the target
    '''
    self.inEdges = [e for e in edges]
  
  def addInEdge(self, edge):
    '''
    Add an input edge
    '''
    if not (edge in self.inEdges):
      if not (edge.getTarget() == self):
        edge.setTarget(self)
      self.inEdges.append(edge)
  
  def removeInEdge(self, edge):
    '''
    Remove an input edge
    '''
    if edge in self.inEdges:
      self.inEdges.remove(edge)
  
  def getOutEdges(self):
    '''
    Retrieve the edge for which this vertex is the source
    '''
    return self.outEdges
  
  def setOutEdges(self, edges):
    '''
    Set the edge for which this vertex is the source
    '''
    self.outEdges = [e for e in edges]
  
  def addOutEdge(self, edge):
    '''
    Add an output edge
    '''
    if not (edge in self.outEdges):
      if not (edge.getSource() == self):
        edge.setSource(self)
      self.outEdges.append(edge)
  
  def removeOutEdge(self, edge):
    '''
    Remove an output edge
    '''
    if edge in self.outEdges:
      self.outEdges.remove(edge)
  
  def accept(self, visitor):
    visitor.visitVertex(self)

class AbsoluteValue(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitAbsoluteValue(self)

class Constant(Vertex):
  def __init__(self):
    Vertex.__init__(self)

  def accept(self, visitor):
    visitor.visitConstant(self)

class DiscreteVariable(Vertex):
  def __init__(self):
    Vertex.__init__(self)
    self.rank = 0
    self.numQuadraturePoints = 0
    self.pointArray = ''
    self.weightArray = ''
    self.quadratureVariable = 'p'
    self.numBasisFunctions = 0
    self.basisArray = ''
    self.basisDerivativeArray = ''
    self.basisVariable = ''
  
  def getRank(self):
    '''
    Retrieve the tensor rank of the variable
    '''
    return self.rank
  
  def setRank(self, rank):
    '''
    Set the tensor rank of the variable
    '''
    self.rank = rank
  
  def getNumQuadraturePoints(self):
    '''
    Retrieve the number of quadrature points in the discretization
    '''
    return self.numQuadraturePoints
  
  def setNumQuadraturePoints(self, numQuadraturePoints):
    '''
    Set the number of quadrature points in the discretization
    '''
    self.numQuadraturePoints = numQuadraturePoints
  
  def getPointArray(self):
    '''
    Retrieve the name of the array holding the quadrature points
    '''
    return self.pointArray
  
  def setPointArray(self, arrayName):
    '''
    Set the name of the array holding the quadrature points
    '''
    self.pointArray = arrayName
  
  def getWeightArray(self):
    '''
    Retrieve the name of the array holding the quadrature weights
    '''
    return self.weightArray
  
  def setWeightArray(self, arrayName):
    '''
    Set the name of the array holding the quadrature weights
    '''
    self.weightArray = arrayName
  
  def getQuadratureVariable(self):
    '''
    Retrieve the variable name which labels the current quadrature point
    '''
    return self.quadratureVariable
  
  def setQuadratureVariable(self, variableName):
    '''
    Set the variable name which labels the current quadrature point
    '''
    self.quadratureVariable = variableName
  
  def getNumBasisFunctions(self):
    '''
    Retrieve the number of basis functions in the discretization
    '''
    return self.numBasisFunctions
  
  def setNumBasisFunctions(self, numBasisFunctions):
    '''
    Set the number of basis functions in the discretization
    '''
    self.numBasisFunctions = numBasisFunctions
  
  def getBasisArray(self):
    '''
    Retrieve the name of the array holding the evaluation of the basis functions at the quadrature points
    '''
    return self.basisArray
  
  def setBasisArray(self, arrayName):
    '''
    Set the name of the array holding the evaluation of the basis functions at the quadrature points
    '''
    self.basisArray = arrayName
  
  def getBasisDerivativeArray(self):
    '''
    Retrieve the name of the array holding the evaluation of the basis function derivatives at the quadrature points
    '''
    return self.basisDerivativeArray
  
  def setBasisDerivativeArray(self, arrayName):
    '''
    Set the name of the array holding the evaluation of the basis function derivatives at the quadrature points
    '''
    self.basisDerivativeArray = arrayName
  
  def getBasisVariable(self):
    '''
    Retrieve the variable name which labels the current basis function
    '''
    return self.basisVariable
  
  def setBasisVariable(self, variableName):
    '''
    Set the variable name which labels the current basis function
    '''
    self.basisVariable = variableName
  
  def accept(self, visitor):
    visitor.visitDiscreteVariable(self)

class Negation(Vertex):
  def __init__(self, IORself):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitNegation(self)

class Variable(Vertex):
  def __init__(self):
    Vertex.__init__(self)
    self.rank = 0
  
  def getRank(self):
    '''
    Retrieve the tensor rank of the variable
    '''
    return self.rank
  
  def setRank(self, rank):
    '''
    Set the tensor rank of the variable
    '''
    self.rank = rank
  
  def accept(self, visitor):
    visitor.visitVariable(self)

class Parser:
  def __init__(self):
    import exprparse
    self.parser = exprparse.setup()
  
  def parse(self, buf):
    '''
    Parse the buffer and return an AST object
    '''
    self.graph = self.parser.parse(buf)
    self.parser.restart()
    return self.graph
  
  def parseFile(self, filename):
    '''
    Parse the file and return an AST object
    '''
    f = file(filename)
    s = f.read()
    f.close()
    return self.parse(s)

class Addition(Vertex):
  def __init__(self):
    Vertex.__init__(self)
    self.setIdentifier('+')
  
  def accept(self, visitor):
    visitor.visitAddition(self)

class Cosine(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitCosine(self)

class Determinant(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitDeterminant(self)

class Exponential(Vertex):
  def __init__(self):
    Vertex.__init__(self)
    self.setIdentifier('^')
  
  def accept(self, visitor):
    visitor.visitExponential(self)

class Inverse(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitInverse(self)

class Multiplication(Vertex):
  def __init__(self):
    Vertex.__init__(self)
    self.setIdentifier('*')
  
  def accept(self, visitor):
    visitor.visitMultiplication(self)

class Sine(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitSine(self)

class Subtraction(Vertex):
  def __init__(self):
    Vertex.__init__(self)
    self.setIdentifier('-')
  
  def accept(self, visitor):
    visitor.visitSubtraction(self)

class Transpose(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitTranspose(self)

class Bracket(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitBracket(self)

class Curl(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitCurl(self)

class Derivative(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitDerivative(self)

class Divergence(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitDivergence(self)

class Division(Vertex):
  def __init__(self):
    Vertex.__init__(self)
    self.isInteger = 0
    self.setIdentifier('/')
  
  def setIntegerDivision(self, isInteger):
    '''
    Set the flag indicating integer division, or truncation of the result to the nearest integer
    '''
    self.isInteger = isInteger
  
  def getIntegerDivision(self):
    '''
    Retrieve the flag indicating integer division, or truncation of the result to the nearest integer
    '''
    return self.isInteger
  
  def accept(self, visitor):
    visitor.visitDivision(self)

class Gradient(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitGradient(self)

class InnerProduct(Vertex):
  def __init__(self):
    Vertex.__init__(self)
    self.setIdentifier('.')

  def accept(self, visitor):
    visitor.visitInnerProduct(self)

class List(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitList(self)

class OuterProduct(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitOuterProduct(self)

class Subscript(Vertex):
  def __init__(self):
    Vertex.__init__(self)
  
  def accept(self, visitor):
    visitor.visitSubscript(self)

class OriginalDepthFirstVisitor:
  '''
  Default DFS visitor for Expression Graphs
  When subclassing this visitor, you should override the visit*() methods so that
  
  def visit<element>(self, element):
  <discover element>
  self.traverse<element>(element)
  <finish element>
  '''
  def __init__(self):
    self.seen = []
    self.reverseEdges = 0
    self.vertexNumber = 0
  
  def visitElement(self, element):
    pass
  
  def visitVertex(self, element):
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitEdge(self, element):
    self.traverseEdge(element)
  
  def getReverseEdges(self):
    '''
    Retrieve the flag for reversing the orientation of edges in the traversal
    '''
    return self.reverseEdges
  
  def setReverseEdges(self, reverseEdges):
    '''
    Set the flag for reversing the orientation of edges in the traversal
    '''
    self.reverseEdges = reverseEdges
  
  def traverseVertex(self, vertex):
    '''
    Traverse each edge of this vertex
    '''
    self.setVertexSeen(vertex, 1)
    if self.getReverseEdges():
      edges = vertex.getInEdges()
    else:
      edges = vertex.getOutEdges()
    for edge in edges:
      edge.accept(self)
  
  def traverseEdge(self, edge):
    '''
    Traverse this edge to its target vertex
    '''
    if self.getReverseEdges():
      next = edge.getSource()
    else:
      next = edge.getTarget()
    if not (self.getVertexSeen(next)):
      next.accept(self)
  
  def resetVertexNumbering(self):
    '''
    Reset the vertex discovery and finish numbering
    '''
    self.vertexNumber = 0
  
  def discoverVertex(self, vertex):
    '''
    Mark a vertex as discovered and increment the count
    '''
    vertex.setDiscoveryNumber(self.vertexNumber)
    self.vertexNumber += 1
  
  def finishVertex(self, vertex):
    '''
    Mark a vertex as finished and increment the count
    '''
    vertex.setFinishNumber(self.vertexNumber)
    self.vertexNumber += 1
  
  def getVerticesSeen(self):
    '''
    Retrieve the list of vertices discovered
    '''
    return self.seen
  
  def setVerticesSeen(self, seen):
    '''
    Set the list of vertices discovered
    '''
    self.seen = list(seen)
  
  def getVertexSeen(self, vertex):
    '''
    Retrieve the vertex discovery state
    '''
    return vertex in self.seen
  
  def setVertexSeen(self, vertex, state):
    '''
    Set the vertex discovery state
    '''
    if state:
      if not (vertex in self.seen):
        self.seen.append(vertex)
    else:
      if vertex in self.seen:
        self.seen.remove(vertex)

class DepthFirstVisitor(OriginalDepthFirstVisitor):
  '''
  Default DFS visitor for PDE Expression Graphs
  '''
  def __init__(self):
    OriginalDepthFirstVisitor.__init__(self)
    self.graph = None
  
  def getGraph(self):
    '''
    Retrieve the Graph which is being visited
    '''
    return self.graph
  
  def setGraph(self, graph):
    '''
    Set the Graph which is being visited
    '''
    self.graph = graph
  
  def visitVariable(self, element):
    '''
    Visit a Variable operand
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitDiscreteVariable(self, element):
    '''
    Visit a DiscreteVariable operand
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitConstant(self, element):
    '''
    Visit a Constant operand
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitAbsoluteValue(self, element):
    '''
    Visit an AbsoluteValue operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitNegation(self, element):
    '''
    Visit a Negation operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitExponential(self, element):
    '''
    Visit a Exponential operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitCosine(self, element):
    '''
    Visit a Cosine operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitSine(self, element):
    '''
    Visit a Sine operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitTranspose(self, element):
    '''
    Visit a Transpose operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitDeterminant(self, element):
    '''
    Visit a Determinant operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitInverse(self, element):
    '''
    Visit an Inverse operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitAddition(self, element):
    '''
    Visit an Addition operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitSubtraction(self, element):
    '''
    Visit a Subtraction operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitMultiplication(self, element):
    '''
    Visit a Multiplication operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitDivision(self, element):
    '''
    Visit a Division operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitList(self, element):
    '''
    Visit a List operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitSubscript(self, element):
    '''
    Visit a Subscript operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitInnerProduct(self, element):
    '''
    Visit an InnerProduct operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitOuterProduct(self, element):
    '''
    Visit an OuterProduct operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitBracket(self, element):
    '''
    Visit a Bracket operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitDerivative(self, element):
    '''
    Visit a Derivative operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitGradient(self, element):
    '''
    Visit a Gradient operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitDivergence(self, element):
    '''
    Visit a Divergence operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)
  
  def visitCurl(self, element):
    '''
    Visit a Curl operator
    '''
    self.discoverVertex(element)
    self.traverseVertex(element)
    self.finishVertex(element)

class DepthFirstSearch(DepthFirstVisitor):
  '''
  Default DFS for PDE Expression Graphs
  This should really also inherit from TOPS.Expression.DepthFirstSearch
  '''
  def __init__(self):
    DepthFirstVisitor.__init__(self)
    self.graph = None
    self.returnFinished = 0
  
  def searchGraph(self):
    '''
    Call visit routines for each element of the graph
    '''
    if self.graph is None:
      raise RuntimeError('Graph not set before search')
    for vertex in list(self.graph.getRoots()) + list(self.graph.getVertices()):
      if not (self.getVertexSeen(vertex)):
        vertex.accept(self)
        self.setVertexSeen(vertex, 1)
  
  def getReturnFinished(self):
    '''
    Retrieve the flag for returning elements after they are finished
    '''
    return self.returnFinished
  
  def setReturnFinished(self, returnFinished):
    '''
    Set the flag for returning elements after they are finished
    '''
    self.returnFinished = returnFinished
  
  def nextElement(self):
    '''
    Return the next element, or a null object
    '''
    iterator = self.getIterator()
    try:
      return iterator.next()
    except StopIteration:
      return None
  
  def nextVertex(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.Vertex')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextEdge(self):
    '''
    Return the next edge, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.Edge')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextVariable(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Variable')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextDiscreteVariable(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.DiscreteVariable')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextConstant(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Constant')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextAbsoluteValue(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.AbsoluteValue')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextNegation(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Negation')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextExponential(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Exponential')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextCosine(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Cosine')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextSine(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Sine')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextTranspose(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Transpose')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextDeterminant(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Determinant')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextInverse(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Inverse')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextAddition(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Addition')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextSubtraction(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Subtraction')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextMultiplication(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Multiplication')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextDivision(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Division')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextList(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.List')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextSubscript(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Subscript')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextInnerProduct(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.InnerProduct')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextOuterProduct(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.OuterProduct')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextBracket(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Bracket')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextDerivative(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Derivative')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextGradient(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Gradient')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextDivergence(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Divergence')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def nextCurl(self):
    '''
    Return the next vertex, or a null object
    '''
    iterator = self.getIterator()
    try:
      next = iterator.next()
      while not (next.isInstanceOf('TOPS.Expression.PDE.Curl')):
        next = iterator.next()
      return next
    except StopIteration:
      return None
  
  def getGraph(self):
    '''
    Retrieve the graph for this search
    '''
    return self.graph
  
  def setGraph(self, graph):
    '''
    Set the graph for this search, which also resets the iterator
    '''
    self.graph = graph
    self.isDirty = 1
    self.resetVertexNumbering()
    self.setVerticesSeen([])
  
  def getIterator(self):
    if self.graph is None:
      raise RuntimeError('Graph not set before iteration')
    if self.isDirty:
      self.iterator = self.dfsIterator(self.returnFinished)
      self.isDirty = 0
    return self.iterator
  
  def dfsIterator(self, returnFinished):
    '''
    This is a generator returning vertices in a depth-first traversal
       - If returnFinished is True, return a vertex when it finishes
       - Otherwise, return a vertex when it is first seen
    '''
    for vertex in self.graph.getVertices():
      if not (self.getVertexSeen(vertex)):
        try:
          for element in self.depthFirstVisit(vertex, returnFinished):
            yield element
        except StopIteration:
          pass
    return 
  
  def depthFirstVisit(self, vertex, returnFinished):
    '''
    This is a generator returning vertices in a depth-first traversal only for the subtree rooted at vertex
       - If returnFinished is True, return a vertex when it finishes
       - Otherwise, return a vertex when it is first seen
    '''
    self.setVertexSeen(vertex, 1)
    if not (returnFinished):
      self.discoverVertex(vertex)
      yield vertex
    if self.getReverseEdges():
      edges = vertex.getInEdges()
    else:
      edges = vertex.getOutEdges()
    for edge in edges:
      if not (returnFinished):
        yield edge
      if self.getReverseEdges():
        next = edge.getSource()
      else:
        next = edge.getTarget()
      if not (self.getVertexSeen(next)):
        try:
          for element in self.depthFirstVisit(next, returnFinished):
            yield element
        except StopIteration:
          pass
      if returnFinished:
        yield edge
    if returnFinished:
      self.finishVertex(vertex)
      yield vertex
    return 

class VertexLocator(DepthFirstSearch):
  '''
  A VertexLocator detects the presence of a vertex in the tree, and returns all occurences
  '''
  def __init__(self):
    DepthFirstSearch.__init__(self)
    self.vertexNumber = 0
    self.name = ''
    self.typeName = ''
    self.vertices = []
  
  def getName(self):
    '''
    Retrieve the identifier to search for
    '''
    return self.name
  
  def setName(self, name):
    '''
    Set the identifier to search for
    '''
    self.name = name
    self.vertices = []
  
  def getType(self):
    '''
    Retrieve the type to search for
    '''
    return self.typeName
  
  def setType(self, typeName):
    '''
    Set the type to search for
    '''
    self.typeName = typeName
    self.vertices = []
  
  def hasVertex(self):
    '''
    Retrieve the flag indicating whether the vertex was found
    '''
    return len(self.vertices)
  
  def getVertices(self):
    '''
    Retrieve all the vertices found
    '''
    return self.vertices
  
  def discoverVertex(self, vertex):
    '''
    Check all vertices as they are discovered
    '''
    vertex.setDiscoveryNumber(self.vertexNumber)
    self.vertexNumber += 1
    if (self.name or self.typeName):
      if (self.name and not (vertex.getIdentifier() == self.name)):
        return 
      if self.typeName:
        path = self.typeName.split('.')
        mod = self.importModule('.'.join(path[:-1]))
        klass = getattr(mod, path[-1])
        try:
          obj = isinstance(vertex, klass)
          if obj is None:
            return 
        except TypeError:
          return 
      self.vertices.append(vertex)

  def importModule(self, moduleName):
    '''STOLEN from script.py
    Import the named module, and return the module object
    - Works properly for fully qualified names
    '''
    module = __import__(moduleName)
    components = moduleName.split('.')
    for comp in components[1:]:
      module = getattr(module, comp)
    return module

class Printer(DepthFirstSearch):
  def __init__(self):
    DepthFirstSearch.__init__(self)
    import sys
    self.f = sys.stdout

  def unmarkVariables(self, element):
    locator = VertexLocator()
    locator.setType('expression.Variable')
    locator.setGraph(self.getGraph().getSubgraph(element))
    locator.searchGraph()
    [self.setVertexSeen(v, 0) for v in locator.getVertices()]
    return
  
  def visitVariable(self, element):
    '''
    Visit a Variable operand
    '''
    if element.getRank() == 1:
      self.write('vec ')
    self.write(element.getIdentifier())
    self.visitVertex(element)
  
  def visitDiscreteVariable(self, element):
    '''
    Visit a DiscreteVariable operand
    '''
    self.write('\\sum_i ' + element.getIdentifier() + '_i \\phi_i')
    self.visitVertex(element)
  
  def visitConstant(self, element):
    '''
    Visit a Constant operand
    '''
    value = element.getValue_double()
    if len(value) == 1:
      self.write(str(value[0]))
    else:
      self.write(str(value))
    self.visitVertex(element)
  
  def visitAbsoluteValue(self, element):
    '''
    Visit an AbsoluteValue operator
    '''
    self.write('|')
    self.visitVertex(element)
    self.write('|')
  
  def visitNegation(self, element):
    '''
    Visit a Negation operator
    '''
    self.write('-')
    self.visitVertex(element)
  
  def visitExponential(self, element):
    '''
    Visit a Exponential operator
    '''
    self.outputBinaryExpression(element)
  
  def visitCosine(self, element):
    '''
    Visit a Cosine operator
    '''
    self.write('cos(')
    self.visitVertex(element)
    self.write(')')
  
  def visitSine(self, element):
    '''
    Visit a Sine operator
    '''
    self.write('sin(')
    self.visitVertex(element)
    self.write(')')
  
  def visitTranspose(self, element):
    '''
    Visit a Transpose operator
    '''
    self.write('trans(')
    self.visitVertex(element)
    self.write(')')
  
  def visitDeterminant(self, element):
    '''
    Visit a Determinant operator
    '''
    self.write('det(')
    self.visitVertex(element)
    self.write(')')
  
  def visitInverse(self, element):
    '''
    Visit an Inverse operator
    '''
    self.write('inv(')
    self.visitVertex(element)
    self.write(')')

  def visitInnerProduct(self, element):
    '''
    Visit an InnerProduct operator
    '''
    self.outputBinaryExpression(element)

  def visitAddition(self, element):
    '''
    Visit an Addition operator
    '''
    self.outputBinaryExpression(element)
  
  def visitSubtraction(self, element):
    '''
    Visit a Subtraction operator
    '''
    self.outputBinaryExpression(element)
  
  def visitMultiplication(self, element):
    '''
    Visit a Multiplication operator
    '''
    self.outputBinaryExpression(element)
  
  def visitDivision(self, element):
    '''
    Visit a Division operator
    '''
    self.outputBinaryExpression(element)
  
  def visitList(self, element):
    '''
    Visit a List operator
    '''
    self.write('{')
    for (c, child) in enumerate([edge.getTarget() for edge in element.getOutEdges()]):
      if c > 0:
        self.write(', ')
      child.accept(self)
    self.write('}')
    self.visitVertex(element)
  
  def visitSubscript(self, element):
    '''
    Visit a Subscript operator
    '''
    children = [edge.getTarget() for edge in element.getOutEdges()]
    if not (len(children) == 2):
      raise RuntimeError('Binary expression must have two children')
    children[0].accept(self)
    self.write('[')
    children[1].accept(self)
    self.write(']')
    self.visitVertex(element)
  
  def visitBracket(self, element):
    '''
    Visit an Bracket operator
    '''
    self.unmarkVariables(element)
    children = [edge.getTarget() for edge in element.getOutEdges()]
    if not (len(children) == 2):
      raise RuntimeError('Binary expression must have two children, had ' + str(len(children)))
    self.write('<')
    children[0].accept(self)
    self.write(', ')
    children[1].accept(self)
    self.write('>')
    self.visitVertex(element)
  
  def visitGradient(self, element):
    '''
    Visit an Gradient operator
    '''
    self.write('grad ')
    self.visitVertex(element)
  
  def visitDivergence(self, element):
    '''
    Visit an Divergence operator
    '''
    self.write('div ')
    self.visitVertex(element)
  
  def visitCurl(self, element):
    '''
    Visit an Curl operator
    '''
    self.write('curl ')
    self.visitVertex(element)
  
  def write(self, s):
    self.f.write(s)
    return 
  
  def outputBinaryExpression(self, element):
    children = [edge.getTarget() for edge in element.getOutEdges()]
    if not (len(children) == 2):
      raise RuntimeError('Binary expression must have two children, had ' + str(len(children)))
    self.write('(')
    children[0].accept(self)
    self.write(' ' + element.getIdentifier() + ' ')
    children[1].accept(self)
    self.visitVertex(element)
    self.write(')')
    return 
