source: TI05-delivery/ows_framework/branches/ows_framework-ddp/ows_server/ows_server/lib/ddp_render.py @ 2744

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI05-delivery/ows_framework/branches/ows_framework-ddp/ows_server/ows_server/lib/ddp_render.py@2744
Revision 2744, 7.5 KB checked in by spascoe, 12 years ago (diff)
Line 
1"""
2Experiment to render a rotated grid using matplotlib.
3
4"""
5
6import pylab, cdms
7from matplotlib.toolkits.basemap import Basemap
8from cdms import selectors
9import Numeric as N
10import Image, ImageDraw
11
12import logging
13
14def grid_corners_to_mesh(grid_corner_lat, grid_corner_lon,
15                         bl_index=0, basemap=None):
16    """
17    Converts a scrip-style grid description into a matplotlib
18    compatible mesh.
19
20    We assume that grid boxes are contiguous.  I.e. the
21    right-hand edge of grid box (x, y) is the same as the left-hand
22    edge of grid box (x+1, y).
23
24    @param grid_corner_lon: An array of longitude coordinates of the
25            corners of each grid box.  The shape must be (nlat, nlon, 4).
26    @param grid_corner_lat: The latitude equivilent of grid_corner_lon.
27    @param nlat: The number of latitude points
28    @param nlon: The number of longitude points
29    @param bl_index: The index in grid_corner_* representing the
30        minimum lat/lon corner of the rotated grid.
31    @param basemap: If not None this is a Basemap instance used to
32        transform the mesh to projection coordinates.
33
34    @param return: (mesh_x, mesh_y)
35
36    """
37
38    # Get the index of the other corners
39    # This relies on corners being listed anticlockwise
40    br_index = (bl_index + 1) % 4
41    tr_index = (bl_index + 2) % 4
42    tl_index = (bl_index + 3) % 4
43
44    g_shape = grid_corner_lat.shape
45    (nlat, nlon, four) = g_shape
46    assert four == 4
47
48    # Initialise the mesh arrays
49    mesh_shape = (nlat+1, nlon+1) 
50    mesh_x = N.zeros(mesh_shape, typecode=N.Float64)
51    mesh_y = N.zeros(mesh_shape, typecode=N.Float64)
52
53    # Paste in values from the bottom-left corners.
54    mesh_x[:-1,:-1] = grid_corner_lon[:,:,bl_index]
55    mesh_y[:-1,:-1] = grid_corner_lat[:,:,bl_index]
56
57    # Paste in the border values, excluding the top-right corner point
58    mesh_x[-1,:-1] = grid_corner_lon[-1,:, br_index]
59    mesh_y[-1,:-1] = grid_corner_lat[-1,:, br_index]
60    mesh_x[:-1,-1] = grid_corner_lon[:,-1, tl_index]
61    mesh_y[:-1,-1] = grid_corner_lat[:,-1, tl_index]
62
63    # Then the corner point
64    mesh_x[-1,-1] = grid_corner_lon[-1,-1, tr_index]
65    mesh_y[-1,-1] = grid_corner_lat[-1,-1, tr_index]
66
67    if basemap:
68        return basemap(mesh_x, mesh_y)
69    else:
70        return (mesh_x, mesh_y)
71
72class RotatedGridRenderer(object):
73    """
74    Handle plotting of rotated grids using matplotlib.
75
76    We plot rotated grids by converting the grid boxes into a mesh and
77    using pylab.pcolormesh().  An alternative approach would be to
78    plot each box using matplotlib.collections.PolyCollection() it was
79    decided not to try and figure out how matplotlib works at this
80    level.
81
82    """
83    def __init__(self, shape, grid_corner_lat, grid_corner_lon, basemap=None):
84        """
85        @param shape: The shape of required grid
86        @param grid_corner_lon: An array of longitude coordinates of the
87            corners of each grid box.  The shape must be (x*y, 4).
88        @param grid_corner_lat: The latitude equivilent of grid_corner_lon.
89           
90        """
91
92        g_shape = tuple(list(shape) + [4])
93
94        g_lat = N.reshape(grid_corner_lat, g_shape)
95        g_lon = N.reshape(grid_corner_lon, g_shape)
96
97        self.mesh_x, self.mesh_y = grid_corners_to_mesh(g_lat, g_lon, 3,
98                                                        basemap=basemap)
99
100    def plotSlice(self, var, minValue, maxValue):
101        """
102        Plot a single lat/lon slice.
103
104        """
105        return pylab.pcolormesh(self.mesh_x, self.mesh_y, var,
106                                vmin=minValue, vmax=maxValue)
107
108    def iterPlotTimes(self, var, minValue, maxValue):
109        """
110        Generator that plots each time value in var.
111
112        Use this in conjunction with a loop that draws the figure to what
113        ever device is desired.
114
115        @yeild: time index plotted
116        @note: Assumes the shape of var is (time, lat, lon)
117       
118        """
119        if var.getTime().id != var.getAxisIds()[0]:
120            raise ValueError('Time must be the first dimension')
121       
122        for i, svar in enumerate(var):
123            self.plotSlice(svar, minValue, maxValue)
124            yield i
125       
126
127    def drawValues(self, var, size):
128        lat = var.getLatitude()
129        lon = var.getLongitude()
130
131        from itertools import izip
132
133        for (x, y, v) in izip(lon.flat, lat.flat, var.flat):
134            pylab.text(float(x), float(y), '%.1f' % v,
135                       size=size,
136                       family='sans-serif',
137                       ha='center',
138                       va='center'
139                       )
140       
141
142def main(filename, grid_filename, output_prefix):
143
144    bbox = (-16.5, 47.0, 4.5, 61.5)
145
146    # Initialise logging
147    logging.basicConfig(level=logging.INFO)
148    logger = logging.getLogger('mpl_render')
149
150    # Create the map projection
151    logger.info('Initialising Basemap')
152    bm = Basemap(projection='merc',
153                 resolution='l',
154                 llcrnrlon=bbox[0],
155                 llcrnrlat=bbox[1],
156                 urcrnrlon=bbox[2],
157                 urcrnrlat=bbox[3],
158                 lat_ts=0,
159                 )
160
161    # Some configuration options
162    dpi = 100
163    xlim, ylim = bm(bbox[::2], bbox[1::2])
164   
165    f = cdms.open(filename)
166    g = cdms.open(grid_filename)
167    r = None
168    gridshape = None
169
170
171    for variable in f.listvariables():
172
173        #!DEBUG: Shortcut for debugging
174        if variable != 'temp':
175            continue
176       
177
178        # Only do variables with a grid (i.e. miss axes)
179        if f[variable].getGrid() is None:
180            continue
181
182        # Only do variables on a single level for now
183        level = f[variable].getLevel()
184        if level is not None and level.shape != (1,):
185            logger.warning('Skipping variable %s: more than one level' % variable)
186            continue
187       
188        logger.info('Processing variable %s' % variable)
189        # Squeeze out the level dimension.
190        var = f(variable, squeeze=1)
191
192        # Setup the gridbox points, if necessary
193        if r is None or gridshape != var.getGrid().shape:
194            gridshape = var.getGrid().shape
195            logger.info('Initialising gridbox coordinates for grid %s' % (gridshape,))
196            r = RotatedGridRenderer(gridshape,
197                                    g['grid_corner_lat'], g['grid_corner_lon'],
198                                    basemap=bm)
199
200
201        # Get min/max of field.  This might be a configuration option in the future.
202        minValue = min(var.flat)
203        maxValue = max(var.flat)
204
205        # Initialise the matplotlib figure
206        fig = pylab.gcf()
207        # Make sure the aspect ratio of the figure is correct
208        fig_width = 8.0
209        fig_height = (ylim[1]-ylim[0])/(xlim[1]-xlim[0]) * fig_width
210        logging.info('Setting figure size to %s, %s' % (fig_width, fig_height))
211        fig.set_size_inches(fig_width, fig_height)
212
213        # Initialise the matplotlib axes
214        ax = pylab.gca()
215        # Turn off anotations
216        ax.set_axis_off()
217        ax.set_frame_on(False)
218        ax.set_position((0,0,1,1))
219
220        # Loop over each time point
221        for i in r.iterPlotTimes(var, minValue, maxValue):
222            bm.drawcoastlines()
223            logger.info('Plotted index %d of %s' % (i, variable))
224            figname = '%s_%s_%s.png' % (output_prefix, variable, i)
225            ax.set_xlim(xlim)
226            ax.set_ylim(ylim)
227            pylab.savefig(figname, dpi=dpi)
228            logger.info('Saved figure to %s' % figname)
229
230if __name__ == '__main__':
231    main('afixaa.pei3aug.pp', 'ddp_grid.nc', './out/rotgrid')
Note: See TracBrowser for help on using the repository browser.