source: qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py @ 5610

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/grid_drawer.py@5610
Revision 5610, 14.7 KB checked in by pnorton, 12 years ago (diff)

Modified the existing plot_base into a grid_plot class that uses the new map and grid factories.

I've also modified the plot_lat_lon, plot_national and plot_rotated to used the grid_plot as a base class.

Line 
1"""
2grid_drawer.py
3============
4
5A GridDrawere knows how to draw a lat-lon grid onto a axes so that it will match
6with a basemap drawing on the same axis.
7
8"""
9#python modules
10import logging
11import operator
12import re
13import time
14from math import *
15
16#third party modules
17from geoplot.mpl_imports import basemap
18import matplotlib
19import matplotlib.colors
20import matplotlib.cm
21import matplotlib.collections
22import numpy as N
23import numpy.ma as MA
24import geoplot.mpl_imports
25
26#internal modules
27from geoplot.array_util import *
28from geoplot.grid import Grid
29from geoplot.grid_builder_base import GridBuilderBase
30
31#setup the logging
32log = logging.getLogger(__name__)
33
34class GridDrawer(object):
35    """
36    Responsible for knwoing how to draw a Grid object onto a given matplotlib
37    axes
38    """
39
40    def __init__(self, drawValues=False, drawValuesFile=False, 
41                 valueFormat="%.2f", showGridLines=False,
42                 outline=False):
43        """
44        Constructs a GridDrawer object             
45        """
46        self.drawValues = drawValues
47        self.drawValuesFile = drawValuesFile
48        self.valueFormat = valueFormat
49        self.showGridLines = showGridLines
50        self.outline = outline
51
52    def draw(self, axes, grid, limits=None, basemap=None, norm=None,
53             cmap=None, fontSize='medium' ):
54        """
55        Draws the grid given on the axis.
56       
57        @param axes: the axes the grid will be drawn on
58        @type axes: matplotlib.axes
59        @param grid: the grid to be drawn
60        @type grid: geoplot.grid
61        @keyword limits: the limits of the drawing (in lat lon)
62        @type limits: tuple
63        @keyword basemap: the basemap instance to scale the grid values to be
64            drawn on the axis.
65        """
66
67        if cmap == None:
68            cmap = matplotlib.cm.get_cmap()
69            cmap.set_bad("w")   
70       
71        if norm == None:
72            norm = matplotlib.colors.Normalize(grid.values.min(), grid.values.max())     
73
74#        log.info("Drawing Grid")
75
76        xLimits, yLimits = self._transformLimits(limits, basemap)
77       
78        grid_mapUnits = grid.transform(basemap)
79               
80#        log.debug("xLimits = %s" % (xLimits,))
81#        log.debug("yLimits = %s" % (yLimits,))
82
83        #draw the grid return the quadmesh for use
84        sm = self._drawGridBoxes(axes, grid_mapUnits, basemap, cmap, norm)
85       
86#        log.debug("self.drawValues = %s" % (self.drawValues,))
87#        log.debug("self.drawValuesFile = %s" % (self.drawValuesFile,))
88
89        if self.drawValues == True:
90            log.debug('drawing grid values')
91
92            #draw the vlaues onto the grid
93            self._drawMeshValues(axes, grid_mapUnits, xLimits, yLimits)
94       
95
96    def _transformLimits(self, limits, basemap):
97        if limits == None:
98            return (None, None)
99        elif basemap == None:
100            return (limits[0], limits[1])
101        else:
102            return basemap(limits[0], limits[1])
103
104    def _drawGridBoxes(self, axes,  grid_mapUnits, basemap, cmap, norm):
105        """
106        Draws the grid boxes onto the axes.
107        """
108
109        gridColour = 'k'
110        gridLineWidth = 0.25
111
112        kwargs = {'cmap':cmap, 'norm':norm}
113       
114        if self.outline == True :
115            # this will draw an outline on all grid cells that don't contain a
116            # missing value
117            kwargs['edgecolors'] = gridColour
118        else:
119            # needs to be a string rather than None or the rc default will be used
120            kwargs['edgecolors'] = 'None'
121
122
123        if geoplot.mpl_imports.oldMatplotlib == False:
124            # in the newer version of matplotlib you can set the linewidth +
125            # linestyle on the pcolor() function
126            kwargs['linewidth'] = gridLineWidth
127            kwargs['linestyle'] = 'solid'
128            kwargs['antialiased'] = False
129
130        #check the values aren't all masked
131       
132        # Its possible that this test becomes time consuming for large arrays,
133        # it is pssible that using sum might be faster, but this will only work
134        # if True + True = 2.
135        # valuesFound = not grid.values.mask.sum() == len(grid.values.mask.flatten())
136       
137        # for now this is so quick it doesn't matter.
138        valuesFound = not reduce(operator.and_, grid_mapUnits.values.mask.flat)
139
140       
141        #if all the values are masked make sure the outline is drawn
142        if self.showGridLines == True:
143            # this will draw all the grid cells even if their contents is missing
144            self._drawGridAsLines(axes, grid_mapUnits.boundsX, grid_mapUnits.boundsY, 
145                           color=gridColour, linewidth=gridLineWidth)
146       
147        if valuesFound == True:
148           
149            st = time.time()
150           
151            log.debug("grid_mapUnits.boundsX.shape = %s" % (grid_mapUnits.boundsX.shape,))
152            log.debug("grid_mapUnits.boundsY.shape = %s" % (grid_mapUnits.boundsY.shape,))
153            log.debug("grid_mapUnits.values.shape = %s" % (grid_mapUnits.values.shape,))
154           
155            values = grid_mapUnits.values
156           
157            assert grid_mapUnits.boundsX.shape == grid_mapUnits.boundsY.shape
158           
159            boundsShape = grid_mapUnits.boundsX.shape
160           
161            #try to fix the values if they are not the expected shape
162            if values.shape[0] != boundsShape[0] -1 and values.shape[1] != boundsShape[1] -1:
163               
164                #check if the axis were in the opposite order
165                if values.shape[1] == boundsShape[0] -1 and values.shape[0] == boundsShape[1] -1:
166                    values = N.transpose(values)
167                 
168            log.debug("values.shape = %s" % (values.shape,))
169
170            axes.pcolormesh(grid_mapUnits.boundsX, 
171                            grid_mapUnits.boundsY, 
172                            values, 
173                            **kwargs)
174           
175            log.debug("drawn mesh in %s" % (time.time() - st ,))
176
177#            st = time.time()
178#           
179#            boundsMin = grid_mapUnits.boundsY[0,0]
180#            boundsMax = grid_mapUnits.boundsY[-1,-1]
181#           
182#            log.debug("boundsMin = %s" % (boundsMin,))
183#            log.debug("boundsMax = %s" % (boundsMax,))
184#           
185#            if boundsMin < boundsMax:
186#                extent=(grid_mapUnits.boundsX.min(), grid_mapUnits.boundsX.max(),
187#                        grid_mapUnits.boundsY.max(), grid_mapUnits.boundsY.min())               
188#            else:
189#                extent=(grid_mapUnits.boundsX.min(), grid_mapUnits.boundsX.max(),
190#                        grid_mapUnits.boundsY.min(), grid_mapUnits.boundsY.max())
191#           
192#            im = axes.imshow(grid_mapUnits.values, extent=extent,
193#                 interpolation='nearest',
194#                 **kwargs
195#            )
196#                       
197#            axes.set_aspect('auto')
198#           
199#            log.debug("drawn imshow in %s" % (time.time() - st ,))
200
201   
202    def _drawGridAsLines(self, axes, X, Y, color='0.25', linewidth=0.3):
203        """
204        Draws a grid as a series of lines onto the given axis. This will work as long as the grid is
205        made up of quadrilaterals.
206        """
207        Ny, Nx = X.shape
208   
209        points=[]
210        # create a list of points to build up the line collection
211        # assuming the grid is filled with quadrelaterals this should
212        # draw it
213        for y in range(Ny):
214            for x in range(Nx):
215                x1 = X[y,x]
216                y1 = Y[y,x]
217   
218                # if this isn't the last intersection in the x direction, draw a line
219                # between this intersection and the next one
220                if x+1 < Nx:
221                    x2 = X[y,x+1]
222                    y2 = Y[y,x+1]
223                    points.append( ((x1, y1), (x2,y2), (x2,y2)) )
224
225                # if this isn't the last intersection in the y direction, draw a line
226                # between this intersection and the next one   
227                if y+1 < Ny:
228                    x3 = X[y+1,x]
229                    y3 = Y[y+1,x]
230                    points.append( ((x1, y1), (x3,y3), (x3,y3)) )
231   
232
233        line_segments = matplotlib.collections.LineCollection(points,
234                                       linewidths    = (linewidth,),
235                                       colors    = (color),
236                                       linestyle = 'solid')
237       
238        axes.add_collection(line_segments)
239
240
241    def _drawMeshValues(self, axes, grid_mapUnits, xLimits, yLimits):
242        """
243        Draw the values of each grid box onto the axes.
244        """
245       
246        # Parse drawValuesFile if specified:
247        if self.drawValuesFile != False:
248            drawValuesList = self._parseDrawValuesFile()
249        else:
250            drawValuesList = False
251
252        ni, nj = N.shape(grid_mapUnits.midpointsX)
253       
254        if xLimits != None:
255            xRangeMaximum = max(xLimits)
256            xRangeMinumum = min(xLimits)
257           
258        if yLimits != None:
259            yRangeMaximum = max(yLimits)
260            yRangeMinumum = min(yLimits)
261
262        for j in range(nj):
263            for i in range(ni):
264
265                # Check if we need to draw this value
266                if drawValuesList != False:
267                    if [i, j] not in drawValuesList: continue
268
269                x = grid_mapUnits.midpointsX[i][j]
270                y = grid_mapUnits.midpointsY[i][j]
271
272                v = grid_mapUnits.values[i,j]
273
274                #skip masked array values:
275                if v.__class__ == MA.MaskedArray:
276                    continue
277
278                # Avoid missing values with float conversion test
279                try:
280                    v = float(v)
281                except:
282                    log.warning("value %s (class=%s) is not a float" , v, v.__class__)
283                    continue
284
285                #get the bounds of the box (in map units), around the current position
286                xBounds, yBounds = self._getBounds(i, j, grid_mapUnits.boundsX, grid_mapUnits.boundsY)
287
288                xTop = max(xBounds)
289                yTop = max(yBounds)
290                xBottom = min(xBounds)
291                yBottom = min(yBounds)
292               
293                if xLimits != None:
294                    if xTop > xRangeMaximum or xBottom < xRangeMinumum:
295                        continue
296
297                if yLimits != None:
298                    if yTop > yRangeMaximum or yBottom < yRangeMinumum:
299                        continue
300
301                axes.text(float(x), float(y), self.valueFormat % v,
302                          ha='center',
303                          va='center',
304                          family='sans-serif',
305                          size=8)
306
307    def _getBounds(self, i, j, xMesh, yMesh):
308        """
309        Gets the points of each of the four corners for a given position in the mesh
310
311        @param i: the x index of the vlaue in the mesh
312        @type i: int
313        @param j: the x index of the vlaue in the mesh
314        @type j: int
315        @param xMesh: the array of the cdmsVar value x bounds that makes up the grid
316        @type xMesh: an instance of the numpy.ndarray class
317        @param yMesh: the array of the cdmsVar value y bounds that makes up the grid
318        @type yMesh: an instance of the numpy.ndarray class
319        """
320
321        coords = ((i, j), (i, j + 1), (i + 1, j), (i + 1, j + 1))
322
323
324        f = lambda array: lambda c: array[c]
325        xBounds = map(f(xMesh), coords)
326        yBounds = map(f(yMesh), coords)
327
328        return (xBounds, yBounds)
329
330    def _parseDrawValuesFile(self):
331        """
332        Attempts to parse the CSV file containing index values of y,x per line.
333        Validates file or raises error.
334        On successful validation returns a list of (y_index, x_index) pairs.
335
336        """
337
338        # Safely read file
339        try:
340            valuesFile = open(self.drawValuesFile)
341            lines = valuesFile.readlines()
342            valuesFile.close()
343        except:
344            raise Exception("Unable to parse 'drawValuesFile': " + str(self.drawValuesFile))
345
346        # Safely parse file contents
347        linePattern = re.compile("^(\d+),(\d+)$")
348        ln = 0
349        cellList = []
350
351        for line in lines:
352            ln += 1
353            line = line.strip()
354            match = linePattern.match(line)
355            if not match:
356                raise Exception("Unable to parse 'drawValuesFile' at line number " + str(ln) + ": " + str(self.drawValuesFile))
357            indices = [int(i) for i in match.groups()]
358            cellList.append(indices)
359
360        return cellList
361 
362if __name__ == '__main__':
363    import geoplot.log_util as log_util
364    log_util.setupGeoplotConsoleHandler(log)
365
366    import pkg_resources, os
367    outputsDir = os.path.abspath(pkg_resources.resource_filename('geoplot', '../../outputs'))   
368
369    xLimits=(-13.0, 6.2) ; yLimits=(47.0, 61.0)
370
371    from numpy import meshgrid
372    from cdms_utils.cdms_compat import N, MA
373
374    x = [-6.,-4.,-2., 0., 2.]
375    xMid = [-5, -3, -1, 1]
376    xlim = len(x) -1
377    y = [54.,55.,56.,58.]
378    yMid = [54.5, 55.5, 57]
379    ylim = len(y) - 1
380    X,Y = meshgrid(x,y)
381   
382    X[0,0] = -7
383    Y[0,0] = 53
384    X[ylim, xlim] = 3
385    Y[ylim, xlim] = 59
386       
387    XMid, YMid = meshgrid(xMid, yMid)
388
389    Z=[[0]*xlim for i in range(ylim)]
390    for i in range(0 , xlim*ylim):
391        vX= i%xlim
392        vY= (i - i%xlim) / xlim
393        Z[vY][vX] = i
394    #Z = N.array(Z)
395    Z = MA.masked_values(Z, 2)
396    Z = MA.masked_values(Z, 6)
397       
398    grid = Grid(X,Y,XMid, YMid, Z)
399
400    from geoplot.mpl_imports import basemap
401    from matplotlib.backends.backend_agg import FigureCanvasAgg
402    from matplotlib.figure import Figure
403    from matplotlib import cm
404
405    figsize=(600 / 80, 800 / 80)
406    fig = Figure(figsize=figsize, dpi=80)
407    axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])
408
409
410    bm = basemap.Basemap(llcrnrlon=xLimits[0],
411                       llcrnrlat=yLimits[0],
412                       urcrnrlon=xLimits[1],
413                       urcrnrlat=yLimits[1],
414                       resolution='h',
415                       suppress_ticks=False)
416    cmap = cm.winter
417    cmap.set_bad("w")   
418    drawer = GridDrawer(cmap, drawValues=True, valueFormat="%0.2f", showGridLines=False,
419                        outline=False)
420       
421    sm = drawer.draw(axes, grid, basemap=bm)
422   
423    bm.drawcoastlines(ax = axes)
424    bm.drawrivers(ax = axes, color='b')
425    bm.drawmeridians([-10,-5,0,5,10], ax=axes, dashes=(None,None))
426    bm.drawparallels([50,55,60], ax=axes, dashes=(None,None))
427   
428   
429    cb = fig.colorbar(sm, orientation='vertical')
430    print "tick position", cb.ax.yaxis.get_ticks_position()
431    cb.ax.yaxis.set_ticks_position('default')
432    print "tick position", cb.ax.yaxis.get_ticks_position()   
433
434
435    canvas = FigureCanvasAgg(fig)
436    filename = "grid_drawer_test.png"
437   
438    fullPath = os.path.join(outputsDir, filename)
439    canvas.print_figure(fullPath)
440    log.debug("Wrote " + fullPath)
441
Note: See TracBrowser for help on using the repository browser.