source: qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py @ 5612

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py@5612
Revision 5612, 16.3 KB checked in by pnorton, 12 years ago (diff)

Have added some code that allows the grid drawer to decide if it should use imshow or pcolormesh to render the grid. This was done to get arround the problem of the grid cells having outlines when pcolormesh is used.

Also implemented a map function that chooses which basemap resolution to use if one isn't set.

Line 
1"""
2grid_drawer.py
3============
4
5A GridDrawere knows how to draw a lat-lon grid onto a axes so that it will match
6with a basemap drawing on the same axis.
7
8"""
9#python modules
10import logging
11import operator
12import re
13import time
14from math import *
15
16#third party modules
17from geoplot.mpl_imports import basemap
18import matplotlib
19import matplotlib.colors
20import matplotlib.cm
21import matplotlib.collections
22import numpy as N
23import numpy.ma as MA
24import geoplot.mpl_imports
25
26#internal modules
27from geoplot.array_util import *
28from geoplot.grid import Grid
29from geoplot.grid_builder_base import GridBuilderBase
30import geoplot.utils as geoplot_utils
31
32#setup the logging
33log = logging.getLogger(__name__)
34
35class GridDrawer(object):
36    """
37    Responsible for knwoing how to draw a Grid object onto a given matplotlib
38    axes
39    """
40
41    def __init__(self, drawValues=False, drawValuesFile=False, 
42                 valueFormat="%.2f", showGridLines=False,
43                 outline=False):
44        """
45        Constructs a GridDrawer object             
46        """
47        self.drawValues = drawValues
48        self.drawValuesFile = drawValuesFile
49        self.valueFormat = valueFormat
50        self.showGridLines = showGridLines
51        self.outline = outline
52
53    def draw(self, axes, grid, limits=None, basemap=None, norm=None,
54             cmap=None, fontSize='medium' ):
55        """
56        Draws the grid given on the axis.
57       
58        @param axes: the axes the grid will be drawn on
59        @type axes: matplotlib.axes
60        @param grid: the grid to be drawn
61        @type grid: geoplot.grid
62        @keyword limits: the limits of the drawing (in lat lon)
63        @type limits: tuple
64        @keyword basemap: the basemap instance to scale the grid values to be
65            drawn on the axis.
66        """
67
68        if cmap == None:
69            cmap = matplotlib.cm.get_cmap()
70            cmap.set_bad("w")   
71       
72        if norm == None:
73            norm = matplotlib.colors.Normalize(grid.values.min(), grid.values.max())     
74
75#        log.info("Drawing Grid")
76
77        xLimits, yLimits = self._transformLimits(limits, basemap)
78       
79        grid_mapUnits = grid.transform(basemap)
80       
81       
82       
83       
84               
85#        log.debug("xLimits = %s" % (xLimits,))
86#        log.debug("yLimits = %s" % (yLimits,))
87
88        #draw the grid return the quadmesh for use
89        sm = self._drawGridBoxes(axes, grid_mapUnits, basemap, cmap, norm)
90       
91       
92       
93#        log.debug("self.drawValues = %s" % (self.drawValues,))
94#        log.debug("self.drawValuesFile = %s" % (self.drawValuesFile,))
95
96        if self.drawValues == True:
97            log.debug('drawing grid values')
98
99            #draw the vlaues onto the grid
100            self._drawMeshValues(axes, grid_mapUnits, xLimits, yLimits)
101       
102
103    def _transformLimits(self, limits, basemap):
104        if limits == None:
105            return (None, None)
106        elif basemap == None:
107            return (limits[0], limits[1])
108        else:
109            return basemap(limits[0], limits[1])
110
111    def _drawGridBoxes(self, axes,  grid_mapUnits, basemap, cmap, norm):
112        """
113        Draws the grid boxes onto the axes.
114        """
115
116        gridColour = 'k'
117        gridLineWidth = 0.25
118
119        kwargs = {'cmap':cmap, 'norm':norm}
120       
121        if self.outline == True :
122            # this will draw an outline on all grid cells that don't contain a
123            # missing value
124            kwargs['edgecolors'] = gridColour
125        else:
126            # needs to be a string rather than None or the rc default will be used
127            kwargs['edgecolors'] = 'None'
128
129
130        if geoplot.mpl_imports.oldMatplotlib == False:
131            # in the newer version of matplotlib you can set the linewidth +
132            # linestyle on the pcolor() function
133            kwargs['linewidth'] = gridLineWidth
134            kwargs['linestyle'] = 'solid'
135            kwargs['antialiased'] = True
136
137        #check the values aren't all masked
138       
139        # Its possible that this test becomes time consuming for large arrays,
140        # it is pssible that using sum might be faster, but this will only work
141        # if True + True = 2.
142        # valuesFound = not grid.values.mask.sum() == len(grid.values.mask.flatten())
143       
144        # for now this is so quick it doesn't matter.
145        valuesFound = not reduce(operator.and_, grid_mapUnits.values.mask.flat)
146
147       
148        #if all the values are masked make sure the outline is drawn
149        if self.showGridLines == True:
150            st = time.time()
151            # this will draw all the grid cells even if their contents is missing
152            self._drawGridAsLines(axes, grid_mapUnits.boundsX, grid_mapUnits.boundsY, 
153                           color=gridColour, linewidth=gridLineWidth)
154            log.debug("drawn grid lines in = %s" % (time.time() - st,))
155       
156        if valuesFound == True:
157           
158            st = time.time()
159           
160            log.debug("grid_mapUnits.boundsX.shape = %s" % (grid_mapUnits.boundsX.shape,))
161            log.debug("grid_mapUnits.boundsY.shape = %s" % (grid_mapUnits.boundsY.shape,))
162            log.debug("grid_mapUnits.values.shape = %s" % (grid_mapUnits.values.shape,))
163           
164            values = grid_mapUnits.values
165           
166            assert grid_mapUnits.boundsX.shape == grid_mapUnits.boundsY.shape
167           
168            boundsShape = grid_mapUnits.boundsX.shape
169           
170            #try to fix the values if they are not the expected shape
171            if values.shape[0] != boundsShape[0] -1 and values.shape[1] != boundsShape[1] -1:
172               
173                #check if the axis were in the opposite order
174                if values.shape[1] == boundsShape[0] -1 and values.shape[0] == boundsShape[1] -1:
175                    values = N.transpose(values)           
176
177            #if the bounds are parralel then use imshow, otherwise stick with pcolormesh
178            if self._areBoundsParralel(grid_mapUnits.boundsX, grid_mapUnits.boundsY):
179               
180                self._drawImshow(axes, 
181                                 grid_mapUnits.boundsX, 
182                                 grid_mapUnits.boundsY, 
183                                 values, 
184                                 kwargs)
185                log.debug("drawn imshow in %s" % (time.time() - st ,))
186               
187            else:
188                self._drawPcolormesh(axes, 
189                                     grid_mapUnits.boundsX, 
190                                     grid_mapUnits.boundsY, 
191                                     values, 
192                                     kwargs)
193                log.debug("drawn mesh in %s" % (time.time() - st ,))
194           
195           
196
197
198    def _drawPcolormesh(self, axes, X, Y, Z, kwargs):
199        res = axes.pcolormesh(X, Y, Z, **kwargs)
200        return res
201   
202    def _drawImshow(self, axes, X, Y, Z, kwargs):
203        yStart = Y[0,0]
204        yEnd = Y[-1,-1]
205       
206        log.debug("yStart = %s" % (yStart,))
207        log.debug("yEnd = %s" % (yEnd,))
208       
209        if yStart < yEnd:
210            extent=(X.min(), X.max(),Y.max(), Y.min())               
211        else:
212            extent=(X.min(), X.max(),Y.min(), Y.max())
213       
214        kwargs.pop('edgecolors')
215        kwargs.pop('linestyle')
216        kwargs.pop('linewidth')
217        kwargs.pop('antialiased')
218       
219        im = axes.imshow(Z, extent=extent,
220             interpolation='nearest',
221             **kwargs
222        )
223                   
224        axes.set_aspect('auto')
225       
226    def _areBoundsParralel(self, X, Y ):
227        """
228        checks if all the bound are parallel, this is a test to see if using
229        imshow is possible.
230       
231        Checks the equality to 5 significant figures.
232        """
233        n = 5
234        frmt = "%." + str(n) + "e"
235        st = time.time()
236       
237        equal = True
238        x_first = [frmt % x for x in X[0]]
239       
240        for i in range(1, X.shape[0]):
241            x_row = [frmt % x for x in X[i]]
242            equal = x_first == x_row
243            if not equal:
244                break
245
246        xEq = equal
247       
248        if equal:
249            y_first = [frmt % y for y in Y[:,0]]
250           
251            for i in range(1, Y.shape[1]):
252                y_col = [frmt % y for y in Y[:,i]]
253                equal = y_first == y_col
254                if not equal:
255                    break
256
257        yEq = equal
258        log.debug("xEq = %s, yEq = %s worked out in = %s" % (xEq, yEq, time.time() - st,))
259       
260        return equal
261           
262
263    def _drawGridAsLines(self, axes, X, Y, color='0.25', linewidth=0.3):
264        """
265        Draws a grid as a series of lines onto the given axis. This will work as long as the grid is
266        made up of quadrilaterals.
267        """
268        Ny, Nx = X.shape
269   
270        points=[]
271        # create a list of points to build up the line collection
272        # assuming the grid is filled with quadrelaterals this should
273        # draw it
274        for y in range(Ny):
275            for x in range(Nx):
276                x1 = X[y,x]
277                y1 = Y[y,x]
278   
279                # if this isn't the last intersection in the x direction, draw a line
280                # between this intersection and the next one
281                if x+1 < Nx:
282                    x2 = X[y,x+1]
283                    y2 = Y[y,x+1]
284                    points.append( ((x1, y1), (x2,y2), (x2,y2)) )
285
286                # if this isn't the last intersection in the y direction, draw a line
287                # between this intersection and the next one   
288                if y+1 < Ny:
289                    x3 = X[y+1,x]
290                    y3 = Y[y+1,x]
291                    points.append( ((x1, y1), (x3,y3), (x3,y3)) )
292   
293
294        line_segments = matplotlib.collections.LineCollection(points,
295                                       linewidths    = (linewidth,),
296                                       colors    = (color),
297                                       linestyle = 'solid')
298       
299        axes.add_collection(line_segments)
300
301
302    def _drawMeshValues(self, axes, grid_mapUnits, xLimits, yLimits):
303        """
304        Draw the values of each grid box onto the axes.
305        """
306       
307        # Parse drawValuesFile if specified:
308        if self.drawValuesFile != False:
309            drawValuesList = self._parseDrawValuesFile()
310        else:
311            drawValuesList = False
312
313        ni, nj = N.shape(grid_mapUnits.midpointsX)
314       
315        if xLimits != None:
316            xRangeMaximum = max(xLimits)
317            xRangeMinumum = min(xLimits)
318           
319        if yLimits != None:
320            yRangeMaximum = max(yLimits)
321            yRangeMinumum = min(yLimits)
322
323        for j in range(nj):
324            for i in range(ni):
325
326                # Check if we need to draw this value
327                if drawValuesList != False:
328                    if [i, j] not in drawValuesList: continue
329
330                x = grid_mapUnits.midpointsX[i][j]
331                y = grid_mapUnits.midpointsY[i][j]
332
333                v = grid_mapUnits.values[i,j]
334
335                #skip masked array values:
336                if v.__class__ == MA.MaskedArray:
337                    continue
338
339                # Avoid missing values with float conversion test
340                try:
341                    v = float(v)
342                except:
343                    log.warning("value %s (class=%s) is not a float" , v, v.__class__)
344                    continue
345
346                #get the bounds of the box (in map units), around the current position
347                xBounds, yBounds = self._getBounds(i, j, grid_mapUnits.boundsX, grid_mapUnits.boundsY)
348
349                xTop = max(xBounds)
350                yTop = max(yBounds)
351                xBottom = min(xBounds)
352                yBottom = min(yBounds)
353               
354                if xLimits != None:
355                    if xTop > xRangeMaximum or xBottom < xRangeMinumum:
356                        continue
357
358                if yLimits != None:
359                    if yTop > yRangeMaximum or yBottom < yRangeMinumum:
360                        continue
361
362                axes.text(float(x), float(y), self.valueFormat % v,
363                          ha='center',
364                          va='center',
365                          family='sans-serif',
366                          size=8)
367
368    def _getBounds(self, i, j, xMesh, yMesh):
369        """
370        Gets the points of each of the four corners for a given position in the mesh
371
372        @param i: the x index of the vlaue in the mesh
373        @type i: int
374        @param j: the x index of the vlaue in the mesh
375        @type j: int
376        @param xMesh: the array of the cdmsVar value x bounds that makes up the grid
377        @type xMesh: an instance of the numpy.ndarray class
378        @param yMesh: the array of the cdmsVar value y bounds that makes up the grid
379        @type yMesh: an instance of the numpy.ndarray class
380        """
381
382        coords = ((i, j), (i, j + 1), (i + 1, j), (i + 1, j + 1))
383
384
385        f = lambda array: lambda c: array[c]
386        xBounds = map(f(xMesh), coords)
387        yBounds = map(f(yMesh), coords)
388
389        return (xBounds, yBounds)
390
391    def _parseDrawValuesFile(self):
392        """
393        Attempts to parse the CSV file containing index values of y,x per line.
394        Validates file or raises error.
395        On successful validation returns a list of (y_index, x_index) pairs.
396
397        """
398
399        # Safely read file
400        try:
401            valuesFile = open(self.drawValuesFile)
402            lines = valuesFile.readlines()
403            valuesFile.close()
404        except:
405            raise Exception("Unable to parse 'drawValuesFile': " + str(self.drawValuesFile))
406
407        # Safely parse file contents
408        linePattern = re.compile("^(\d+),(\d+)$")
409        ln = 0
410        cellList = []
411
412        for line in lines:
413            ln += 1
414            line = line.strip()
415            match = linePattern.match(line)
416            if not match:
417                raise Exception("Unable to parse 'drawValuesFile' at line number " + str(ln) + ": " + str(self.drawValuesFile))
418            indices = [int(i) for i in match.groups()]
419            cellList.append(indices)
420
421        return cellList
422 
423if __name__ == '__main__':
424    import geoplot.log_util as log_util
425    log_util.setupGeoplotConsoleHandler(log)
426
427    import pkg_resources, os
428    outputsDir = os.path.abspath(pkg_resources.resource_filename('geoplot', '../../outputs'))   
429
430    xLimits=(-13.0, 6.2) ; yLimits=(47.0, 61.0)
431
432    from numpy import meshgrid
433    from cdms_utils.cdms_compat import N, MA
434
435    x = [-6.,-4.,-2., 0., 2.]
436    xMid = [-5, -3, -1, 1]
437    xlim = len(x) -1
438    y = [54.,55.,56.,58.]
439    yMid = [54.5, 55.5, 57]
440    ylim = len(y) - 1
441    X,Y = meshgrid(x,y)
442   
443    X[0,0] = -7
444    Y[0,0] = 53
445    X[ylim, xlim] = 3
446    Y[ylim, xlim] = 59
447       
448    XMid, YMid = meshgrid(xMid, yMid)
449
450    Z=[[0]*xlim for i in range(ylim)]
451    for i in range(0 , xlim*ylim):
452        vX= i%xlim
453        vY= (i - i%xlim) / xlim
454        Z[vY][vX] = i
455    #Z = N.array(Z)
456    Z = MA.masked_values(Z, 2)
457    Z = MA.masked_values(Z, 6)
458       
459    grid = Grid(X,Y,XMid, YMid, Z)
460
461    from geoplot.mpl_imports import basemap
462    from matplotlib.backends.backend_agg import FigureCanvasAgg
463    from matplotlib.figure import Figure
464    from matplotlib import cm
465
466    figsize=(600 / 80, 800 / 80)
467    fig = Figure(figsize=figsize, dpi=80)
468    axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])
469
470
471    bm = basemap.Basemap(llcrnrlon=xLimits[0],
472                       llcrnrlat=yLimits[0],
473                       urcrnrlon=xLimits[1],
474                       urcrnrlat=yLimits[1],
475                       resolution='h',
476                       suppress_ticks=False)
477    cmap = cm.winter
478    cmap.set_bad("w")   
479    drawer = GridDrawer(cmap, drawValues=True, valueFormat="%0.2f", showGridLines=False,
480                        outline=False)
481       
482    sm = drawer.draw(axes, grid, basemap=bm)
483   
484    bm.drawcoastlines(ax = axes)
485    bm.drawrivers(ax = axes, color='b')
486    bm.drawmeridians([-10,-5,0,5,10], ax=axes, dashes=(None,None))
487    bm.drawparallels([50,55,60], ax=axes, dashes=(None,None))
488   
489   
490    cb = fig.colorbar(sm, orientation='vertical')
491    print "tick position", cb.ax.yaxis.get_ticks_position()
492    cb.ax.yaxis.set_ticks_position('default')
493    print "tick position", cb.ax.yaxis.get_ticks_position()   
494
495
496    canvas = FigureCanvasAgg(fig)
497    filename = "grid_drawer_test.png"
498   
499    fullPath = os.path.join(outputsDir, filename)
500    canvas.print_figure(fullPath)
501    log.debug("Wrote " + fullPath)
502
Note: See TracBrowser for help on using the repository browser.