source: qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py @ 6089

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py@6089
Revision 6089, 18.1 KB checked in by pnorton, 11 years ago (diff)

Imroved the colour bar code so that a legend colour bar can be used without specifying any intervals.

Line 
1"""
2grid_drawer.py
3============
4
5A GridDrawere knows how to draw a lat-lon grid onto a axes so that it will match
6with a basemap drawing on the same axis.
7
8"""
9#python modules
10import logging
11import operator
12import re
13import time
14from math import *
15
16#third party modules
17from geoplot.mpl_imports import basemap
18import matplotlib.colors
19import matplotlib.cm
20import matplotlib.collections
21import numpy as N
22import numpy.ma as MA
23import geoplot.mpl_imports
24from matplotlib.colors import LightSource
25
26#internal modules
27from geoplot.array_util import *
28from geoplot.grid import Grid
29import geoplot.utils as geoplot_utils
30
31#setup the logging
32log = logging.getLogger(__name__)
33
34class GridDrawer(object):
35    """
36    Responsible for knwoing how to draw a Grid object onto a given matplotlib
37    axes
38    """
39
40    def __init__(self, drawValues=False, drawValuesFile=False, 
41                 valueFormat="%.2f", showGridLines=False,
42                 outline=False):
43        """
44        Constructs a GridDrawer object             
45        """
46        self.drawValues = drawValues
47        self.drawValuesFile = drawValuesFile
48        self.valueFormat = valueFormat
49        self.showGridLines = showGridLines
50        self.outline = outline
51
52    def draw(self, axes, grid, limits=None, basemap=None, norm=None,
53             cmap=None, fontSize='medium', assumeBoundsParallel=None):
54        """
55        Draws the grid given on the axis.
56       
57        @param axes: the axes the grid will be drawn on
58        @type axes: matplotlib.axes
59        @param grid: the grid to be drawn
60        @type grid: geoplot.grid
61        @keyword limits: the limits of the drawing (in lat lon)
62        @type limits: tuple
63        @keyword basemap: the basemap instance to scale the grid values to be
64            drawn on the axis.
65        """
66
67
68        if cmap == None:
69            cmap = matplotlib.cm.get_cmap()
70            cmap.set_bad("w")
71               
72        log.debug("cmap._rgba_under  = %s" % (cmap._rgba_under ,))
73       
74        if norm == None:
75            norm = matplotlib.colors.Normalize(grid.values.min(), grid.values.max())     
76
77        xLimits, yLimits = self._transformLimits(limits, basemap)
78       
79        # check that the grid being drawn has data inside the limits, if there
80        # is no data in the limits there is no point drawing anything.
81        # can only check the grid is in the limits if they are provided
82        if limits != (None, None):
83            if not self._isGridDataWithinLimits(grid, xLimits, yLimits):
84                return
85               
86        grid_mapUnits = grid.transform(basemap)
87
88        self._drawGridBoxes(axes, grid_mapUnits, basemap, cmap, norm,
89                            assumeBoundsParallel)
90       
91#        log.debug("self.drawValues = %s" % (self.drawValues,))
92#        log.debug("self.drawValuesFile = %s" % (self.drawValuesFile,))
93
94        if self.drawValues == True:
95            log.debug('drawing grid values')
96           
97            #draw the vlaues onto the grid
98            self._drawMeshValues(axes, grid_mapUnits, xLimits, yLimits)
99       
100
101    def _transformLimits(self, limits, basemap):
102        """
103        Use the basemap object specified to transform the limits into map units.
104        """
105       
106        if limits == None:
107            return (None, None)
108        elif basemap == None:
109            return (limits[0], limits[1])
110        else:
111            return basemap(limits[0], limits[1])
112
113    def _drawGridBoxes(self, axes,  grid_mapUnits, basemap, cmap, norm,
114                       assumeBoundsParallel):
115        """
116        Draws the grid boxes onto the axes.
117        """
118       
119       
120        gridColour = 'k'
121        gridLineWidth = 0.25
122
123        kwargs = {'cmap':cmap, 'norm':norm}
124       
125        if self.outline == True :
126            # this will draw an outline on all grid cells that don't contain a
127            # missing value
128            kwargs['edgecolors'] = gridColour
129        else:
130            # needs to be a string rather than None or the rc default will be used
131            kwargs['edgecolors'] = 'None'
132
133
134        if geoplot.mpl_imports.oldMatplotlib == False:
135            # in the newer version of matplotlib you can set the linewidth +
136            # linestyle on the pcolor() function
137            kwargs['linewidth'] = gridLineWidth
138            kwargs['linestyle'] = 'solid'
139            kwargs['antialiased'] = True
140
141        #check the values aren't all masked
142       
143       
144        # Its possible that this test becomes time consuming for large arrays,
145        # it is pssible that using sum might be faster, but this will only work
146        # if True + True = 2.
147        # valuesFound = not grid.values.mask.sum() == len(grid.values.mask.flatten())
148        # valuesFound = not grid.values.mask.sum() == reduce(lambda a,b: a*b, grid.values.shape)
149       
150        # for now this is so quick it doesn't matter.
151        valuesFound = not reduce(operator.and_, grid_mapUnits.values.mask.flat)
152
153
154        #if all the values are masked make sure the outline is drawn
155        if self.showGridLines == True:
156            st = time.time()
157            # this will draw all the grid cells even if their contents is missing
158            self._drawGridAsLines(axes, grid_mapUnits.boundsX, grid_mapUnits.boundsY, 
159                           color=gridColour, linewidth=gridLineWidth)
160            log.debug("drawn grid lines in = %s" % (time.time() - st,))
161       
162        if valuesFound == True:
163           
164            log.debug("grid_mapUnits: boundsX.shape = %s, boundsY.shape = %s, values.shape = %s" 
165                      % (grid_mapUnits.boundsX.shape, grid_mapUnits.boundsY.shape, grid_mapUnits.values.shape))
166           
167            assert grid_mapUnits.boundsX.shape == grid_mapUnits.boundsY.shape
168           
169            values = grid_mapUnits.values
170            boundsShape = grid_mapUnits.boundsX.shape
171           
172            #try to fix the values if they are not the expected shape
173            if values.shape[0] != boundsShape[0] -1 and values.shape[1] != boundsShape[1] -1:
174               
175                message = "Value array shape doesn't match the bounds shape." 
176               
177                #check if the axis were in the opposite order
178                if values.shape[1] == boundsShape[0] -1 and values.shape[0] == boundsShape[1] -1:
179                    message += " It looks like the data may have been transposed."
180                    #log.warning("Transposing data to try and fit bounds to values")
181                    #values = N.transpose(values)
182               
183                raise Exception(message)
184                               
185            if assumeBoundsParallel == True:
186                useImshow = True
187            elif assumeBoundsParallel == False:
188                useImshow = False
189            else:
190                useImshow = self._areBoundsParralel(grid_mapUnits.boundsX, grid_mapUnits.boundsY)
191           
192            #if the bounds are parralel then use imshow, otherwise stick with pcolormesh
193            if useImshow:
194               
195                self._drawImshow(axes, 
196                                 grid_mapUnits.boundsX, 
197                                 grid_mapUnits.boundsY, 
198                                 values, 
199                                 kwargs)
200               
201            else:
202                self._drawPcolormesh(axes, 
203                                     grid_mapUnits.boundsX, 
204                                     grid_mapUnits.boundsY, 
205                                     values, 
206                                     kwargs)
207               
208           
209           
210
211
212    def _drawPcolormesh(self, axes, X, Y, Z, kwargs):
213        """
214        Draw the vlaues Z at using the bounds X and Y and the pcolormesh method on the axes object.
215        The kwargs dictionary is passed to the pcolormesh method.
216        """
217        res = axes.pcolormesh(X, Y, Z, **kwargs)
218        return res
219   
220    def _drawImshow(self, axes, X, Y, Z, kwargs):
221        """
222        Draw the values Z using the bounds X and Y and the imshow method on the axis.
223        The kwargs dictionary is passed to the imshow method.
224        """
225        st1 = time.time()
226       
227        yStart = Y[0,0]
228        yEnd = Y[-1,-1]
229       
230        #log.debug("yStart = %s, yEnd = %s" % (yStart, yEnd))
231       
232        if yStart < yEnd:
233            extent=(X.min(), X.max(),Y.max(), Y.min())               
234        else:
235            extent=(X.min(), X.max(),Y.min(), Y.max())
236       
237        kwargs.pop('edgecolors')
238        kwargs.pop('linestyle')
239        kwargs.pop('linewidth')
240        kwargs.pop('antialiased')
241        #kwargs['origin'] = 'lower'
242       
243        log.debug("got extent = %s yStart = %s yEnd = %s" % (extent, yStart, yEnd,))
244       
245#        ls = LightSource(azdeg=0,altdeg=65)
246#        rgb = ls.shade(Z,kwargs['cmap'])
247
248       
249       
250        im = axes.imshow(Z, extent=extent,
251             interpolation='nearest',
252             **kwargs
253        )
254                   
255        axes.set_aspect('auto')
256       
257    def _areBoundsParralel(self, X, Y ):
258        """
259        checks if all the bound are parallel, this is a test to see if using
260        imshow is possible.
261       
262        Checks the equality to 5 significant figures.
263        """
264       
265        st = time.time()
266        equal = True
267       
268        for i in range(1, X.shape[0]):
269           
270            equal = N.allclose(X[0], X[i])
271           
272            if not equal:
273                log.debug("X[0,0:5] = %s" % (X[0,0:5],))
274                log.debug("X[i,0:5] = %s" % (X[i,0:5],))
275                break
276
277        xEq = equal
278       
279        if equal:
280           
281            for i in range(1, Y.shape[1]):
282               
283                equal = N.allclose( Y[:,0], Y[:,i])
284               
285                if not equal:
286                    break
287
288        yEq = equal
289       
290        log.debug("xEq = %s, yEq = %s worked out in = %s" % (xEq, yEq, time.time() - st,))
291       
292        return equal
293           
294
295    def _drawGridAsLines(self, axes, X, Y, color='0.25', linewidth=0.3):
296        """
297        Draws a grid as a series of lines onto the given axis. This will work as long as the grid is
298        made up of quadrilaterals.
299        """
300        Ny, Nx = X.shape
301   
302        points=[]
303        # create a list of points to build up the line collection
304        # assuming the grid is filled with quadrelaterals this should
305        # draw it
306        for y in range(Ny):
307            for x in range(Nx):
308                x1 = X[y,x]
309                y1 = Y[y,x]
310   
311                # if this isn't the last intersection in the x direction, draw a line
312                # between this intersection and the next one
313                if x+1 < Nx:
314                    x2 = X[y,x+1]
315                    y2 = Y[y,x+1]
316                    points.append( ((x1, y1), (x2,y2), (x2,y2)) )
317
318                # if this isn't the last intersection in the y direction, draw a line
319                # between this intersection and the next one   
320                if y+1 < Ny:
321                    x3 = X[y+1,x]
322                    y3 = Y[y+1,x]
323                    points.append( ((x1, y1), (x3,y3), (x3,y3)) )
324   
325
326        line_segments = matplotlib.collections.LineCollection(points,
327                                       linewidths    = (linewidth,),
328                                       colors    = (color),
329                                       linestyle = 'solid')
330       
331        axes.add_collection(line_segments)
332
333
334    def _drawMeshValues(self, axes, grid_mapUnits, xLimits, yLimits):
335        """
336        Draw the values of each grid box onto the axes.
337        """
338       
339        # Parse drawValuesFile if specified:
340        if self.drawValuesFile != False:
341            drawValuesList = self._parseDrawValuesFile()
342        else:
343            drawValuesList = False
344
345        ni, nj = N.shape(grid_mapUnits.midpointsX)
346       
347        if xLimits != None:
348            xRangeMaximum = max(xLimits)
349            xRangeMinumum = min(xLimits)
350           
351        if yLimits != None:
352            yRangeMaximum = max(yLimits)
353            yRangeMinumum = min(yLimits)
354
355        for j in range(nj):
356            for i in range(ni):
357
358                # Check if we need to draw this value
359                if drawValuesList != False:
360                    if [i, j] not in drawValuesList: continue
361
362                x = grid_mapUnits.midpointsX[i][j]
363                y = grid_mapUnits.midpointsY[i][j]
364
365                v = grid_mapUnits.values[i,j]
366
367                #skip masked array values:
368                if v.__class__ == MA.MaskedArray:
369                    continue
370
371                # Avoid missing values with float conversion test
372                try:
373                    v = float(v)
374                except:
375                    log.warning("value %s (class=%s) is not a float" , v, v.__class__)
376                    continue
377
378                #get the bounds of the box (in map units), around the current position
379                xBounds, yBounds = self._getBounds(i, j, grid_mapUnits.boundsX, grid_mapUnits.boundsY)
380
381                xTop = max(xBounds)
382                yTop = max(yBounds)
383                xBottom = min(xBounds)
384                yBottom = min(yBounds)
385               
386                if xLimits != None:
387                    if xTop > xRangeMaximum or xBottom < xRangeMinumum:
388                        continue
389
390                if yLimits != None:
391                    if yTop > yRangeMaximum or yBottom < yRangeMinumum:
392                        continue
393
394                axes.text(float(x), float(y), self.valueFormat % v,
395                          ha='center',
396                          va='center',
397                          family='sans-serif',
398                          size=8)
399
400    def _getBounds(self, i, j, xMesh, yMesh):
401        """
402        Gets the points of each of the four corners for a given position in the mesh
403
404        @param i: the x index of the vlaue in the mesh
405        @type i: int
406        @param j: the x index of the vlaue in the mesh
407        @type j: int
408        @param xMesh: the array of the cdmsVar value x bounds that makes up the grid
409        @type xMesh: an instance of the numpy.ndarray class
410        @param yMesh: the array of the cdmsVar value y bounds that makes up the grid
411        @type yMesh: an instance of the numpy.ndarray class
412        """
413
414        coords = ((i, j), (i, j + 1), (i + 1, j), (i + 1, j + 1))
415
416
417        f = lambda array: lambda c: array[c]
418        xBounds = map(f(xMesh), coords)
419        yBounds = map(f(yMesh), coords)
420
421        return (xBounds, yBounds)
422
423    def _parseDrawValuesFile(self):
424        """
425        Attempts to parse the CSV file containing index values of y,x per line.
426        Validates file or raises error.
427        On successful validation returns a list of (y_index, x_index) pairs.
428
429        """
430
431        # Safely read file
432        try:
433            valuesFile = open(self.drawValuesFile)
434            lines = valuesFile.readlines()
435            valuesFile.close()
436        except:
437            raise Exception("Unable to parse 'drawValuesFile': " + str(self.drawValuesFile))
438
439        # Safely parse file contents
440        linePattern = re.compile("^(\d+),(\d+)$")
441        ln = 0
442        cellList = []
443
444        for line in lines:
445            ln += 1
446            line = line.strip()
447            match = linePattern.match(line)
448            if not match:
449                raise Exception("Unable to parse 'drawValuesFile' at line number " + str(ln) + ": " + str(self.drawValuesFile))
450            indices = [int(i) for i in match.groups()]
451            cellList.append(indices)
452
453        return cellList
454   
455    def _isGridDataWithinLimits(self, grid, xLimits, yLimits):
456       
457        xRange = (grid.boundsX.min(), grid.boundsX.max())
458        yRange = (grid.boundsY.min(), grid.boundsY.max())
459       
460        #log.debug("xRange = %s" % (xRange,))
461        #log.debug("yRange = %s" % (yRange,))
462       
463        isInLimits = geoplot_utils.isRangeInLimits(xRange, xLimits) and \
464                     geoplot_utils.isRangeInLimits(yRange, yLimits)
465       
466        return isInLimits
467   
468if __name__ == '__main__':
469    import geoplot.log_util as log_util
470    log_util.setupGeoplotConsoleHandler(log)
471
472    import pkg_resources, os
473    outputsDir = os.path.abspath(pkg_resources.resource_filename('geoplot', '../../outputs'))   
474
475    xLimits=(-13.0, 6.2) ; yLimits=(47.0, 61.0)
476
477    from numpy import meshgrid
478
479    x = [-6.,-4.,-2., 0., 2.]
480    xMid = [-5, -3, -1, 1]
481    xlim = len(x) -1
482    y = [54.,55.,56.,58.]
483    yMid = [54.5, 55.5, 57]
484    ylim = len(y) - 1
485    X,Y = meshgrid(x,y)
486   
487    X[0,0] = -7
488    Y[0,0] = 53
489    X[ylim, xlim] = 3
490    Y[ylim, xlim] = 59
491       
492    XMid, YMid = meshgrid(xMid, yMid)
493
494    Z=[[0]*xlim for i in range(ylim)]
495    for i in range(0 , xlim*ylim):
496        vX= i%xlim
497        vY= (i - i%xlim) / xlim
498        Z[vY][vX] = i
499    #Z = N.array(Z)
500    Z = MA.masked_values(Z, 2)
501    Z = MA.masked_values(Z, 6)
502       
503    grid = Grid(X,Y,XMid, YMid, Z)
504
505    from geoplot.mpl_imports import basemap
506    from matplotlib.backends.backend_agg import FigureCanvasAgg
507    from matplotlib.figure import Figure
508    from matplotlib import cm
509
510    figsize=(600 / 80, 800 / 80)
511    fig = Figure(figsize=figsize, dpi=80)
512    axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])
513
514
515    bm = basemap.Basemap(llcrnrlon=xLimits[0],
516                       llcrnrlat=yLimits[0],
517                       urcrnrlon=xLimits[1],
518                       urcrnrlat=yLimits[1],
519                       resolution='h',
520                       suppress_ticks=False)
521    cmap = cm.winter
522    cmap.set_bad("w")   
523    drawer = GridDrawer(cmap, drawValues=True, valueFormat="%0.2f", showGridLines=False,
524                        outline=False)
525       
526    sm = drawer.draw(axes, grid, basemap=bm)
527   
528    bm.drawcoastlines(ax = axes)
529    bm.drawrivers(ax = axes, color='b')
530    bm.drawmeridians([-10,-5,0,5,10], ax=axes, dashes=(None,None))
531    bm.drawparallels([50,55,60], ax=axes, dashes=(None,None))
532   
533   
534    cb = fig.colorbar(sm, orientation='vertical')
535    print "tick position", cb.ax.yaxis.get_ticks_position()
536    cb.ax.yaxis.set_ticks_position('default')
537    print "tick position", cb.ax.yaxis.get_ticks_position()   
538
539
540    canvas = FigureCanvasAgg(fig)
541    filename = "grid_drawer_test.png"
542   
543    fullPath = os.path.join(outputsDir, filename)
544    canvas.print_figure(fullPath)
545    log.debug("Wrote " + fullPath)
546
Note: See TracBrowser for help on using the repository browser.