source: qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py @ 5876

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py@5876
Revision 5876, 17.9 KB checked in by pnorton, 10 years ago (diff)

Added a faster layer drawer for the grid, a new style of colour bar and colour schemes which now hold the normalisation and colour map instances.

Line 
1"""
2grid_drawer.py
3============
4
5A GridDrawere knows how to draw a lat-lon grid onto a axes so that it will match
6with a basemap drawing on the same axis.
7
8"""
9#python modules
10import logging
11import operator
12import re
13import time
14from math import *
15
16#third party modules
17from geoplot.mpl_imports import basemap
18import matplotlib.colors
19import matplotlib.cm
20import matplotlib.collections
21import numpy as N
22import numpy.ma as MA
23import geoplot.mpl_imports
24
25#internal modules
26from geoplot.array_util import *
27from geoplot.grid import Grid
28import geoplot.utils as geoplot_utils
29
30#setup the logging
31log = logging.getLogger(__name__)
32
33class GridDrawer(object):
34    """
35    Responsible for knwoing how to draw a Grid object onto a given matplotlib
36    axes
37    """
38
39    def __init__(self, drawValues=False, drawValuesFile=False, 
40                 valueFormat="%.2f", showGridLines=False,
41                 outline=False):
42        """
43        Constructs a GridDrawer object             
44        """
45        self.drawValues = drawValues
46        self.drawValuesFile = drawValuesFile
47        self.valueFormat = valueFormat
48        self.showGridLines = showGridLines
49        self.outline = outline
50
51    def draw(self, axes, grid, limits=None, basemap=None, norm=None,
52             cmap=None, fontSize='medium', assumeBoundsParallel=None):
53        """
54        Draws the grid given on the axis.
55       
56        @param axes: the axes the grid will be drawn on
57        @type axes: matplotlib.axes
58        @param grid: the grid to be drawn
59        @type grid: geoplot.grid
60        @keyword limits: the limits of the drawing (in lat lon)
61        @type limits: tuple
62        @keyword basemap: the basemap instance to scale the grid values to be
63            drawn on the axis.
64        """
65
66        if cmap == None:
67            cmap = matplotlib.cm.get_cmap()
68            cmap.set_bad("w")   
69       
70        if norm == None:
71            norm = matplotlib.colors.Normalize(grid.values.min(), grid.values.max())     
72
73        xLimits, yLimits = self._transformLimits(limits, basemap)
74       
75        # check that the grid being drawn has data inside the limits, if there
76        # is no data in the limits there is no point drawing anything.
77        # can only check the grid is in the limits if they are provided
78        if limits != (None, None):
79            if not self._isGridDataWithinLimits(grid, xLimits, yLimits):
80                return
81               
82        grid_mapUnits = grid.transform(basemap)
83
84        self._drawGridBoxes(axes, grid_mapUnits, basemap, cmap, norm,
85                            assumeBoundsParallel)
86       
87#        log.debug("self.drawValues = %s" % (self.drawValues,))
88#        log.debug("self.drawValuesFile = %s" % (self.drawValuesFile,))
89
90        if self.drawValues == True:
91            log.debug('drawing grid values')
92           
93            #draw the vlaues onto the grid
94            self._drawMeshValues(axes, grid_mapUnits, xLimits, yLimits)
95       
96
97    def _transformLimits(self, limits, basemap):
98        """
99        Use the basemap object specified to transform the limits into map units.
100        """
101       
102        if limits == None:
103            return (None, None)
104        elif basemap == None:
105            return (limits[0], limits[1])
106        else:
107            return basemap(limits[0], limits[1])
108
109    def _drawGridBoxes(self, axes,  grid_mapUnits, basemap, cmap, norm,
110                       assumeBoundsParallel):
111        """
112        Draws the grid boxes onto the axes.
113        """
114       
115       
116        gridColour = 'k'
117        gridLineWidth = 0.25
118
119        kwargs = {'cmap':cmap, 'norm':norm}
120       
121        if self.outline == True :
122            # this will draw an outline on all grid cells that don't contain a
123            # missing value
124            kwargs['edgecolors'] = gridColour
125        else:
126            # needs to be a string rather than None or the rc default will be used
127            kwargs['edgecolors'] = 'None'
128
129
130        if geoplot.mpl_imports.oldMatplotlib == False:
131            # in the newer version of matplotlib you can set the linewidth +
132            # linestyle on the pcolor() function
133            kwargs['linewidth'] = gridLineWidth
134            kwargs['linestyle'] = 'solid'
135            kwargs['antialiased'] = True
136
137        #check the values aren't all masked
138       
139       
140        # Its possible that this test becomes time consuming for large arrays,
141        # it is pssible that using sum might be faster, but this will only work
142        # if True + True = 2.
143        # valuesFound = not grid.values.mask.sum() == len(grid.values.mask.flatten())
144        # valuesFound = not grid.values.mask.sum() == reduce(lambda a,b: a*b, grid.values.shape)
145       
146        # for now this is so quick it doesn't matter.
147        valuesFound = not reduce(operator.and_, grid_mapUnits.values.mask.flat)
148
149
150        #if all the values are masked make sure the outline is drawn
151        if self.showGridLines == True:
152            st = time.time()
153            # this will draw all the grid cells even if their contents is missing
154            self._drawGridAsLines(axes, grid_mapUnits.boundsX, grid_mapUnits.boundsY, 
155                           color=gridColour, linewidth=gridLineWidth)
156            log.debug("drawn grid lines in = %s" % (time.time() - st,))
157       
158        if valuesFound == True:
159           
160            log.debug("grid_mapUnits: boundsX.shape = %s, boundsY.shape = %s, values.shape = %s" 
161                      % (grid_mapUnits.boundsX.shape, grid_mapUnits.boundsY.shape, grid_mapUnits.values.shape))
162           
163            assert grid_mapUnits.boundsX.shape == grid_mapUnits.boundsY.shape
164           
165            values = grid_mapUnits.values
166            boundsShape = grid_mapUnits.boundsX.shape
167           
168            #try to fix the values if they are not the expected shape
169            if values.shape[0] != boundsShape[0] -1 and values.shape[1] != boundsShape[1] -1:
170               
171                message = "Value array shape doesn't match the bounds shape." 
172               
173                #check if the axis were in the opposite order
174                if values.shape[1] == boundsShape[0] -1 and values.shape[0] == boundsShape[1] -1:
175                    message += " It looks like the data may have been transposed."
176                    #log.warning("Transposing data to try and fit bounds to values")
177                    #values = N.transpose(values)
178               
179                raise Exception(message)
180                               
181            if assumeBoundsParallel == True:
182                useImshow = True
183            elif assumeBoundsParallel == False:
184                useImshow = False
185            else:
186                useImshow = self._areBoundsParralel(grid_mapUnits.boundsX, grid_mapUnits.boundsY)
187           
188            #if the bounds are parralel then use imshow, otherwise stick with pcolormesh
189            if useImshow:
190               
191                self._drawImshow(axes, 
192                                 grid_mapUnits.boundsX, 
193                                 grid_mapUnits.boundsY, 
194                                 values, 
195                                 kwargs)
196               
197            else:
198                self._drawPcolormesh(axes, 
199                                     grid_mapUnits.boundsX, 
200                                     grid_mapUnits.boundsY, 
201                                     values, 
202                                     kwargs)
203               
204           
205           
206
207
208    def _drawPcolormesh(self, axes, X, Y, Z, kwargs):
209        """
210        Draw the vlaues Z at using the bounds X and Y and the pcolormesh method on the axes object.
211        The kwargs dictionary is passed to the pcolormesh method.
212        """
213        res = axes.pcolormesh(X, Y, Z, **kwargs)
214        return res
215   
216    def _drawImshow(self, axes, X, Y, Z, kwargs):
217        """
218        Draw the values Z using the bounds X and Y and the imshow method on the axis.
219        The kwargs dictionary is passed to the imshow method.
220        """
221        st1 = time.time()
222       
223        yStart = Y[0,0]
224        yEnd = Y[-1,-1]
225       
226        #log.debug("yStart = %s, yEnd = %s" % (yStart, yEnd))
227       
228        if yStart < yEnd:
229            extent=(X.min(), X.max(),Y.max(), Y.min())               
230        else:
231            extent=(X.min(), X.max(),Y.min(), Y.max())
232       
233        kwargs.pop('edgecolors')
234        kwargs.pop('linestyle')
235        kwargs.pop('linewidth')
236        kwargs.pop('antialiased')
237        #kwargs['origin'] = 'lower'
238       
239        log.debug("got extent = %s yStart = %s yEnd = %s" % (extent, yStart, yEnd,))
240       
241        im = axes.imshow(Z, extent=extent,
242             interpolation='nearest',
243             **kwargs
244        )
245                   
246        axes.set_aspect('auto')
247       
248    def _areBoundsParralel(self, X, Y ):
249        """
250        checks if all the bound are parallel, this is a test to see if using
251        imshow is possible.
252       
253        Checks the equality to 5 significant figures.
254        """
255       
256        st = time.time()
257        equal = True
258       
259        for i in range(1, X.shape[0]):
260           
261            equal = N.allclose(X[0], X[i])
262           
263            if not equal:
264                log.debug("X[0,0:5] = %s" % (X[0,0:5],))
265                log.debug("X[i,0:5] = %s" % (X[i,0:5],))
266                break
267
268        xEq = equal
269       
270        if equal:
271           
272            for i in range(1, Y.shape[1]):
273               
274                equal = N.allclose( Y[:,0], Y[:,i])
275               
276                if not equal:
277                    break
278
279        yEq = equal
280       
281        log.debug("xEq = %s, yEq = %s worked out in = %s" % (xEq, yEq, time.time() - st,))
282       
283        return equal
284           
285
286    def _drawGridAsLines(self, axes, X, Y, color='0.25', linewidth=0.3):
287        """
288        Draws a grid as a series of lines onto the given axis. This will work as long as the grid is
289        made up of quadrilaterals.
290        """
291        Ny, Nx = X.shape
292   
293        points=[]
294        # create a list of points to build up the line collection
295        # assuming the grid is filled with quadrelaterals this should
296        # draw it
297        for y in range(Ny):
298            for x in range(Nx):
299                x1 = X[y,x]
300                y1 = Y[y,x]
301   
302                # if this isn't the last intersection in the x direction, draw a line
303                # between this intersection and the next one
304                if x+1 < Nx:
305                    x2 = X[y,x+1]
306                    y2 = Y[y,x+1]
307                    points.append( ((x1, y1), (x2,y2), (x2,y2)) )
308
309                # if this isn't the last intersection in the y direction, draw a line
310                # between this intersection and the next one   
311                if y+1 < Ny:
312                    x3 = X[y+1,x]
313                    y3 = Y[y+1,x]
314                    points.append( ((x1, y1), (x3,y3), (x3,y3)) )
315   
316
317        line_segments = matplotlib.collections.LineCollection(points,
318                                       linewidths    = (linewidth,),
319                                       colors    = (color),
320                                       linestyle = 'solid')
321       
322        axes.add_collection(line_segments)
323
324
325    def _drawMeshValues(self, axes, grid_mapUnits, xLimits, yLimits):
326        """
327        Draw the values of each grid box onto the axes.
328        """
329       
330        # Parse drawValuesFile if specified:
331        if self.drawValuesFile != False:
332            drawValuesList = self._parseDrawValuesFile()
333        else:
334            drawValuesList = False
335
336        ni, nj = N.shape(grid_mapUnits.midpointsX)
337       
338        if xLimits != None:
339            xRangeMaximum = max(xLimits)
340            xRangeMinumum = min(xLimits)
341           
342        if yLimits != None:
343            yRangeMaximum = max(yLimits)
344            yRangeMinumum = min(yLimits)
345
346        for j in range(nj):
347            for i in range(ni):
348
349                # Check if we need to draw this value
350                if drawValuesList != False:
351                    if [i, j] not in drawValuesList: continue
352
353                x = grid_mapUnits.midpointsX[i][j]
354                y = grid_mapUnits.midpointsY[i][j]
355
356                v = grid_mapUnits.values[i,j]
357
358                #skip masked array values:
359                if v.__class__ == MA.MaskedArray:
360                    continue
361
362                # Avoid missing values with float conversion test
363                try:
364                    v = float(v)
365                except:
366                    log.warning("value %s (class=%s) is not a float" , v, v.__class__)
367                    continue
368
369                #get the bounds of the box (in map units), around the current position
370                xBounds, yBounds = self._getBounds(i, j, grid_mapUnits.boundsX, grid_mapUnits.boundsY)
371
372                xTop = max(xBounds)
373                yTop = max(yBounds)
374                xBottom = min(xBounds)
375                yBottom = min(yBounds)
376               
377                if xLimits != None:
378                    if xTop > xRangeMaximum or xBottom < xRangeMinumum:
379                        continue
380
381                if yLimits != None:
382                    if yTop > yRangeMaximum or yBottom < yRangeMinumum:
383                        continue
384
385                axes.text(float(x), float(y), self.valueFormat % v,
386                          ha='center',
387                          va='center',
388                          family='sans-serif',
389                          size=8)
390
391    def _getBounds(self, i, j, xMesh, yMesh):
392        """
393        Gets the points of each of the four corners for a given position in the mesh
394
395        @param i: the x index of the vlaue in the mesh
396        @type i: int
397        @param j: the x index of the vlaue in the mesh
398        @type j: int
399        @param xMesh: the array of the cdmsVar value x bounds that makes up the grid
400        @type xMesh: an instance of the numpy.ndarray class
401        @param yMesh: the array of the cdmsVar value y bounds that makes up the grid
402        @type yMesh: an instance of the numpy.ndarray class
403        """
404
405        coords = ((i, j), (i, j + 1), (i + 1, j), (i + 1, j + 1))
406
407
408        f = lambda array: lambda c: array[c]
409        xBounds = map(f(xMesh), coords)
410        yBounds = map(f(yMesh), coords)
411
412        return (xBounds, yBounds)
413
414    def _parseDrawValuesFile(self):
415        """
416        Attempts to parse the CSV file containing index values of y,x per line.
417        Validates file or raises error.
418        On successful validation returns a list of (y_index, x_index) pairs.
419
420        """
421
422        # Safely read file
423        try:
424            valuesFile = open(self.drawValuesFile)
425            lines = valuesFile.readlines()
426            valuesFile.close()
427        except:
428            raise Exception("Unable to parse 'drawValuesFile': " + str(self.drawValuesFile))
429
430        # Safely parse file contents
431        linePattern = re.compile("^(\d+),(\d+)$")
432        ln = 0
433        cellList = []
434
435        for line in lines:
436            ln += 1
437            line = line.strip()
438            match = linePattern.match(line)
439            if not match:
440                raise Exception("Unable to parse 'drawValuesFile' at line number " + str(ln) + ": " + str(self.drawValuesFile))
441            indices = [int(i) for i in match.groups()]
442            cellList.append(indices)
443
444        return cellList
445   
446    def _isGridDataWithinLimits(self, grid, xLimits, yLimits):
447       
448        xRange = (grid.boundsX.min(), grid.boundsX.max())
449        yRange = (grid.boundsY.min(), grid.boundsY.max())
450       
451        #log.debug("xRange = %s" % (xRange,))
452        #log.debug("yRange = %s" % (yRange,))
453       
454        isInLimits = geoplot_utils.isRangeInLimits(xRange, xLimits) and \
455                     geoplot_utils.isRangeInLimits(yRange, yLimits)
456       
457        return isInLimits
458   
459if __name__ == '__main__':
460    import geoplot.log_util as log_util
461    log_util.setupGeoplotConsoleHandler(log)
462
463    import pkg_resources, os
464    outputsDir = os.path.abspath(pkg_resources.resource_filename('geoplot', '../../outputs'))   
465
466    xLimits=(-13.0, 6.2) ; yLimits=(47.0, 61.0)
467
468    from numpy import meshgrid
469
470    x = [-6.,-4.,-2., 0., 2.]
471    xMid = [-5, -3, -1, 1]
472    xlim = len(x) -1
473    y = [54.,55.,56.,58.]
474    yMid = [54.5, 55.5, 57]
475    ylim = len(y) - 1
476    X,Y = meshgrid(x,y)
477   
478    X[0,0] = -7
479    Y[0,0] = 53
480    X[ylim, xlim] = 3
481    Y[ylim, xlim] = 59
482       
483    XMid, YMid = meshgrid(xMid, yMid)
484
485    Z=[[0]*xlim for i in range(ylim)]
486    for i in range(0 , xlim*ylim):
487        vX= i%xlim
488        vY= (i - i%xlim) / xlim
489        Z[vY][vX] = i
490    #Z = N.array(Z)
491    Z = MA.masked_values(Z, 2)
492    Z = MA.masked_values(Z, 6)
493       
494    grid = Grid(X,Y,XMid, YMid, Z)
495
496    from geoplot.mpl_imports import basemap
497    from matplotlib.backends.backend_agg import FigureCanvasAgg
498    from matplotlib.figure import Figure
499    from matplotlib import cm
500
501    figsize=(600 / 80, 800 / 80)
502    fig = Figure(figsize=figsize, dpi=80)
503    axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])
504
505
506    bm = basemap.Basemap(llcrnrlon=xLimits[0],
507                       llcrnrlat=yLimits[0],
508                       urcrnrlon=xLimits[1],
509                       urcrnrlat=yLimits[1],
510                       resolution='h',
511                       suppress_ticks=False)
512    cmap = cm.winter
513    cmap.set_bad("w")   
514    drawer = GridDrawer(cmap, drawValues=True, valueFormat="%0.2f", showGridLines=False,
515                        outline=False)
516       
517    sm = drawer.draw(axes, grid, basemap=bm)
518   
519    bm.drawcoastlines(ax = axes)
520    bm.drawrivers(ax = axes, color='b')
521    bm.drawmeridians([-10,-5,0,5,10], ax=axes, dashes=(None,None))
522    bm.drawparallels([50,55,60], ax=axes, dashes=(None,None))
523   
524   
525    cb = fig.colorbar(sm, orientation='vertical')
526    print "tick position", cb.ax.yaxis.get_ticks_position()
527    cb.ax.yaxis.set_ticks_position('default')
528    print "tick position", cb.ax.yaxis.get_ticks_position()   
529
530
531    canvas = FigureCanvasAgg(fig)
532    filename = "grid_drawer_test.png"
533   
534    fullPath = os.path.join(outputsDir, filename)
535    canvas.print_figure(fullPath)
536    log.debug("Wrote " + fullPath)
537
Note: See TracBrowser for help on using the repository browser.