source: qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py @ 5636

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py@5636
Revision 5636, 16.8 KB checked in by pnorton, 11 years ago (diff)

Fixed a problem with imshow drawing a grid when there was no data to draw. Also fixed some of the tests.

Line 
1"""
2grid_drawer.py
3============
4
5A GridDrawere knows how to draw a lat-lon grid onto a axes so that it will match
6with a basemap drawing on the same axis.
7
8"""
9#python modules
10import logging
11import operator
12import re
13import time
14from math import *
15
16#third party modules
17from geoplot.mpl_imports import basemap
18import matplotlib
19import matplotlib.colors
20import matplotlib.cm
21import matplotlib.collections
22import numpy as N
23import numpy.ma as MA
24import geoplot.mpl_imports
25
26#internal modules
27from geoplot.array_util import *
28from geoplot.grid import Grid
29from geoplot.grid_builder_base import GridBuilderBase
30import geoplot.utils as geoplot_utils
31
32#setup the logging
33log = logging.getLogger(__name__)
34
35class GridDrawer(object):
36    """
37    Responsible for knwoing how to draw a Grid object onto a given matplotlib
38    axes
39    """
40
41    def __init__(self, drawValues=False, drawValuesFile=False, 
42                 valueFormat="%.2f", showGridLines=False,
43                 outline=False):
44        """
45        Constructs a GridDrawer object             
46        """
47        self.drawValues = drawValues
48        self.drawValuesFile = drawValuesFile
49        self.valueFormat = valueFormat
50        self.showGridLines = showGridLines
51        self.outline = outline
52
53    def draw(self, axes, grid, limits=None, basemap=None, norm=None,
54             cmap=None, fontSize='medium' ):
55        """
56        Draws the grid given on the axis.
57       
58        @param axes: the axes the grid will be drawn on
59        @type axes: matplotlib.axes
60        @param grid: the grid to be drawn
61        @type grid: geoplot.grid
62        @keyword limits: the limits of the drawing (in lat lon)
63        @type limits: tuple
64        @keyword basemap: the basemap instance to scale the grid values to be
65            drawn on the axis.
66        """
67
68        if cmap == None:
69            cmap = matplotlib.cm.get_cmap()
70            cmap.set_bad("w")   
71       
72        if norm == None:
73            norm = matplotlib.colors.Normalize(grid.values.min(), grid.values.max())     
74
75        xLimits, yLimits = self._transformLimits(limits, basemap)
76       
77        # check that the grid being drawn has data inside the limits, if there
78        # is no data in the limits there is no point drawing anything.
79        # can only check the grid is in the limits if they are provided
80        if limits != None:
81            if not self._isGridDataWithinLimits(grid, xLimits, yLimits):
82                return
83               
84        grid_mapUnits = grid.transform(basemap)
85
86        self._drawGridBoxes(axes, grid_mapUnits, basemap, cmap, norm)
87       
88#        log.debug("self.drawValues = %s" % (self.drawValues,))
89#        log.debug("self.drawValuesFile = %s" % (self.drawValuesFile,))
90
91        if self.drawValues == True:
92            log.debug('drawing grid values')
93           
94            #draw the vlaues onto the grid
95            self._drawMeshValues(axes, grid_mapUnits, xLimits, yLimits)
96       
97
98    def _transformLimits(self, limits, basemap):
99        if limits == None:
100            return (None, None)
101        elif basemap == None:
102            return (limits[0], limits[1])
103        else:
104            return basemap(limits[0], limits[1])
105
106    def _drawGridBoxes(self, axes,  grid_mapUnits, basemap, cmap, norm):
107        """
108        Draws the grid boxes onto the axes.
109        """
110
111        gridColour = 'k'
112        gridLineWidth = 0.25
113
114        kwargs = {'cmap':cmap, 'norm':norm}
115       
116        if self.outline == True :
117            # this will draw an outline on all grid cells that don't contain a
118            # missing value
119            kwargs['edgecolors'] = gridColour
120        else:
121            # needs to be a string rather than None or the rc default will be used
122            kwargs['edgecolors'] = 'None'
123
124
125        if geoplot.mpl_imports.oldMatplotlib == False:
126            # in the newer version of matplotlib you can set the linewidth +
127            # linestyle on the pcolor() function
128            kwargs['linewidth'] = gridLineWidth
129            kwargs['linestyle'] = 'solid'
130            kwargs['antialiased'] = True
131
132        #check the values aren't all masked
133       
134        # Its possible that this test becomes time consuming for large arrays,
135        # it is pssible that using sum might be faster, but this will only work
136        # if True + True = 2.
137        # valuesFound = not grid.values.mask.sum() == len(grid.values.mask.flatten())
138       
139        # for now this is so quick it doesn't matter.
140        valuesFound = not reduce(operator.and_, grid_mapUnits.values.mask.flat)
141
142       
143        #if all the values are masked make sure the outline is drawn
144        if self.showGridLines == True:
145            st = time.time()
146            # this will draw all the grid cells even if their contents is missing
147            self._drawGridAsLines(axes, grid_mapUnits.boundsX, grid_mapUnits.boundsY, 
148                           color=gridColour, linewidth=gridLineWidth)
149            log.debug("drawn grid lines in = %s" % (time.time() - st,))
150       
151        if valuesFound == True:
152           
153            st = time.time()
154           
155            log.debug("grid_mapUnits.boundsX.shape = %s" % (grid_mapUnits.boundsX.shape,))
156            log.debug("grid_mapUnits.boundsY.shape = %s" % (grid_mapUnits.boundsY.shape,))
157            log.debug("grid_mapUnits.values.shape = %s" % (grid_mapUnits.values.shape,))
158           
159            values = grid_mapUnits.values
160           
161            assert grid_mapUnits.boundsX.shape == grid_mapUnits.boundsY.shape
162           
163            boundsShape = grid_mapUnits.boundsX.shape
164           
165            #try to fix the values if they are not the expected shape
166            if values.shape[0] != boundsShape[0] -1 and values.shape[1] != boundsShape[1] -1:
167               
168                #check if the axis were in the opposite order
169                if values.shape[1] == boundsShape[0] -1 and values.shape[0] == boundsShape[1] -1:
170                    values = N.transpose(values)           
171
172            #if the bounds are parralel then use imshow, otherwise stick with pcolormesh
173            if self._areBoundsParralel(grid_mapUnits.boundsX, grid_mapUnits.boundsY):
174               
175                self._drawImshow(axes, 
176                                 grid_mapUnits.boundsX, 
177                                 grid_mapUnits.boundsY, 
178                                 values, 
179                                 kwargs)
180                log.debug("drawn imshow in %s" % (time.time() - st ,))
181               
182            else:
183                self._drawPcolormesh(axes, 
184                                     grid_mapUnits.boundsX, 
185                                     grid_mapUnits.boundsY, 
186                                     values, 
187                                     kwargs)
188                log.debug("drawn mesh in %s" % (time.time() - st ,))
189           
190           
191
192
193    def _drawPcolormesh(self, axes, X, Y, Z, kwargs):
194        res = axes.pcolormesh(X, Y, Z, **kwargs)
195        return res
196   
197    def _drawImshow(self, axes, X, Y, Z, kwargs):
198        yStart = Y[0,0]
199        yEnd = Y[-1,-1]
200       
201        log.debug("yStart = %s" % (yStart,))
202        log.debug("yEnd = %s" % (yEnd,))
203       
204        if yStart < yEnd:
205            extent=(X.min(), X.max(),Y.max(), Y.min())               
206        else:
207            extent=(X.min(), X.max(),Y.min(), Y.max())
208       
209        kwargs.pop('edgecolors')
210        kwargs.pop('linestyle')
211        kwargs.pop('linewidth')
212        kwargs.pop('antialiased')
213       
214        log.debug("extent = %s" % (extent,))
215        log.debug("Z.shape = %s" % (Z.shape,))
216       
217        im = axes.imshow(Z, extent=extent,
218             interpolation='nearest',
219             **kwargs
220        )
221                   
222        axes.set_aspect('auto')
223       
224    def _areBoundsParralel(self, X, Y ):
225        """
226        checks if all the bound are parallel, this is a test to see if using
227        imshow is possible.
228       
229        Checks the equality to 5 significant figures.
230        """
231       
232        st = time.time()
233        equal = True
234       
235        for i in range(1, X.shape[0]):
236           
237            equal = N.allclose(X[0], X[i])
238           
239            if not equal:
240                break
241
242        xEq = equal
243       
244        if equal:
245           
246            for i in range(1, Y.shape[1]):
247               
248                equal = N.allclose( Y[:,0], Y[:,i])
249               
250                if not equal:
251                    break
252
253        yEq = equal
254        log.debug("xEq = %s, yEq = %s worked out in = %s" % (xEq, yEq, time.time() - st,))
255       
256        return equal
257           
258
259    def _drawGridAsLines(self, axes, X, Y, color='0.25', linewidth=0.3):
260        """
261        Draws a grid as a series of lines onto the given axis. This will work as long as the grid is
262        made up of quadrilaterals.
263        """
264        Ny, Nx = X.shape
265   
266        points=[]
267        # create a list of points to build up the line collection
268        # assuming the grid is filled with quadrelaterals this should
269        # draw it
270        for y in range(Ny):
271            for x in range(Nx):
272                x1 = X[y,x]
273                y1 = Y[y,x]
274   
275                # if this isn't the last intersection in the x direction, draw a line
276                # between this intersection and the next one
277                if x+1 < Nx:
278                    x2 = X[y,x+1]
279                    y2 = Y[y,x+1]
280                    points.append( ((x1, y1), (x2,y2), (x2,y2)) )
281
282                # if this isn't the last intersection in the y direction, draw a line
283                # between this intersection and the next one   
284                if y+1 < Ny:
285                    x3 = X[y+1,x]
286                    y3 = Y[y+1,x]
287                    points.append( ((x1, y1), (x3,y3), (x3,y3)) )
288   
289
290        line_segments = matplotlib.collections.LineCollection(points,
291                                       linewidths    = (linewidth,),
292                                       colors    = (color),
293                                       linestyle = 'solid')
294       
295        axes.add_collection(line_segments)
296
297
298    def _drawMeshValues(self, axes, grid_mapUnits, xLimits, yLimits):
299        """
300        Draw the values of each grid box onto the axes.
301        """
302       
303        # Parse drawValuesFile if specified:
304        if self.drawValuesFile != False:
305            drawValuesList = self._parseDrawValuesFile()
306        else:
307            drawValuesList = False
308
309        ni, nj = N.shape(grid_mapUnits.midpointsX)
310       
311        if xLimits != None:
312            xRangeMaximum = max(xLimits)
313            xRangeMinumum = min(xLimits)
314           
315        if yLimits != None:
316            yRangeMaximum = max(yLimits)
317            yRangeMinumum = min(yLimits)
318
319        for j in range(nj):
320            for i in range(ni):
321
322                # Check if we need to draw this value
323                if drawValuesList != False:
324                    if [i, j] not in drawValuesList: continue
325
326                x = grid_mapUnits.midpointsX[i][j]
327                y = grid_mapUnits.midpointsY[i][j]
328
329                v = grid_mapUnits.values[i,j]
330
331                #skip masked array values:
332                if v.__class__ == MA.MaskedArray:
333                    continue
334
335                # Avoid missing values with float conversion test
336                try:
337                    v = float(v)
338                except:
339                    log.warning("value %s (class=%s) is not a float" , v, v.__class__)
340                    continue
341
342                #get the bounds of the box (in map units), around the current position
343                xBounds, yBounds = self._getBounds(i, j, grid_mapUnits.boundsX, grid_mapUnits.boundsY)
344
345                xTop = max(xBounds)
346                yTop = max(yBounds)
347                xBottom = min(xBounds)
348                yBottom = min(yBounds)
349               
350                if xLimits != None:
351                    if xTop > xRangeMaximum or xBottom < xRangeMinumum:
352                        continue
353
354                if yLimits != None:
355                    if yTop > yRangeMaximum or yBottom < yRangeMinumum:
356                        continue
357
358                axes.text(float(x), float(y), self.valueFormat % v,
359                          ha='center',
360                          va='center',
361                          family='sans-serif',
362                          size=8)
363
364    def _getBounds(self, i, j, xMesh, yMesh):
365        """
366        Gets the points of each of the four corners for a given position in the mesh
367
368        @param i: the x index of the vlaue in the mesh
369        @type i: int
370        @param j: the x index of the vlaue in the mesh
371        @type j: int
372        @param xMesh: the array of the cdmsVar value x bounds that makes up the grid
373        @type xMesh: an instance of the numpy.ndarray class
374        @param yMesh: the array of the cdmsVar value y bounds that makes up the grid
375        @type yMesh: an instance of the numpy.ndarray class
376        """
377
378        coords = ((i, j), (i, j + 1), (i + 1, j), (i + 1, j + 1))
379
380
381        f = lambda array: lambda c: array[c]
382        xBounds = map(f(xMesh), coords)
383        yBounds = map(f(yMesh), coords)
384
385        return (xBounds, yBounds)
386
387    def _parseDrawValuesFile(self):
388        """
389        Attempts to parse the CSV file containing index values of y,x per line.
390        Validates file or raises error.
391        On successful validation returns a list of (y_index, x_index) pairs.
392
393        """
394
395        # Safely read file
396        try:
397            valuesFile = open(self.drawValuesFile)
398            lines = valuesFile.readlines()
399            valuesFile.close()
400        except:
401            raise Exception("Unable to parse 'drawValuesFile': " + str(self.drawValuesFile))
402
403        # Safely parse file contents
404        linePattern = re.compile("^(\d+),(\d+)$")
405        ln = 0
406        cellList = []
407
408        for line in lines:
409            ln += 1
410            line = line.strip()
411            match = linePattern.match(line)
412            if not match:
413                raise Exception("Unable to parse 'drawValuesFile' at line number " + str(ln) + ": " + str(self.drawValuesFile))
414            indices = [int(i) for i in match.groups()]
415            cellList.append(indices)
416
417        return cellList
418   
419    def _isGridDataWithinLimits(self, grid, xLimits, yLimits):
420       
421        xRange = (grid.boundsX.min(), grid.boundsX.max())
422        yRange = (grid.boundsY.min(), grid.boundsY.max())
423       
424        log.debug("xRange = %s" % (xRange,))
425        log.debug("yRange = %s" % (yRange,))
426       
427        isInLimits = geoplot_utils.isRangeInLimits(xRange, xLimits) and \
428                     geoplot_utils.isRangeInLimits(yRange, yLimits)
429       
430        return isInLimits
431   
432if __name__ == '__main__':
433    import geoplot.log_util as log_util
434    log_util.setupGeoplotConsoleHandler(log)
435
436    import pkg_resources, os
437    outputsDir = os.path.abspath(pkg_resources.resource_filename('geoplot', '../../outputs'))   
438
439    xLimits=(-13.0, 6.2) ; yLimits=(47.0, 61.0)
440
441    from numpy import meshgrid
442    from cdms_utils.cdms_compat import N, MA
443
444    x = [-6.,-4.,-2., 0., 2.]
445    xMid = [-5, -3, -1, 1]
446    xlim = len(x) -1
447    y = [54.,55.,56.,58.]
448    yMid = [54.5, 55.5, 57]
449    ylim = len(y) - 1
450    X,Y = meshgrid(x,y)
451   
452    X[0,0] = -7
453    Y[0,0] = 53
454    X[ylim, xlim] = 3
455    Y[ylim, xlim] = 59
456       
457    XMid, YMid = meshgrid(xMid, yMid)
458
459    Z=[[0]*xlim for i in range(ylim)]
460    for i in range(0 , xlim*ylim):
461        vX= i%xlim
462        vY= (i - i%xlim) / xlim
463        Z[vY][vX] = i
464    #Z = N.array(Z)
465    Z = MA.masked_values(Z, 2)
466    Z = MA.masked_values(Z, 6)
467       
468    grid = Grid(X,Y,XMid, YMid, Z)
469
470    from geoplot.mpl_imports import basemap
471    from matplotlib.backends.backend_agg import FigureCanvasAgg
472    from matplotlib.figure import Figure
473    from matplotlib import cm
474
475    figsize=(600 / 80, 800 / 80)
476    fig = Figure(figsize=figsize, dpi=80)
477    axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])
478
479
480    bm = basemap.Basemap(llcrnrlon=xLimits[0],
481                       llcrnrlat=yLimits[0],
482                       urcrnrlon=xLimits[1],
483                       urcrnrlat=yLimits[1],
484                       resolution='h',
485                       suppress_ticks=False)
486    cmap = cm.winter
487    cmap.set_bad("w")   
488    drawer = GridDrawer(cmap, drawValues=True, valueFormat="%0.2f", showGridLines=False,
489                        outline=False)
490       
491    sm = drawer.draw(axes, grid, basemap=bm)
492   
493    bm.drawcoastlines(ax = axes)
494    bm.drawrivers(ax = axes, color='b')
495    bm.drawmeridians([-10,-5,0,5,10], ax=axes, dashes=(None,None))
496    bm.drawparallels([50,55,60], ax=axes, dashes=(None,None))
497   
498   
499    cb = fig.colorbar(sm, orientation='vertical')
500    print "tick position", cb.ax.yaxis.get_ticks_position()
501    cb.ax.yaxis.set_ticks_position('default')
502    print "tick position", cb.ax.yaxis.get_ticks_position()   
503
504
505    canvas = FigureCanvasAgg(fig)
506    filename = "grid_drawer_test.png"
507   
508    fullPath = os.path.join(outputsDir, filename)
509    canvas.print_figure(fullPath)
510    log.debug("Wrote " + fullPath)
511
Note: See TracBrowser for help on using the repository browser.