source: TI02-CSML/trunk/csml/csmllibs/NumericTransform.py @ 2473

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI02-CSML/trunk/csml/csmllibs/NumericTransform.py@2473
Revision 2473, 5.6 KB checked in by domlowe, 13 years ago (diff)

merging PML code branch with trunk

  • Property svn:executable set to *
Line 
1#!/usr/bin/python
2# Author: Mark Henning <mhen@pml.ac.uk>
3#
4# A simple (and safe) floating point mathematical expression solver,
5# that can handle both infix and postfix (reverse polish) notation
6#
7# TODO: Currently quite intolerent of incorrect syntax.
8# TODO: Implementation of shunting yard too opaque. Possibly rewrite.
9#
10# Change History
11# 2007-05-04 mhen Modified to handle Numeric arrays
12
13from math import *
14import Numeric
15import re
16
17# Symbol Table.
18# Functions:
19#   'args':    The number of arguments accepted by the function
20#   'def':     Function object. (*MUST* accept number of args specified in 'args')
21#
22# Additionaly, if function represents an infix operator, then the following must
23# be defined:
24#   'prec':    The operator precedence. Must be an integer.
25#   'assoc':   'yes' if the infix operator is associative. 'left' or right' if
26#              the operator is left or right non-associative.
27#
28# Constants can be defined as a function that takes no arguments.
29
30_symbols = {
31   # Operators:
32   '*': {'args':2, 'prec':5, 'assoc':'yes',  'def':lambda a,b: a*b },
33   '-': {'args':2, 'prec':0, 'assoc':'left', 'def':lambda a,b: a-b },
34   '/': {'args':2, 'prec':5, 'assoc':'left', 'def':lambda a,b: a/b },
35   '+': {'args':2, 'prec':0, 'assoc':'yes',  'def':lambda a,b: a+b },
36   '^': {'args':2, 'prec':9, 'assoc':'right','def':lambda a,b: a**b},
37   '%': {'args':2, 'prec':5, 'assoc':'left', 'def':lambda a,b: a%b },
38   # Functions:
39   'log':  {'args':2, 'def':lambda base,a: log(a,base) },
40   'sin':  {'args':1, 'def':lambda a: sin(a) },
41   'cos':  {'args':1, 'def':lambda a: cos(a) },
42   'tan':  {'args':1, 'def':lambda a: tan(a) },
43   'asin': {'args':1, 'def':lambda a: asin(a) },
44   'acos': {'args':1, 'def':lambda a: acos(a) },
45   'atan': {'args':1, 'def':lambda a: atan(a) },
46   'sqrt': {'args':1, 'def':lambda a: sqrt(a) },
47   'floor':{'args':1, 'def':lambda a: floor(a) },
48   'ceil': {'args':1, 'def':lambda a: ceil(a) },
49   'abs':  {'args':1, 'def':lambda a: fabs(a) },
50   # Constants:
51   'pi':   {'args':0, 'def':lambda: pi },
52   'e':    {'args':0, 'def':lambda: e }
53}
54
55#
56# Implementation of Dijkstra's Shunting Yard Algorithm to compile infix
57# notation into postfix (reverse polish) notation.
58#
59# TODO: This implementation is rather opaque. Possibly rewrite at some point.
60def dijkstraShuntingYard(tokens):
61   s = [] # The stack
62   q = [] # Output queue
63   for t in tokens:
64      if(isNumericString(t) or not (t in _symbols or t in ['(',',',')'])):
65         q.append(t) # Float literal or unknown symbolic token (assume var)
66      elif(t == '(' or (t in _symbols and 'prec' not in _symbols[t])):           
67         s.append(t) # Function, constant or left parenthesis token.
68      elif(t == ')' or t == ','): # right paren or func separator token.
69         while(len(s)>0 and s[len(s)-1] != '('):
70            q.append(s.pop())
71         if(len(s) == 0): raise SyntaxError("Could not parse expression.")
72         if(t == ')'):
73            s.pop()
74            if(len(s)>0 and s[len(s)-1] in _symbols and 'prec' not in _symbols[s[len(s)-1]]):
75               q.append(s.pop())
76      else: # Infix operator token.
77         modifier = 1
78         if(_symbols[t]['assoc'] == 'right'): modifier = 0
79         while(len(s)>0 and s[len(s)-1] in _symbols and 'prec' in _symbols[s[len(s)-1]] 
80               and _symbols[t]['prec'] < _symbols[s[len(s)-1]]['prec'] + modifier):
81            q.append(s.pop())
82         s.append(t)
83   while(len(s)>0):                   
84      t = s.pop()
85      if(t not in _symbols): raise SyntaxError("Could not parse expression.")
86      q.append(t)
87   return q
88
89# Return true if the string represents a floating point value.
90def isNumericString(candidate):
91   try: float(candidate)
92   except ValueError: return False
93   return True
94
95# Parse and evaluate postfix expressions
96class postfixExpression:
97   def __init__(self, expression):
98      self.tokens = expression.split()
99
100   # Solve the tokenised postfix expression stored in self.tokens and
101   # return the solution (or None if no solution could be found)
102   def solve(self, **variables):
103      self.stack = []
104      for token in self.tokens:
105         if token in variables:
106            self.stack.append(variables[token])
107         elif token in _symbols:
108            # Pop reqd. no. of args off stack and pass to function in symbol table,
109            # putting result back onto stack:
110            args = self.stack[-_symbols[token]['args']:]
111            self.stack = self.stack[0:-_symbols[token]['args']]
112            self.stack.append( _symbols[token]['def'](*args) )
113         else:
114            try: # Not in symbol tables, but it might be a numeric literal...
115               self.stack.append(float(token))
116            except ValueError: # Nope, not a literal.
117               raise NameError("Could not resolve symbol '"+token+"'.")
118      if(len(self.stack) == 1):
119         return self.stack[0]
120
121
122# Parse and evaluate infix expressions:
123class infixExpression(postfixExpression):
124   def __init__(self, expression):
125      infixTokens = re.findall("\s*([-0-9.]+|\w+|.)\s*", expression)
126      self.tokens = dijkstraShuntingYard(infixTokens)
127   
128
129# Regression Test:
130if __name__ == "__main__":
131   expression_str = " (10* (sqrt(16) + 3) - 20)*x/10^2+-0.42  "
132   print "Raw Expression:", expression_str
133   expression = infixExpression(expression_str)
134   print "Compiled (RPN) Expression:", expression.tokens
135   print "== Basic Tests: =="
136   print "  Solving for x=2, (Answer should be 0.58): ", expression.solve( x=2 )
137   print "  Solving for x=3, (Answer should be 1.08): ", expression.solve( x=3 )
138   print "== NumPy Tests: =="
139   print "  Solving for x=[2 3 4] (Answer should be [0.58 1.08 1.58]): ", expression.solve(x=Numeric.array((2,3,4)))
Note: See TracBrowser for help on using the repository browser.