source: qesdi/geoplot/trunk/lib/geoplot/grid_builder_base.py @ 6308

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/grid_builder_base.py@6308
Revision 6308, 15.1 KB checked in by pnorton, 10 years ago (diff)

Applied the fix to _mergeBounds so that it will cope with reversed bounds.

Line 
1"""
2grid_builder_base.py
3====================
4
5Holds the GridBuilderBase class. A GridBuilder is an object that knows how to
6extract a lat-lon Grid object from a given cdmsVariable. This is an abstract
7baseclass and can't be used directly.
8
9"""
10
11#python modules
12import logging
13import time
14
15#third party modules
16
17#internal modules
18from geoplot.grid import Grid
19
20import numpy as N
21import numpy.ma as MA
22
23#set the log
24log = logging.getLogger(__name__)
25
26class GridBuilderBase(object):
27    """
28    Impliments the common GridBuilder functionality.
29
30    This is an abstract class. Any methods that iherit from this class need to
31    impliment the _resizeVar, _buildGridBounds, _buildGridMidpoints and the
32    _buildGridValues methods.
33    """
34
35    def __init__(self, cdmsVar):
36        """
37        constructs the grid builder object
38       
39        @param cdmsVar: the cdms variable that conains the grid data.
40        @type cdmsVar:a cdms.variable object 
41        """
42       
43        self.cdmsVar = cdmsVar
44        self._checkVariable()
45       
46#        for axis in self.cdmsVar.getAxisList():
47#            _logAxis(axis)
48
49    def buildGrid(self, xLimits=None, yLimits=None):
50        """
51        builds a grid object from the data in the cdmsVar ascociated with the
52        grid builder object.
53       
54        If xLimits or yLimits values are given then the resulting grid will
55        contain the portion of the data in the cdms variable that falls within
56        the limits given.
57       
58        @keyword xLimits: (optional) longitude limits of the resulting grid
59        @type xLimits: a tuple of (MinLongitude, MaxLongitude)
60        @keyword yLimits: (optional) latitude limits of the resulting grid
61        @type yLimits: a tuple of (MinLatitude, MaxLatitude)
62        @return: a grid built using the cdms variable data
63        @rtype: geoplot.Grid
64        """
65
66#        log.debug("building grid with cdms variable id = %s" % (self.cdmsVar.id,))
67#        log.debug("self.cdmsVar.getAxisIds() = %s" % (self.cdmsVar.getAxisIds(),))
68#        log.debug("self.cdmsVar.shape = %s" % (self.cdmsVar.shape,))
69
70       
71#        xmid, ymid = self._buildGridMidpoints(self.cdmsVar)
72       
73#        log.debug("y midpoints min =[" + str(ymid.min()) + \
74#                "] max =[" + str(ymid.max()) + "]")
75#        log.debug("x midpoints min =[" + str(xmid.min()) + \
76#                "] max =[" + str(xmid.max()) + "]")
77
78        varMin, varMax = self._getVarMinAndMax()
79
80        infValueFound, varMax = self._checkVarMaxForInf(varMax)
81         
82        st = time.time()
83        reducedVar = self._getResizedVar(xLimits, yLimits)
84        log.debug("Reduced variable in %ss" % (time.time() - st,))
85
86        st = time.time()   
87        (gridBoundsX, gridBoundsY) = self._buildGridBounds(reducedVar)
88        log.debug("built bounds in %ss" % (time.time() - st,))
89       
90        st = time.time()   
91        (gridMidpointX, gridMidpointY) = self._buildGridMidpoints(reducedVar)
92        log.debug("built midpoints in %ss" % (time.time() - st,))
93       
94        st = time.time()   
95        gridValues = self._buildGridValues(reducedVar)
96        log.debug("built values in %ss" % (time.time() - st,))
97       
98        if infValueFound:
99            # if there was an inf value found in the full variable
100            # there may be one in the reduced one, so mask any found
101            gridValues = self._maskInfValsInVar(gridValues)
102                       
103#        log.debug("After resize:")
104#        log.debug("y midpoints min =[" + str(gridMidpointY.min()) + \
105#                "] max =[" + str(gridMidpointY.max()) + "]")
106#        log.debug("x midpoints min =[" + str(gridMidpointX.min()) + \
107#                "] max =[" + str(gridMidpointX.max()) + "]")
108#       
109#        log.debug("Diff: y [" + str(gridMidpointY.min() - ymid.min()) + \
110#                  "][" + str(gridMidpointY.max() - ymid.max()) )
111#        log.debug("Diff: x [" + str(gridMidpointX.min() - xmid.min()) + \
112#                  "][" + str(gridMidpointX.max() - xmid.max()) +"]")
113#       
114#        log.debug("self.cdmsVar.shape = %s" % (self.cdmsVar.shape,))
115#        log.debug("gridValues.shape = %s" % (gridValues.shape,))
116#        log.debug("gridBoundsX.shape = %s" % (gridBoundsX.shape,))
117
118        return Grid(gridBoundsX, gridBoundsY, gridMidpointX, gridMidpointY, gridValues,
119                    varMax, varMin)
120
121    def _checkVarMaxForInf(self, varMax):
122        """
123        Checks if varMax is an infinite value, if it is will compute the
124        non-infitite maximum.
125        """
126       
127        infValueFound = False
128       
129        if not hasattr(varMax, 'mask') and varMax == N.inf:
130            infValueFound = True
131           
132            # need to make sure we get the max from the non-reduced variable
133            varMax =  self._maskInfValsInVar(self.cdmsVar).max()
134           
135        return infValueFound, varMax       
136
137    def _getResizedVar(self, xLimits, yLimits):
138        """
139        Replaces any None's in the limits with the max/min midpoint values then
140        returns the reduced variable from self.resizeVar().
141       
142        If both of these limits are None then the original self.cdmsVar will be
143        returned. If any part of these limits is None then it will be replaced
144        with the maximum or minimum midpoint value from the self.cdms variable.
145       
146        After filling any Nones in the limits this method calls the
147        self._resizeVar method to do the resizing.
148       
149        @param xLimits: (optional) longitude limits of the resulting grid
150        @type xLimits: a tuple of (MinLongitude, MaxLongitude)
151        @param yLimits: (optional) latitude limits of the resulting grid
152        @type yLimits: a tuple of (MinLatitude, MaxLatitude)
153        @return: A variable containing the subset of data
154        @rtype: cdms.variable
155        """
156       
157        if xLimits == None: xLimits = (None, None)
158        if yLimits == None: yLimits = (None, None)
159       
160        if xLimits == (None, None) and yLimits == (None, None):
161            reducedVar = self.cdmsVar
162        else:
163            if None in xLimits or None in yLimits:
164                xLimits, yLimits = self._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
165           
166#            log.debug("limits:" + str(xLimits) + ", " + str(yLimits))
167            reducedVar = self._resizeVar(xLimits, yLimits)
168           
169        return reducedVar
170       
171    def _resizeVar(self, xLimits, yLimits):
172        """
173        Returns a cdms variable that contains the subset of self.cdmsVar that is between
174        the limits.
175       
176        @param xLimits: (optional) longitude limits of the resulting grid
177        @type xLimits: a tuple of (MinLongitude, MaxLongitude)
178        @param yLimits: (optional) latitude limits of the resulting grid
179        @type yLimits: a tuple of (MinLatitude, MaxLatitude)
180        @return: A variable containing the subset of data
181        @rtype: cdms.variable
182        """
183        raise NotImplementedError
184
185    def _buildGridBounds(self, cdmsVar):
186        """
187        Builds two 2d Numpy array of the x and y positions of the grid
188        boundaries.
189       
190        Returns one array for the x bounds and one for the y bound. These
191        arrays contain the lower boundary for a given grid box.
192       
193        e.g. xBounds[x,y] will give the lower x boundary for the gridbox x,y.
194       
195        @param cdmsVar: the variable to extract the boundry data from
196        @type cdmsVar: cdms.variable
197        @return: the position data as (xPositions, yPositions)
198        @rtype: 2 Numpy Arrays
199       
200        """
201        raise NotImplementedError
202
203    def _buildGridMidpoints(self, cdmsVar):
204        """
205        Builds two 2d Numpy arrays for the x and y midpoints of a given grid
206        box, this position corresponds to the location of the grid box value.
207       
208        Returns one array for the midpoint x position and another for the y
209        postion.
210       
211        e.g. xMidpoints[x,y] will give the midpoint of the gridbox x,y and the
212        position of the ascociated measurment (value[x,y]).
213       
214        @param cdmsVar: the variable to extract the midpoint data from
215        @type cdmsVar: cdms.variable
216        @return: the midpoint data  as (xMidpoints, yMidpoints)
217        @rtype: 2 Numpy Arrays
218        """
219        raise NotImplementedError
220   
221    @staticmethod
222    def _getBoundsFromAxis(axis):
223        """
224        returns the bounding grid for the values in an axis, These bounds will
225        be retrieved form the cdms variable if present or generated.
226
227        The bounds array is
228
229        @param axis: the axis to get bounds for
230        @type axis: cdms.axis.Axis
231        @return: the axis bounds
232        @rtype: numpy.aarray
233        """
234        if axis.getBounds() == None:
235            axisBounds = GridBuilderBase._createBoundsFormList(axis.getValue())
236        else:
237            axisBounds = GridBuilderBase._mergeBounds(axis.getBounds())
238
239        return axisBounds
240   
241    @staticmethod
242    def _createBoundsFormList(values):
243        """
244        Creates a list of boundries from a given list of vlaues.
245
246        These bounding values are created by halfing the distance between the first
247        two items in the list and then adding this value to every other item in the list.
248        The first bounding value is the first item in the list minus this shift.
249
250        @params values: a list of values to create bounds from
251        @type values: a list of int or float
252        """
253        bounds = []
254        shift = (values[1] - values[0])/2
255        bounds.append(values[0] - shift)
256        for item in values:
257            bounds.append(item + shift)
258
259        return N.array(bounds)
260
261    @staticmethod
262    def _mergeBounds(bounds):
263        """
264        Folds a bounds array of shape (x, 2) into a 1D array of shape (x + 1,).
265   
266        We assume that grid boxes are contiguous.  I.e. the
267        right-hand edge of grid box (x, y) is the same as the left-hand
268        edge of grid box (x + 1, y) and similarly in y.
269
270        There are 4 different types of bounds that might be received:
271
272            1. Descending with each bound low-to-high
273            2. Descending with each bound high-to-low
274            3. Ascending with each bound low-to-high
275            4. Ascending with each bound high-to-low
276
277        In order to cater for these we just use all unique values in the bounds
278        (which is fine given that we assume grid boxes are contiguous) and sort them.
279        If bounds are in descending order then we reverse the 1D array so it is also
280        descending.
281 
282        @param lonBounds: The longitude bounds array
283        @param latBounds: the latitude bounds array
284        """
285
286        # Set whether descending
287        descending = False
288        if bounds[0, 0] > bounds[-1, 0]:
289            descending = True
290
291        unique_bounds = N.unique(bounds)
292               
293        # Check length is correct
294        if len(unique_bounds) != (bounds.shape[0] + 1):
295            log.error("Length of merged bounds ('%s') does not equal length of bounds + 1 ('%s')." % (len(unique_bounds), bounds.shape[0])) 
296
297        # Sort it and then reverse it if descending
298        unique_bounds.sort()         
299       
300        if descending == True:
301            merged = unique_bounds[::-1]
302        else:
303            merged = unique_bounds
304
305        return merged
306   
307   
308   
309   
310    @staticmethod
311    def _fillMissingLimitsFromArray(limits, array):
312       
313        newLimits = [limits[0], limits[1]]
314
315        if newLimits[0] == None:
316            newLimits[0] = array.min()
317           
318        if newLimits[1] == None:
319            newLimits[1] = array.max()
320
321        return tuple(newLimits)
322
323    def _replaceNoneInLimitsWithMaxMin(self, xLimits, yLimits):
324       
325        (newXLimits, newYLimits) = (xLimits, yLimits)
326        xBounds, yBounds = self._buildGridMidpoints(self.cdmsVar)
327        #log.debug("bounds" + str(yBounds))
328       
329        if None in xLimits:
330            axis = self.cdmsVar.getAxisList()[1]
331            newXLimits = GridBuilderBase._fillMissingLimitsFromArray(xLimits, xBounds)
332           
333        if None in yLimits:
334            axis = self.cdmsVar.getAxisList()[0]
335            newYLimits = GridBuilderBase._fillMissingLimitsFromArray(yLimits, yBounds)
336       
337        return (newXLimits, newYLimits)
338
339    def _checkVariable(self):
340        """
341        checks the cdms variable to make sure that it is suitable for plotting
342        """
343
344        self._checkVariableAxis()
345
346    def _checkVariableAxis(self):
347        """
348        Checks the axis on a given variable, if axes that arn't expected are
349        found then a warning is written to the log.
350        """
351
352        if len(self.cdmsVar.getAxisList()) > 2:
353            log.warning('cdms variable contains ' \
354                        +str(len(self.cdmsVar.getAxisList())) + ' axes.')
355
356        if self.cdmsVar.getTime() != None:
357            log.warning('cdms variable contains a time axes.')
358            #firstTime = self.cdmsVar.getTime().getValue()[0]
359            #self.cdmsVar = self.cdmsVar(time = firstTime, squeeze = 1)
360            #log.warning('using first time variable of ' + str(firstTime))
361
362        if self.cdmsVar.getLevel() != None:
363            log.warning('cdms variable contains a level axes.')
364            #firstLevel = self.cdmsVar.getLevel().getValue()[0]
365            #self.cdmsVar = self.cdmsVar(level = float(firstLevel), squeeze = 1)
366            #log.warning('using first level variable of ' + str(firstLevel))
367
368    def _buildGridValues(self,cdmsVar):
369        """
370        Builds a numpy array of values for each of the grid boxes.
371       
372        e.g. values[y,x] is the value for grid box y, x.
373        """
374        data = cdmsVar.getValue()
375        missing = cdmsVar.getMissing()
376       
377        # this data is extracted such that if the axis are in order x/y
378        # data[xIndex, yIndex] = value
379        # but if they are in y/x it is:
380        # data[yIndex, xIndex] = value
381        # this function needs to return the value in terms of data[y,x] for the
382        # imshow call to work.
383
384        if not self.__class__._areAxisInOrderYX(cdmsVar):
385            log.warning("axis are in order x,y. Swapping them.")
386            data = data.swapaxes(0,1)
387       
388        if missing == None:
389            missing = 1e20
390               
391        return MA.masked_values(data, missing)
392   
393    def _getVarMinAndMax(self):
394        st = time.time()
395        varMax = self.cdmsVar.max()
396        varMin = self.cdmsVar.min()
397               
398        log.debug("got min (%s) and max (%s) in %ss" % (varMin, varMax, time.time() - st,))
399        return varMin, varMax       
400   
401    def _maskInfValsInVar(self, var):
402       
403        #inds = N.argwhere(var == numpy.inf)
404       
405        return N.ma.masked_equal(var, N.inf)
406   
407def _logAxis(axis):
408    """
409    A function that writes the details of a particular axis to log.debug.
410    """
411
412    if axis.isLongitude() :
413        msg = 'Longitude Axis'
414    elif axis.isLatitude():
415        msg = 'Latitude Axis'
416    elif axis.isLevel():
417        msg = 'Level Axis'
418    elif axis.isTime():
419        msg = 'Time Axis'
420    else:
421        msg = 'Unknown axis'
422
423    log.debug(msg + ' ' + str(axis.id) + '(' + str(len(axis.getValue())) +') :' +\
424                str(axis.getValue()[0]) + ' - ' + str(axis.getValue()[-1]))
Note: See TracBrowser for help on using the repository browser.