source: qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py @ 6024

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py@6024
Revision 6024, 18.0 KB checked in by pnorton, 11 years ago (diff)

Modified the cairo renderer to work with the latest matplotlib version.

Line 
1"""
2grid_drawer.py
3============
4
5A GridDrawere knows how to draw a lat-lon grid onto a axes so that it will match
6with a basemap drawing on the same axis.
7
8"""
9#python modules
10import logging
11import operator
12import re
13import time
14from math import *
15
16#third party modules
17from geoplot.mpl_imports import basemap
18import matplotlib.colors
19import matplotlib.cm
20import matplotlib.collections
21import numpy as N
22import numpy.ma as MA
23import geoplot.mpl_imports
24from matplotlib.colors import LightSource
25
26#internal modules
27from geoplot.array_util import *
28from geoplot.grid import Grid
29import geoplot.utils as geoplot_utils
30
31#setup the logging
32log = logging.getLogger(__name__)
33
34class GridDrawer(object):
35    """
36    Responsible for knwoing how to draw a Grid object onto a given matplotlib
37    axes
38    """
39
40    def __init__(self, drawValues=False, drawValuesFile=False, 
41                 valueFormat="%.2f", showGridLines=False,
42                 outline=False):
43        """
44        Constructs a GridDrawer object             
45        """
46        self.drawValues = drawValues
47        self.drawValuesFile = drawValuesFile
48        self.valueFormat = valueFormat
49        self.showGridLines = showGridLines
50        self.outline = outline
51
52    def draw(self, axes, grid, limits=None, basemap=None, norm=None,
53             cmap=None, fontSize='medium', assumeBoundsParallel=None):
54        """
55        Draws the grid given on the axis.
56       
57        @param axes: the axes the grid will be drawn on
58        @type axes: matplotlib.axes
59        @param grid: the grid to be drawn
60        @type grid: geoplot.grid
61        @keyword limits: the limits of the drawing (in lat lon)
62        @type limits: tuple
63        @keyword basemap: the basemap instance to scale the grid values to be
64            drawn on the axis.
65        """
66
67        if cmap == None:
68            cmap = matplotlib.cm.get_cmap()
69            cmap.set_bad("w")   
70       
71        if norm == None:
72            norm = matplotlib.colors.Normalize(grid.values.min(), grid.values.max())     
73
74        xLimits, yLimits = self._transformLimits(limits, basemap)
75       
76        # check that the grid being drawn has data inside the limits, if there
77        # is no data in the limits there is no point drawing anything.
78        # can only check the grid is in the limits if they are provided
79        if limits != (None, None):
80            if not self._isGridDataWithinLimits(grid, xLimits, yLimits):
81                return
82               
83        grid_mapUnits = grid.transform(basemap)
84
85        self._drawGridBoxes(axes, grid_mapUnits, basemap, cmap, norm,
86                            assumeBoundsParallel)
87       
88#        log.debug("self.drawValues = %s" % (self.drawValues,))
89#        log.debug("self.drawValuesFile = %s" % (self.drawValuesFile,))
90
91        if self.drawValues == True:
92            log.debug('drawing grid values')
93           
94            #draw the vlaues onto the grid
95            self._drawMeshValues(axes, grid_mapUnits, xLimits, yLimits)
96       
97
98    def _transformLimits(self, limits, basemap):
99        """
100        Use the basemap object specified to transform the limits into map units.
101        """
102       
103        if limits == None:
104            return (None, None)
105        elif basemap == None:
106            return (limits[0], limits[1])
107        else:
108            return basemap(limits[0], limits[1])
109
110    def _drawGridBoxes(self, axes,  grid_mapUnits, basemap, cmap, norm,
111                       assumeBoundsParallel):
112        """
113        Draws the grid boxes onto the axes.
114        """
115       
116       
117        gridColour = 'k'
118        gridLineWidth = 0.25
119
120        kwargs = {'cmap':cmap, 'norm':norm}
121       
122        if self.outline == True :
123            # this will draw an outline on all grid cells that don't contain a
124            # missing value
125            kwargs['edgecolors'] = gridColour
126        else:
127            # needs to be a string rather than None or the rc default will be used
128            kwargs['edgecolors'] = 'None'
129
130
131        if geoplot.mpl_imports.oldMatplotlib == False:
132            # in the newer version of matplotlib you can set the linewidth +
133            # linestyle on the pcolor() function
134            kwargs['linewidth'] = gridLineWidth
135            kwargs['linestyle'] = 'solid'
136            kwargs['antialiased'] = True
137
138        #check the values aren't all masked
139       
140       
141        # Its possible that this test becomes time consuming for large arrays,
142        # it is pssible that using sum might be faster, but this will only work
143        # if True + True = 2.
144        # valuesFound = not grid.values.mask.sum() == len(grid.values.mask.flatten())
145        # valuesFound = not grid.values.mask.sum() == reduce(lambda a,b: a*b, grid.values.shape)
146       
147        # for now this is so quick it doesn't matter.
148        valuesFound = not reduce(operator.and_, grid_mapUnits.values.mask.flat)
149
150
151        #if all the values are masked make sure the outline is drawn
152        if self.showGridLines == True:
153            st = time.time()
154            # this will draw all the grid cells even if their contents is missing
155            self._drawGridAsLines(axes, grid_mapUnits.boundsX, grid_mapUnits.boundsY, 
156                           color=gridColour, linewidth=gridLineWidth)
157            log.debug("drawn grid lines in = %s" % (time.time() - st,))
158       
159        if valuesFound == True:
160           
161            log.debug("grid_mapUnits: boundsX.shape = %s, boundsY.shape = %s, values.shape = %s" 
162                      % (grid_mapUnits.boundsX.shape, grid_mapUnits.boundsY.shape, grid_mapUnits.values.shape))
163           
164            assert grid_mapUnits.boundsX.shape == grid_mapUnits.boundsY.shape
165           
166            values = grid_mapUnits.values
167            boundsShape = grid_mapUnits.boundsX.shape
168           
169            #try to fix the values if they are not the expected shape
170            if values.shape[0] != boundsShape[0] -1 and values.shape[1] != boundsShape[1] -1:
171               
172                message = "Value array shape doesn't match the bounds shape." 
173               
174                #check if the axis were in the opposite order
175                if values.shape[1] == boundsShape[0] -1 and values.shape[0] == boundsShape[1] -1:
176                    message += " It looks like the data may have been transposed."
177                    #log.warning("Transposing data to try and fit bounds to values")
178                    #values = N.transpose(values)
179               
180                raise Exception(message)
181                               
182            if assumeBoundsParallel == True:
183                useImshow = True
184            elif assumeBoundsParallel == False:
185                useImshow = False
186            else:
187                useImshow = self._areBoundsParralel(grid_mapUnits.boundsX, grid_mapUnits.boundsY)
188           
189            #if the bounds are parralel then use imshow, otherwise stick with pcolormesh
190            if useImshow:
191               
192                self._drawImshow(axes, 
193                                 grid_mapUnits.boundsX, 
194                                 grid_mapUnits.boundsY, 
195                                 values, 
196                                 kwargs)
197               
198            else:
199                self._drawPcolormesh(axes, 
200                                     grid_mapUnits.boundsX, 
201                                     grid_mapUnits.boundsY, 
202                                     values, 
203                                     kwargs)
204               
205           
206           
207
208
209    def _drawPcolormesh(self, axes, X, Y, Z, kwargs):
210        """
211        Draw the vlaues Z at using the bounds X and Y and the pcolormesh method on the axes object.
212        The kwargs dictionary is passed to the pcolormesh method.
213        """
214        res = axes.pcolormesh(X, Y, Z, **kwargs)
215        return res
216   
217    def _drawImshow(self, axes, X, Y, Z, kwargs):
218        """
219        Draw the values Z using the bounds X and Y and the imshow method on the axis.
220        The kwargs dictionary is passed to the imshow method.
221        """
222        st1 = time.time()
223       
224        yStart = Y[0,0]
225        yEnd = Y[-1,-1]
226       
227        #log.debug("yStart = %s, yEnd = %s" % (yStart, yEnd))
228       
229        if yStart < yEnd:
230            extent=(X.min(), X.max(),Y.max(), Y.min())               
231        else:
232            extent=(X.min(), X.max(),Y.min(), Y.max())
233       
234        kwargs.pop('edgecolors')
235        kwargs.pop('linestyle')
236        kwargs.pop('linewidth')
237        kwargs.pop('antialiased')
238        #kwargs['origin'] = 'lower'
239       
240        log.debug("got extent = %s yStart = %s yEnd = %s" % (extent, yStart, yEnd,))
241       
242#        ls = LightSource(azdeg=0,altdeg=65)
243#        rgb = ls.shade(Z,kwargs['cmap'])
244
245       
246       
247        im = axes.imshow(Z, extent=extent,
248             interpolation='nearest',
249             **kwargs
250        )
251                   
252        axes.set_aspect('auto')
253       
254    def _areBoundsParralel(self, X, Y ):
255        """
256        checks if all the bound are parallel, this is a test to see if using
257        imshow is possible.
258       
259        Checks the equality to 5 significant figures.
260        """
261       
262        st = time.time()
263        equal = True
264       
265        for i in range(1, X.shape[0]):
266           
267            equal = N.allclose(X[0], X[i])
268           
269            if not equal:
270                log.debug("X[0,0:5] = %s" % (X[0,0:5],))
271                log.debug("X[i,0:5] = %s" % (X[i,0:5],))
272                break
273
274        xEq = equal
275       
276        if equal:
277           
278            for i in range(1, Y.shape[1]):
279               
280                equal = N.allclose( Y[:,0], Y[:,i])
281               
282                if not equal:
283                    break
284
285        yEq = equal
286       
287        log.debug("xEq = %s, yEq = %s worked out in = %s" % (xEq, yEq, time.time() - st,))
288       
289        return equal
290           
291
292    def _drawGridAsLines(self, axes, X, Y, color='0.25', linewidth=0.3):
293        """
294        Draws a grid as a series of lines onto the given axis. This will work as long as the grid is
295        made up of quadrilaterals.
296        """
297        Ny, Nx = X.shape
298   
299        points=[]
300        # create a list of points to build up the line collection
301        # assuming the grid is filled with quadrelaterals this should
302        # draw it
303        for y in range(Ny):
304            for x in range(Nx):
305                x1 = X[y,x]
306                y1 = Y[y,x]
307   
308                # if this isn't the last intersection in the x direction, draw a line
309                # between this intersection and the next one
310                if x+1 < Nx:
311                    x2 = X[y,x+1]
312                    y2 = Y[y,x+1]
313                    points.append( ((x1, y1), (x2,y2), (x2,y2)) )
314
315                # if this isn't the last intersection in the y direction, draw a line
316                # between this intersection and the next one   
317                if y+1 < Ny:
318                    x3 = X[y+1,x]
319                    y3 = Y[y+1,x]
320                    points.append( ((x1, y1), (x3,y3), (x3,y3)) )
321   
322
323        line_segments = matplotlib.collections.LineCollection(points,
324                                       linewidths    = (linewidth,),
325                                       colors    = (color),
326                                       linestyle = 'solid')
327       
328        axes.add_collection(line_segments)
329
330
331    def _drawMeshValues(self, axes, grid_mapUnits, xLimits, yLimits):
332        """
333        Draw the values of each grid box onto the axes.
334        """
335       
336        # Parse drawValuesFile if specified:
337        if self.drawValuesFile != False:
338            drawValuesList = self._parseDrawValuesFile()
339        else:
340            drawValuesList = False
341
342        ni, nj = N.shape(grid_mapUnits.midpointsX)
343       
344        if xLimits != None:
345            xRangeMaximum = max(xLimits)
346            xRangeMinumum = min(xLimits)
347           
348        if yLimits != None:
349            yRangeMaximum = max(yLimits)
350            yRangeMinumum = min(yLimits)
351
352        for j in range(nj):
353            for i in range(ni):
354
355                # Check if we need to draw this value
356                if drawValuesList != False:
357                    if [i, j] not in drawValuesList: continue
358
359                x = grid_mapUnits.midpointsX[i][j]
360                y = grid_mapUnits.midpointsY[i][j]
361
362                v = grid_mapUnits.values[i,j]
363
364                #skip masked array values:
365                if v.__class__ == MA.MaskedArray:
366                    continue
367
368                # Avoid missing values with float conversion test
369                try:
370                    v = float(v)
371                except:
372                    log.warning("value %s (class=%s) is not a float" , v, v.__class__)
373                    continue
374
375                #get the bounds of the box (in map units), around the current position
376                xBounds, yBounds = self._getBounds(i, j, grid_mapUnits.boundsX, grid_mapUnits.boundsY)
377
378                xTop = max(xBounds)
379                yTop = max(yBounds)
380                xBottom = min(xBounds)
381                yBottom = min(yBounds)
382               
383                if xLimits != None:
384                    if xTop > xRangeMaximum or xBottom < xRangeMinumum:
385                        continue
386
387                if yLimits != None:
388                    if yTop > yRangeMaximum or yBottom < yRangeMinumum:
389                        continue
390
391                axes.text(float(x), float(y), self.valueFormat % v,
392                          ha='center',
393                          va='center',
394                          family='sans-serif',
395                          size=8)
396
397    def _getBounds(self, i, j, xMesh, yMesh):
398        """
399        Gets the points of each of the four corners for a given position in the mesh
400
401        @param i: the x index of the vlaue in the mesh
402        @type i: int
403        @param j: the x index of the vlaue in the mesh
404        @type j: int
405        @param xMesh: the array of the cdmsVar value x bounds that makes up the grid
406        @type xMesh: an instance of the numpy.ndarray class
407        @param yMesh: the array of the cdmsVar value y bounds that makes up the grid
408        @type yMesh: an instance of the numpy.ndarray class
409        """
410
411        coords = ((i, j), (i, j + 1), (i + 1, j), (i + 1, j + 1))
412
413
414        f = lambda array: lambda c: array[c]
415        xBounds = map(f(xMesh), coords)
416        yBounds = map(f(yMesh), coords)
417
418        return (xBounds, yBounds)
419
420    def _parseDrawValuesFile(self):
421        """
422        Attempts to parse the CSV file containing index values of y,x per line.
423        Validates file or raises error.
424        On successful validation returns a list of (y_index, x_index) pairs.
425
426        """
427
428        # Safely read file
429        try:
430            valuesFile = open(self.drawValuesFile)
431            lines = valuesFile.readlines()
432            valuesFile.close()
433        except:
434            raise Exception("Unable to parse 'drawValuesFile': " + str(self.drawValuesFile))
435
436        # Safely parse file contents
437        linePattern = re.compile("^(\d+),(\d+)$")
438        ln = 0
439        cellList = []
440
441        for line in lines:
442            ln += 1
443            line = line.strip()
444            match = linePattern.match(line)
445            if not match:
446                raise Exception("Unable to parse 'drawValuesFile' at line number " + str(ln) + ": " + str(self.drawValuesFile))
447            indices = [int(i) for i in match.groups()]
448            cellList.append(indices)
449
450        return cellList
451   
452    def _isGridDataWithinLimits(self, grid, xLimits, yLimits):
453       
454        xRange = (grid.boundsX.min(), grid.boundsX.max())
455        yRange = (grid.boundsY.min(), grid.boundsY.max())
456       
457        #log.debug("xRange = %s" % (xRange,))
458        #log.debug("yRange = %s" % (yRange,))
459       
460        isInLimits = geoplot_utils.isRangeInLimits(xRange, xLimits) and \
461                     geoplot_utils.isRangeInLimits(yRange, yLimits)
462       
463        return isInLimits
464   
465if __name__ == '__main__':
466    import geoplot.log_util as log_util
467    log_util.setupGeoplotConsoleHandler(log)
468
469    import pkg_resources, os
470    outputsDir = os.path.abspath(pkg_resources.resource_filename('geoplot', '../../outputs'))   
471
472    xLimits=(-13.0, 6.2) ; yLimits=(47.0, 61.0)
473
474    from numpy import meshgrid
475
476    x = [-6.,-4.,-2., 0., 2.]
477    xMid = [-5, -3, -1, 1]
478    xlim = len(x) -1
479    y = [54.,55.,56.,58.]
480    yMid = [54.5, 55.5, 57]
481    ylim = len(y) - 1
482    X,Y = meshgrid(x,y)
483   
484    X[0,0] = -7
485    Y[0,0] = 53
486    X[ylim, xlim] = 3
487    Y[ylim, xlim] = 59
488       
489    XMid, YMid = meshgrid(xMid, yMid)
490
491    Z=[[0]*xlim for i in range(ylim)]
492    for i in range(0 , xlim*ylim):
493        vX= i%xlim
494        vY= (i - i%xlim) / xlim
495        Z[vY][vX] = i
496    #Z = N.array(Z)
497    Z = MA.masked_values(Z, 2)
498    Z = MA.masked_values(Z, 6)
499       
500    grid = Grid(X,Y,XMid, YMid, Z)
501
502    from geoplot.mpl_imports import basemap
503    from matplotlib.backends.backend_agg import FigureCanvasAgg
504    from matplotlib.figure import Figure
505    from matplotlib import cm
506
507    figsize=(600 / 80, 800 / 80)
508    fig = Figure(figsize=figsize, dpi=80)
509    axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])
510
511
512    bm = basemap.Basemap(llcrnrlon=xLimits[0],
513                       llcrnrlat=yLimits[0],
514                       urcrnrlon=xLimits[1],
515                       urcrnrlat=yLimits[1],
516                       resolution='h',
517                       suppress_ticks=False)
518    cmap = cm.winter
519    cmap.set_bad("w")   
520    drawer = GridDrawer(cmap, drawValues=True, valueFormat="%0.2f", showGridLines=False,
521                        outline=False)
522       
523    sm = drawer.draw(axes, grid, basemap=bm)
524   
525    bm.drawcoastlines(ax = axes)
526    bm.drawrivers(ax = axes, color='b')
527    bm.drawmeridians([-10,-5,0,5,10], ax=axes, dashes=(None,None))
528    bm.drawparallels([50,55,60], ax=axes, dashes=(None,None))
529   
530   
531    cb = fig.colorbar(sm, orientation='vertical')
532    print "tick position", cb.ax.yaxis.get_ticks_position()
533    cb.ax.yaxis.set_ticks_position('default')
534    print "tick position", cb.ax.yaxis.get_ticks_position()   
535
536
537    canvas = FigureCanvasAgg(fig)
538    filename = "grid_drawer_test.png"
539   
540    fullPath = os.path.join(outputsDir, filename)
541    canvas.print_figure(fullPath)
542    log.debug("Wrote " + fullPath)
543
Note: See TracBrowser for help on using the repository browser.