source: qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py @ 5826

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py@5826
Revision 5826, 17.7 KB checked in by pnorton, 10 years ago (diff)

Improved geoplot's behaviour when dealing with variables with axis in the order of lon/lat instead of lat/lon.

Line 
1"""
2grid_drawer.py
3============
4
5A GridDrawere knows how to draw a lat-lon grid onto a axes so that it will match
6with a basemap drawing on the same axis.
7
8"""
9#python modules
10import logging
11import operator
12import re
13import time
14from math import *
15
16#third party modules
17from geoplot.mpl_imports import basemap
18import matplotlib.colors
19import matplotlib.cm
20import matplotlib.collections
21import numpy as N
22import numpy.ma as MA
23import geoplot.mpl_imports
24
25#internal modules
26from geoplot.array_util import *
27from geoplot.grid import Grid
28import geoplot.utils as geoplot_utils
29
30#setup the logging
31log = logging.getLogger(__name__)
32
33class GridDrawer(object):
34    """
35    Responsible for knwoing how to draw a Grid object onto a given matplotlib
36    axes
37    """
38
39    def __init__(self, drawValues=False, drawValuesFile=False, 
40                 valueFormat="%.2f", showGridLines=False,
41                 outline=False, forcePcolor=False):
42        """
43        Constructs a GridDrawer object             
44        """
45        self.drawValues = drawValues
46        self.drawValuesFile = drawValuesFile
47        self.valueFormat = valueFormat
48        self.showGridLines = showGridLines
49        self.outline = outline
50        self.forcePcolor = forcePcolor
51
52    def draw(self, axes, grid, limits=None, basemap=None, norm=None,
53             cmap=None, fontSize='medium' ):
54        """
55        Draws the grid given on the axis.
56       
57        @param axes: the axes the grid will be drawn on
58        @type axes: matplotlib.axes
59        @param grid: the grid to be drawn
60        @type grid: geoplot.grid
61        @keyword limits: the limits of the drawing (in lat lon)
62        @type limits: tuple
63        @keyword basemap: the basemap instance to scale the grid values to be
64            drawn on the axis.
65        """
66
67        if cmap == None:
68            cmap = matplotlib.cm.get_cmap()
69            cmap.set_bad("w")   
70       
71        if norm == None:
72            norm = matplotlib.colors.Normalize(grid.values.min(), grid.values.max())     
73
74        xLimits, yLimits = self._transformLimits(limits, basemap)
75       
76        # check that the grid being drawn has data inside the limits, if there
77        # is no data in the limits there is no point drawing anything.
78        # can only check the grid is in the limits if they are provided
79        if limits != None:
80            if not self._isGridDataWithinLimits(grid, xLimits, yLimits):
81                return
82               
83        grid_mapUnits = grid.transform(basemap)
84
85        self._drawGridBoxes(axes, grid_mapUnits, basemap, cmap, norm)
86       
87#        log.debug("self.drawValues = %s" % (self.drawValues,))
88#        log.debug("self.drawValuesFile = %s" % (self.drawValuesFile,))
89
90        if self.drawValues == True:
91            log.debug('drawing grid values')
92           
93            #draw the vlaues onto the grid
94            self._drawMeshValues(axes, grid_mapUnits, xLimits, yLimits)
95       
96
97    def _transformLimits(self, limits, basemap):
98        """
99        Use the basemap object specified to transform the limits into map units.
100        """
101       
102        if limits == None:
103            return (None, None)
104        elif basemap == None:
105            return (limits[0], limits[1])
106        else:
107            return basemap(limits[0], limits[1])
108
109    def _drawGridBoxes(self, axes,  grid_mapUnits, basemap, cmap, norm):
110        """
111        Draws the grid boxes onto the axes.
112        """
113       
114        gridColour = 'k'
115        gridLineWidth = 0.25
116
117        kwargs = {'cmap':cmap, 'norm':norm}
118       
119        if self.outline == True :
120            # this will draw an outline on all grid cells that don't contain a
121            # missing value
122            kwargs['edgecolors'] = gridColour
123        else:
124            # needs to be a string rather than None or the rc default will be used
125            kwargs['edgecolors'] = 'None'
126
127
128        if geoplot.mpl_imports.oldMatplotlib == False:
129            # in the newer version of matplotlib you can set the linewidth +
130            # linestyle on the pcolor() function
131            kwargs['linewidth'] = gridLineWidth
132            kwargs['linestyle'] = 'solid'
133            kwargs['antialiased'] = True
134
135        #check the values aren't all masked
136       
137        # Its possible that this test becomes time consuming for large arrays,
138        # it is pssible that using sum might be faster, but this will only work
139        # if True + True = 2.
140        # valuesFound = not grid.values.mask.sum() == len(grid.values.mask.flatten())
141        # valuesFound = not grid.values.mask.sum() == reduce(lambda a,b: a*b, grid.values.shape)
142       
143        # for now this is so quick it doesn't matter.
144        valuesFound = not reduce(operator.and_, grid_mapUnits.values.mask.flat)
145
146        #if all the values are masked make sure the outline is drawn
147        if self.showGridLines == True:
148            st = time.time()
149            # this will draw all the grid cells even if their contents is missing
150            self._drawGridAsLines(axes, grid_mapUnits.boundsX, grid_mapUnits.boundsY, 
151                           color=gridColour, linewidth=gridLineWidth)
152            log.debug("drawn grid lines in = %s" % (time.time() - st,))
153       
154        if valuesFound == True:
155           
156            st = time.time()
157           
158           
159            log.debug("grid_mapUnits: boundsX.shape = %s, boundsY.shape = %s, values.shape = %s" 
160                      % (grid_mapUnits.boundsX.shape, grid_mapUnits.boundsY.shape, grid_mapUnits.values.shape))
161           
162            assert grid_mapUnits.boundsX.shape == grid_mapUnits.boundsY.shape
163           
164            values = grid_mapUnits.values
165            boundsShape = grid_mapUnits.boundsX.shape
166           
167            #try to fix the values if they are not the expected shape
168            if values.shape[0] != boundsShape[0] -1 and values.shape[1] != boundsShape[1] -1:
169               
170                message = "Value array shape doesn't match the bounds shape." 
171               
172                #check if the axis were in the opposite order
173                if values.shape[1] == boundsShape[0] -1 and values.shape[0] == boundsShape[1] -1:
174                    message += " It looks like the data may have been transposed."
175                    #log.warning("Transposing data to try and fit bounds to values")
176                    #values = N.transpose(values)
177               
178                raise Exception(message)
179                               
180
181
182            #if the bounds are parralel then use imshow, otherwise stick with pcolormesh
183            if self.forcePcolor == False and \
184               self._areBoundsParralel(grid_mapUnits.boundsX, grid_mapUnits.boundsY):
185               
186                self._drawImshow(axes, 
187                                 grid_mapUnits.boundsX, 
188                                 grid_mapUnits.boundsY, 
189                                 values, 
190                                 kwargs)
191                log.debug("drawn imshow in %s" % (time.time() - st ,))
192               
193            else:
194                self._drawPcolormesh(axes, 
195                                     grid_mapUnits.boundsX, 
196                                     grid_mapUnits.boundsY, 
197                                     values, 
198                                     kwargs)
199               
200                log.debug("drawn pcolormesh in %s" % (time.time() - st ,))
201           
202           
203
204
205    def _drawPcolormesh(self, axes, X, Y, Z, kwargs):
206        """
207        Draw the vlaues Z at using the bounds X and Y and the pcolormesh method on the axes object.
208        The kwargs dictionary is passed to the pcolormesh method.
209        """
210        res = axes.pcolormesh(X, Y, Z, **kwargs)
211        return res
212   
213    def _drawImshow(self, axes, X, Y, Z, kwargs):
214        """
215        Draw the values Z using the bounds X and Y and the imshow method on the axis.
216        The kwargs dictionary is passed to the imshow method.
217        """
218        yStart = Y[0,0]
219        yEnd = Y[-1,-1]
220       
221        log.debug("yStart = %s, yEnd = %s" % (yStart, yEnd))
222       
223        if yStart < yEnd:
224            extent=(X.min(), X.max(),Y.max(), Y.min())               
225        else:
226            extent=(X.min(), X.max(),Y.min(), Y.max())
227       
228        kwargs.pop('edgecolors')
229        kwargs.pop('linestyle')
230        kwargs.pop('linewidth')
231        kwargs.pop('antialiased')
232        #kwargs['origin'] = 'lower'
233       
234        log.debug("extent = %s" % (extent,))
235       
236        im = axes.imshow(Z, extent=extent,
237             interpolation='nearest',
238             **kwargs
239        )
240                   
241        axes.set_aspect('auto')
242       
243    def _areBoundsParralel(self, X, Y ):
244        """
245        checks if all the bound are parallel, this is a test to see if using
246        imshow is possible.
247       
248        Checks the equality to 5 significant figures.
249        """
250       
251        st = time.time()
252        equal = True
253       
254        for i in range(1, X.shape[0]):
255           
256            equal = N.allclose(X[0], X[i])
257           
258            if not equal:
259                log.debug("X[0,0:5] = %s" % (X[0,0:5],))
260                log.debug("X[i,0:5] = %s" % (X[i,0:5],))
261                break
262
263        xEq = equal
264       
265        if equal:
266           
267            for i in range(1, Y.shape[1]):
268               
269                equal = N.allclose( Y[:,0], Y[:,i])
270               
271                if not equal:
272                    break
273
274        yEq = equal
275        log.debug("xEq = %s, yEq = %s worked out in = %s" % (xEq, yEq, time.time() - st,))
276       
277        return equal
278           
279
280    def _drawGridAsLines(self, axes, X, Y, color='0.25', linewidth=0.3):
281        """
282        Draws a grid as a series of lines onto the given axis. This will work as long as the grid is
283        made up of quadrilaterals.
284        """
285        Ny, Nx = X.shape
286   
287        points=[]
288        # create a list of points to build up the line collection
289        # assuming the grid is filled with quadrelaterals this should
290        # draw it
291        for y in range(Ny):
292            for x in range(Nx):
293                x1 = X[y,x]
294                y1 = Y[y,x]
295   
296                # if this isn't the last intersection in the x direction, draw a line
297                # between this intersection and the next one
298                if x+1 < Nx:
299                    x2 = X[y,x+1]
300                    y2 = Y[y,x+1]
301                    points.append( ((x1, y1), (x2,y2), (x2,y2)) )
302
303                # if this isn't the last intersection in the y direction, draw a line
304                # between this intersection and the next one   
305                if y+1 < Ny:
306                    x3 = X[y+1,x]
307                    y3 = Y[y+1,x]
308                    points.append( ((x1, y1), (x3,y3), (x3,y3)) )
309   
310
311        line_segments = matplotlib.collections.LineCollection(points,
312                                       linewidths    = (linewidth,),
313                                       colors    = (color),
314                                       linestyle = 'solid')
315       
316        axes.add_collection(line_segments)
317
318
319    def _drawMeshValues(self, axes, grid_mapUnits, xLimits, yLimits):
320        """
321        Draw the values of each grid box onto the axes.
322        """
323       
324        # Parse drawValuesFile if specified:
325        if self.drawValuesFile != False:
326            drawValuesList = self._parseDrawValuesFile()
327        else:
328            drawValuesList = False
329
330        ni, nj = N.shape(grid_mapUnits.midpointsX)
331       
332        if xLimits != None:
333            xRangeMaximum = max(xLimits)
334            xRangeMinumum = min(xLimits)
335           
336        if yLimits != None:
337            yRangeMaximum = max(yLimits)
338            yRangeMinumum = min(yLimits)
339
340        for j in range(nj):
341            for i in range(ni):
342
343                # Check if we need to draw this value
344                if drawValuesList != False:
345                    if [i, j] not in drawValuesList: continue
346
347                x = grid_mapUnits.midpointsX[i][j]
348                y = grid_mapUnits.midpointsY[i][j]
349
350                v = grid_mapUnits.values[i,j]
351
352                #skip masked array values:
353                if v.__class__ == MA.MaskedArray:
354                    continue
355
356                # Avoid missing values with float conversion test
357                try:
358                    v = float(v)
359                except:
360                    log.warning("value %s (class=%s) is not a float" , v, v.__class__)
361                    continue
362
363                #get the bounds of the box (in map units), around the current position
364                xBounds, yBounds = self._getBounds(i, j, grid_mapUnits.boundsX, grid_mapUnits.boundsY)
365
366                xTop = max(xBounds)
367                yTop = max(yBounds)
368                xBottom = min(xBounds)
369                yBottom = min(yBounds)
370               
371                if xLimits != None:
372                    if xTop > xRangeMaximum or xBottom < xRangeMinumum:
373                        continue
374
375                if yLimits != None:
376                    if yTop > yRangeMaximum or yBottom < yRangeMinumum:
377                        continue
378
379                axes.text(float(x), float(y), self.valueFormat % v,
380                          ha='center',
381                          va='center',
382                          family='sans-serif',
383                          size=8)
384
385    def _getBounds(self, i, j, xMesh, yMesh):
386        """
387        Gets the points of each of the four corners for a given position in the mesh
388
389        @param i: the x index of the vlaue in the mesh
390        @type i: int
391        @param j: the x index of the vlaue in the mesh
392        @type j: int
393        @param xMesh: the array of the cdmsVar value x bounds that makes up the grid
394        @type xMesh: an instance of the numpy.ndarray class
395        @param yMesh: the array of the cdmsVar value y bounds that makes up the grid
396        @type yMesh: an instance of the numpy.ndarray class
397        """
398
399        coords = ((i, j), (i, j + 1), (i + 1, j), (i + 1, j + 1))
400
401
402        f = lambda array: lambda c: array[c]
403        xBounds = map(f(xMesh), coords)
404        yBounds = map(f(yMesh), coords)
405
406        return (xBounds, yBounds)
407
408    def _parseDrawValuesFile(self):
409        """
410        Attempts to parse the CSV file containing index values of y,x per line.
411        Validates file or raises error.
412        On successful validation returns a list of (y_index, x_index) pairs.
413
414        """
415
416        # Safely read file
417        try:
418            valuesFile = open(self.drawValuesFile)
419            lines = valuesFile.readlines()
420            valuesFile.close()
421        except:
422            raise Exception("Unable to parse 'drawValuesFile': " + str(self.drawValuesFile))
423
424        # Safely parse file contents
425        linePattern = re.compile("^(\d+),(\d+)$")
426        ln = 0
427        cellList = []
428
429        for line in lines:
430            ln += 1
431            line = line.strip()
432            match = linePattern.match(line)
433            if not match:
434                raise Exception("Unable to parse 'drawValuesFile' at line number " + str(ln) + ": " + str(self.drawValuesFile))
435            indices = [int(i) for i in match.groups()]
436            cellList.append(indices)
437
438        return cellList
439   
440    def _isGridDataWithinLimits(self, grid, xLimits, yLimits):
441       
442        xRange = (grid.boundsX.min(), grid.boundsX.max())
443        yRange = (grid.boundsY.min(), grid.boundsY.max())
444       
445        log.debug("xRange = %s" % (xRange,))
446        log.debug("yRange = %s" % (yRange,))
447       
448        isInLimits = geoplot_utils.isRangeInLimits(xRange, xLimits) and \
449                     geoplot_utils.isRangeInLimits(yRange, yLimits)
450       
451        return isInLimits
452   
453if __name__ == '__main__':
454    import geoplot.log_util as log_util
455    log_util.setupGeoplotConsoleHandler(log)
456
457    import pkg_resources, os
458    outputsDir = os.path.abspath(pkg_resources.resource_filename('geoplot', '../../outputs'))   
459
460    xLimits=(-13.0, 6.2) ; yLimits=(47.0, 61.0)
461
462    from numpy import meshgrid
463
464    x = [-6.,-4.,-2., 0., 2.]
465    xMid = [-5, -3, -1, 1]
466    xlim = len(x) -1
467    y = [54.,55.,56.,58.]
468    yMid = [54.5, 55.5, 57]
469    ylim = len(y) - 1
470    X,Y = meshgrid(x,y)
471   
472    X[0,0] = -7
473    Y[0,0] = 53
474    X[ylim, xlim] = 3
475    Y[ylim, xlim] = 59
476       
477    XMid, YMid = meshgrid(xMid, yMid)
478
479    Z=[[0]*xlim for i in range(ylim)]
480    for i in range(0 , xlim*ylim):
481        vX= i%xlim
482        vY= (i - i%xlim) / xlim
483        Z[vY][vX] = i
484    #Z = N.array(Z)
485    Z = MA.masked_values(Z, 2)
486    Z = MA.masked_values(Z, 6)
487       
488    grid = Grid(X,Y,XMid, YMid, Z)
489
490    from geoplot.mpl_imports import basemap
491    from matplotlib.backends.backend_agg import FigureCanvasAgg
492    from matplotlib.figure import Figure
493    from matplotlib import cm
494
495    figsize=(600 / 80, 800 / 80)
496    fig = Figure(figsize=figsize, dpi=80)
497    axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])
498
499
500    bm = basemap.Basemap(llcrnrlon=xLimits[0],
501                       llcrnrlat=yLimits[0],
502                       urcrnrlon=xLimits[1],
503                       urcrnrlat=yLimits[1],
504                       resolution='h',
505                       suppress_ticks=False)
506    cmap = cm.winter
507    cmap.set_bad("w")   
508    drawer = GridDrawer(cmap, drawValues=True, valueFormat="%0.2f", showGridLines=False,
509                        outline=False)
510       
511    sm = drawer.draw(axes, grid, basemap=bm)
512   
513    bm.drawcoastlines(ax = axes)
514    bm.drawrivers(ax = axes, color='b')
515    bm.drawmeridians([-10,-5,0,5,10], ax=axes, dashes=(None,None))
516    bm.drawparallels([50,55,60], ax=axes, dashes=(None,None))
517   
518   
519    cb = fig.colorbar(sm, orientation='vertical')
520    print "tick position", cb.ax.yaxis.get_ticks_position()
521    cb.ax.yaxis.set_ticks_position('default')
522    print "tick position", cb.ax.yaxis.get_ticks_position()   
523
524
525    canvas = FigureCanvasAgg(fig)
526    filename = "grid_drawer_test.png"
527   
528    fullPath = os.path.join(outputsDir, filename)
529    canvas.print_figure(fullPath)
530    log.debug("Wrote " + fullPath)
531
Note: See TracBrowser for help on using the repository browser.