source: qesdi/geoplot/trunk/lib/geoplot/grid_builder_base.py @ 5876

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/grid_builder_base.py@5876
Revision 5876, 12.9 KB checked in by pnorton, 10 years ago (diff)

Added a faster layer drawer for the grid, a new style of colour bar and colour schemes which now hold the normalisation and colour map instances.

Line 
1"""
2grid_builder_base.py
3====================
4
5Holds the GridBuilderBase class. A GridBuilder is an object that knows how to
6extract a lat-lon Grid object from a given cdmsVariable. This is an abstract
7baseclass and can't be used directly.
8
9"""
10
11#python modules
12import logging
13import time
14
15#third party modules
16
17#internal modules
18from geoplot.grid import Grid
19
20import numpy as N
21import numpy.ma as MA
22
23#set the log
24log = logging.getLogger(__name__)
25
26class GridBuilderBase(object):
27    """
28    Impliments the common GridBuilder functionality.
29
30    This is an abstract class. Any methods that iherit from this class need to
31    impliment the _resizeVar, _buildGridBounds, _buildGridMidpoints and the
32    _buildGridValues methods.
33    """
34
35    def __init__(self, cdmsVar):
36        """
37        constructs the grid builder object
38       
39        @param cdmsVar: the cdms variable that conains the grid data.
40        @type cdmsVar:a cdms.variable object 
41        """
42       
43        self.cdmsVar = cdmsVar
44        self._checkVariable()
45       
46#        for axis in self.cdmsVar.getAxisList():
47#            _logAxis(axis)
48
49    def buildGrid(self, xLimits=None, yLimits=None):
50        """
51        builds a grid object from the data in the cdmsVar ascociated with the
52        grid builder object.
53       
54        If xLimits or yLimits values are given then the resulting grid will
55        contain the portion of the data in the cdms variable that falls within
56        the limits given.
57       
58        @keyword xLimits: (optional) longitude limits of the resulting grid
59        @type xLimits: a tuple of (MinLongitude, MaxLongitude)
60        @keyword yLimits: (optional) latitude limits of the resulting grid
61        @type yLimits: a tuple of (MinLatitude, MaxLatitude)
62        @return: a grid built using the cdms variable data
63        @rtype: geoplot.Grid
64        """
65
66#        log.debug("building grid with cdms variable id = %s" % (self.cdmsVar.id,))
67#        log.debug("self.cdmsVar.getAxisIds() = %s" % (self.cdmsVar.getAxisIds(),))
68#        log.debug("self.cdmsVar.shape = %s" % (self.cdmsVar.shape,))
69
70       
71#        xmid, ymid = self._buildGridMidpoints(self.cdmsVar)
72       
73#        log.debug("y midpoints min =[" + str(ymid.min()) + \
74#                "] max =[" + str(ymid.max()) + "]")
75#        log.debug("x midpoints min =[" + str(xmid.min()) + \
76#                "] max =[" + str(xmid.max()) + "]")
77
78        st = time.time()
79        reducedVar = self._getResizedVar(xLimits, yLimits)
80        log.debug("Reduced variable in %ss" % (time.time() - st,))
81       
82        st = time.time()   
83        (gridBoundsX, gridBoundsY) = self._buildGridBounds(reducedVar)
84        log.debug("built bounds in %ss" % (time.time() - st,))
85       
86        st = time.time()   
87        (gridMidpointX, gridMidpointY) = self._buildGridMidpoints(reducedVar)
88        log.debug("built midpoints in %ss" % (time.time() - st,))
89       
90        st = time.time()   
91        gridValues = self._buildGridValues(reducedVar)
92        log.debug("built values in %ss" % (time.time() - st,))
93       
94               
95#        log.debug("After resize:")
96#        log.debug("y midpoints min =[" + str(gridMidpointY.min()) + \
97#                "] max =[" + str(gridMidpointY.max()) + "]")
98#        log.debug("x midpoints min =[" + str(gridMidpointX.min()) + \
99#                "] max =[" + str(gridMidpointX.max()) + "]")
100#       
101#        log.debug("Diff: y [" + str(gridMidpointY.min() - ymid.min()) + \
102#                  "][" + str(gridMidpointY.max() - ymid.max()) )
103#        log.debug("Diff: x [" + str(gridMidpointX.min() - xmid.min()) + \
104#                  "][" + str(gridMidpointX.max() - xmid.max()) +"]")
105#       
106#        log.debug("self.cdmsVar.shape = %s" % (self.cdmsVar.shape,))
107#        log.debug("gridValues.shape = %s" % (gridValues.shape,))
108#        log.debug("gridBoundsX.shape = %s" % (gridBoundsX.shape,))
109
110        return Grid(gridBoundsX, gridBoundsY, gridMidpointX, gridMidpointY, gridValues,
111                    self.cdmsVar[:].max(), self.cdmsVar[:].min())
112
113    def _getResizedVar(self, xLimits, yLimits):
114        """
115        Replaces any None's in the limits with the max/min midpoint values then
116        returns the reduced variable from self.resizeVar().
117       
118        If both of these limits are None then the original self.cdmsVar will be
119        returned. If any part of these limits is None then it will be replaced
120        with the maximum or minimum midpoint value from the self.cdms variable.
121       
122        After filling any Nones in the limits this method calls the
123        self._resizeVar method to do the resizing.
124       
125        @param xLimits: (optional) longitude limits of the resulting grid
126        @type xLimits: a tuple of (MinLongitude, MaxLongitude)
127        @param yLimits: (optional) latitude limits of the resulting grid
128        @type yLimits: a tuple of (MinLatitude, MaxLatitude)
129        @return: A variable containing the subset of data
130        @rtype: cdms.variable
131        """
132       
133        if xLimits == None: xLimits = (None, None)
134        if yLimits == None: yLimits = (None, None)
135       
136        if xLimits == (None, None) and yLimits == (None, None):
137            reducedVar = self.cdmsVar
138        else:
139            if None in xLimits or None in yLimits:
140                xLimits, yLimits = self._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
141           
142#            log.debug("limits:" + str(xLimits) + ", " + str(yLimits))
143            reducedVar = self._resizeVar(xLimits, yLimits)
144           
145        return reducedVar
146       
147    def _resizeVar(self, xLimits, yLimits):
148        """
149        Returns a cdms variable that contains the subset of self.cdmsVar that is between
150        the limits.
151       
152        @param xLimits: (optional) longitude limits of the resulting grid
153        @type xLimits: a tuple of (MinLongitude, MaxLongitude)
154        @param yLimits: (optional) latitude limits of the resulting grid
155        @type yLimits: a tuple of (MinLatitude, MaxLatitude)
156        @return: A variable containing the subset of data
157        @rtype: cdms.variable
158        """
159        raise NotImplementedError
160
161    def _buildGridBounds(self, cdmsVar):
162        """
163        Builds two 2d Numpy array of the x and y positions of the grid
164        boundaries.
165       
166        Returns one array for the x bounds and one for the y bound. These
167        arrays contain the lower boundary for a given grid box.
168       
169        e.g. xBounds[x,y] will give the lower x boundary for the gridbox x,y.
170       
171        @param cdmsVar: the variable to extract the boundry data from
172        @type cdmsVar: cdms.variable
173        @return: the position data as (xPositions, yPositions)
174        @rtype: 2 Numpy Arrays
175       
176        """
177        raise NotImplementedError
178
179    def _buildGridMidpoints(self, cdmsVar):
180        """
181        Builds two 2d Numpy arrays for the x and y midpoints of a given grid
182        box, this position corresponds to the location of the grid box value.
183       
184        Returns one array for the midpoint x position and another for the y
185        postion.
186       
187        e.g. xMidpoints[x,y] will give the midpoint of the gridbox x,y and the
188        position of the ascociated measurment (value[x,y]).
189       
190        @param cdmsVar: the variable to extract the midpoint data from
191        @type cdmsVar: cdms.variable
192        @return: the midpoint data  as (xMidpoints, yMidpoints)
193        @rtype: 2 Numpy Arrays
194        """
195        raise NotImplementedError
196   
197    @staticmethod
198    def _getBoundsFromAxis(axis):
199        """
200        returns the bounding grid for the values in an axis, These bounds will
201        be retrieved form the cdms variable if present or generated.
202
203        The bounds array is
204
205        @param axis: the axis to get bounds for
206        @type axis: cdms.axis.Axis
207        @return: the axis bounds
208        @rtype: numpy.aarray
209        """
210        if axis.getBounds() == None:
211            axisBounds = GridBuilderBase._createBoundsFormList(axis.getValue())
212        else:
213            axisBounds = GridBuilderBase._mergeBounds(axis.getBounds())
214
215        return axisBounds
216   
217    @staticmethod
218    def _createBoundsFormList(values):
219        """
220        Creates a list of boundries from a given list of vlaues.
221
222        These bounding values are created by halfing the distance between the first
223        two items in the list and then adding this value to every other item in the list.
224        The first bounding value is the first item in the list minus this shift.
225
226        @params values: a list of values to create bounds from
227        @type values: a list of int or float
228        """
229        bounds = []
230        shift = (values[1] - values[0])/2
231        bounds.append(values[0] - shift)
232        for item in values:
233            bounds.append(item + shift)
234
235        return N.array(bounds)
236
237    @staticmethod
238    def _mergeBounds(bounds):
239        """
240        Folds a bounds array of shape (x, 2) into a 1D array of shape (x + 1,).
241   
242        We assume that grid boxes are contiguous.  I.e. the
243        right-hand edge of grid box (x, y) is the same as the left-hand
244        edge of grid box (x + 1, y) and similarly in y.
245   
246        @param lonBounds: The longitude bounds array
247        @param latBounds: the latitude bounds array
248        """
249   
250        # Get grid dimensions
251        n = bounds.shape[0]
252   
253        # Take the lower bounds as the mesh point
254        # except for the last index where the upper bounds is taken
255        merged = N.resize(bounds[:, 0], (n + 1,))
256        merged[-1] = bounds[-1, 1]
257   
258        return merged
259   
260   
261   
262   
263    @staticmethod
264    def _fillMissingLimitsFromArray(limits, array):
265       
266        newLimits = [limits[0], limits[1]]
267
268        if newLimits[0] == None:
269            newLimits[0] = array.min()
270           
271        if newLimits[1] == None:
272            newLimits[1] = array.max()
273
274        return tuple(newLimits)
275
276    def _replaceNoneInLimitsWithMaxMin(self, xLimits, yLimits):
277       
278        (newXLimits, newYLimits) = (xLimits, yLimits)
279        xBounds, yBounds = self._buildGridMidpoints(self.cdmsVar)
280        #log.debug("bounds" + str(yBounds))
281       
282        if None in xLimits:
283            axis = self.cdmsVar.getAxisList()[1]
284            newXLimits = GridBuilderBase._fillMissingLimitsFromArray(xLimits, xBounds)
285           
286        if None in yLimits:
287            axis = self.cdmsVar.getAxisList()[0]
288            newYLimits = GridBuilderBase._fillMissingLimitsFromArray(yLimits, yBounds)
289       
290        return (newXLimits, newYLimits)
291
292    def _checkVariable(self):
293        """
294        checks the cdms variable to make sure that it is suitable for plotting
295        """
296
297        self._checkVariableAxis()
298
299    def _checkVariableAxis(self):
300        """
301        Checks the axis on a given variable, if axes that arn't expected are
302        found then a warning is written to the log.
303        """
304
305        if len(self.cdmsVar.getAxisList()) > 2:
306            log.warning('cdms variable contains ' \
307                        +str(len(self.cdmsVar.getAxisList())) + ' axes.')
308
309        if self.cdmsVar.getTime() != None:
310            log.warning('cdms variable contains a time axes.')
311            #firstTime = self.cdmsVar.getTime().getValue()[0]
312            #self.cdmsVar = self.cdmsVar(time = firstTime, squeeze = 1)
313            #log.warning('using first time variable of ' + str(firstTime))
314
315        if self.cdmsVar.getLevel() != None:
316            log.warning('cdms variable contains a level axes.')
317            #firstLevel = self.cdmsVar.getLevel().getValue()[0]
318            #self.cdmsVar = self.cdmsVar(level = float(firstLevel), squeeze = 1)
319            #log.warning('using first level variable of ' + str(firstLevel))
320
321    def _buildGridValues(self,cdmsVar):
322        """
323        Builds a numpy array of values for each of the grid boxes.
324       
325        e.g. values[y,x] is the value for grid box y, x.
326        """
327        data = cdmsVar.getValue()
328        missing = cdmsVar.getMissing()
329       
330        # this data is extracted such that if the axis are in order x/y
331        # data[xIndex, yIndex] = value
332        # but if they are in y/x it is:
333        # data[yIndex, xIndex] = value
334        # this function needs to return the value in terms of data[y,x] for the
335        # imshow call to work.
336
337        if not self.__class__._areAxisInOrderYX(cdmsVar):
338            log.warning("axis are in order y,x. Swapping them.")
339            data = data.swapaxes(0,1)
340                 
341        if missing == None:
342            missing = 1e20
343               
344        return MA.masked_values(data, missing)
345       
346def _logAxis(axis):
347    """
348    A function that writes the details of a particular axis to log.debug.
349    """
350
351    if axis.isLongitude() :
352        msg = 'Longitude Axis'
353    elif axis.isLatitude():
354        msg = 'Latitude Axis'
355    elif axis.isLevel():
356        msg = 'Level Axis'
357    elif axis.isTime():
358        msg = 'Time Axis'
359    else:
360        msg = 'Unknown axis'
361
362    log.debug(msg + ' ' + str(axis.id) + '(' + str(len(axis.getValue())) +') :' +\
363                str(axis.getValue()[0]) + ' - ' + str(axis.getValue()[-1]))
Note: See TracBrowser for help on using the repository browser.