source: TI05-delivery/ows_framework/branches/ows_framework-ddp/ows_server/ows_server/lib/ddp_render.py @ 2747

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI05-delivery/ows_framework/branches/ows_framework-ddp/ows_server/ows_server/lib/ddp_render.py@2747
Revision 2747, 7.5 KB checked in by spascoe, 12 years ago (diff)

Added Mercator stuff for Ag's demo.

Line 
1"""
2Experiment to render a rotated grid using matplotlib.
3
4"""
5
6import pylab, cdms
7from matplotlib.toolkits.basemap import Basemap
8from cdms import selectors
9import Numeric as N
10import Image, ImageDraw
11
12import logging
13
14def grid_corners_to_mesh(grid_corner_lat, grid_corner_lon,
15                         bl_index=0, basemap=None):
16    """
17    Converts a scrip-style grid description into a matplotlib
18    compatible mesh.
19
20    We assume that grid boxes are contiguous.  I.e. the
21    right-hand edge of grid box (x, y) is the same as the left-hand
22    edge of grid box (x+1, y).
23
24    @param grid_corner_lon: An array of longitude coordinates of the
25            corners of each grid box.  The shape must be (nlat, nlon, 4).
26    @param grid_corner_lat: The latitude equivilent of grid_corner_lon.
27    @param nlat: The number of latitude points
28    @param nlon: The number of longitude points
29    @param bl_index: The index in grid_corner_* representing the
30        minimum lat/lon corner of the rotated grid.
31    @param basemap: If not None this is a Basemap instance used to
32        transform the mesh to projection coordinates.
33
34    @param return: (mesh_x, mesh_y)
35
36    """
37
38    # Get the index of the other corners
39    # This relies on corners being listed anticlockwise
40    br_index = (bl_index + 1) % 4
41    tr_index = (bl_index + 2) % 4
42    tl_index = (bl_index + 3) % 4
43
44    g_shape = grid_corner_lat.shape
45    (nlat, nlon, four) = g_shape
46    assert four == 4
47
48    # Initialise the mesh arrays
49    mesh_shape = (nlat+1, nlon+1) 
50    mesh_x = N.zeros(mesh_shape, typecode=N.Float64)
51    mesh_y = N.zeros(mesh_shape, typecode=N.Float64)
52
53    # Paste in values from the bottom-left corners.
54    mesh_x[:-1,:-1] = grid_corner_lon[:,:,bl_index]
55    mesh_y[:-1,:-1] = grid_corner_lat[:,:,bl_index]
56
57    # Paste in the border values, excluding the top-right corner point
58    mesh_x[-1,:-1] = grid_corner_lon[-1,:, br_index]
59    mesh_y[-1,:-1] = grid_corner_lat[-1,:, br_index]
60    mesh_x[:-1,-1] = grid_corner_lon[:,-1, tl_index]
61    mesh_y[:-1,-1] = grid_corner_lat[:,-1, tl_index]
62
63    # Then the corner point
64    mesh_x[-1,-1] = grid_corner_lon[-1,-1, tr_index]
65    mesh_y[-1,-1] = grid_corner_lat[-1,-1, tr_index]
66
67    if basemap:
68        return basemap(mesh_x, mesh_y)
69    else:
70        return (mesh_x, mesh_y)
71
72class RotatedGridRenderer(object):
73    """
74    Handle plotting of rotated grids using matplotlib.
75
76    We plot rotated grids by converting the grid boxes into a mesh and
77    using pylab.pcolormesh().  An alternative approach would be to
78    plot each box using matplotlib.collections.PolyCollection() it was
79    decided not to try and figure out how matplotlib works at this
80    level.
81
82    """
83    def __init__(self, shape, grid_corner_lat, grid_corner_lon, basemap=None):
84        """
85        @param shape: The shape of required grid
86        @param grid_corner_lon: An array of longitude coordinates of the
87            corners of each grid box.  The shape must be (x*y, 4).
88        @param grid_corner_lat: The latitude equivilent of grid_corner_lon.
89           
90        """
91
92        g_shape = tuple(list(shape) + [4])
93
94        g_lat = N.reshape(grid_corner_lat, g_shape)
95        g_lon = N.reshape(grid_corner_lon, g_shape)
96
97        self.mesh_x, self.mesh_y = grid_corners_to_mesh(g_lat, g_lon, 3,
98                                                        basemap=basemap)
99
100    def plotSlice(self, var, minValue, maxValue):
101        """
102        Plot a single lat/lon slice.
103
104        """
105        return pylab.pcolormesh(self.mesh_x, self.mesh_y, var,
106                                vmin=minValue, vmax=maxValue)
107
108    def iterPlotTimes(self, var, minValue, maxValue):
109        """
110        Generator that plots each time value in var.
111
112        Use this in conjunction with a loop that draws the figure to what
113        ever device is desired.
114
115        @yeild: time index plotted
116        @note: Assumes the shape of var is (time, lat, lon)
117       
118        """
119        if var.getTime().id != var.getAxisIds()[0]:
120            raise ValueError('Time must be the first dimension')
121       
122        for i, svar in enumerate(var):
123            self.plotSlice(svar, minValue, maxValue)
124            yield i
125       
126
127    def drawValues(self, var, size):
128        lat = var.getLatitude()
129        lon = var.getLongitude()
130
131        from itertools import izip
132
133        for (x, y, v) in izip(lon.flat, lat.flat, var.flat):
134            pylab.text(float(x), float(y), '%.1f' % v,
135                       size=size,
136                       family='sans-serif',
137                       ha='center',
138                       va='center'
139                       )
140       
141
142def main(filename, grid_filename, output_prefix):
143
144    bbox = (-16.5, 47.0, 4.5, 61.5)
145
146    # Initialise logging
147    logging.basicConfig(level=logging.INFO)
148    logger = logging.getLogger('mpl_render')
149
150    # Create the map projection
151    logger.info('Initialising Basemap')
152    bm = Basemap(projection='merc',
153                 resolution='i',
154                 llcrnrlon=bbox[0],
155                 llcrnrlat=bbox[1],
156                 urcrnrlon=bbox[2],
157                 urcrnrlat=bbox[3],
158                 lat_ts=0,
159                 )
160
161    # Some configuration options
162    dpi = 100
163    xlim, ylim = bm(bbox[::2], bbox[1::2])
164   
165    f = cdms.open(filename)
166    g = cdms.open(grid_filename)
167    r = None
168    gridshape = None
169
170
171    for variable in f.listvariables():
172
173
174        # Only do variables with a grid (i.e. miss axes)
175        if f[variable].getGrid() is None:
176            continue
177
178        # Only do variables on a single level for now
179        level = f[variable].getLevel()
180        if level is not None and level.shape != (1,):
181            logger.warning('Skipping variable %s: more than one level' % variable)
182            continue
183       
184        logger.info('Processing variable %s' % variable)
185        var = f[variable]
186        # Squeeze out the level dimension.
187        var = var(squeeze=1)
188
189        if r is None or gridshape != var.getGrid().shape:
190            gridshape = var.getGrid().shape
191            logger.info('Initialising gridbox coordinates for grid %s' % (gridshape,))
192            r = RotatedGridRenderer(gridshape,
193                                    g['grid_corner_lat'], g['grid_corner_lon'],
194                                    basemap=bm)
195
196
197        # Get min/max of field.  This might be a configuration option in the future.
198        minValue = min(var.flat)
199        maxValue = max(var.flat)
200
201        def init_plot():
202            # Initialise the matplotlib figure
203            fig = pylab.figure()
204            # Make sure the aspect ratio of the figure is correct
205            fig_width = 8.0
206            fig_height = (ylim[1]-ylim[0])/(xlim[1]-xlim[0]) * fig_width
207            fig.set_size_inches(fig_width, fig_height)
208
209            # Initialise the matplotlib axes
210            ax = pylab.gca()
211            # Turn off anotations
212            ax.set_axis_off()
213            ax.set_frame_on(False)
214            ax.set_position((0,0,1,1))
215
216            return ax
217
218        # Loop over each time point
219        ax = init_plot()
220        for i in r.iterPlotTimes(var, minValue, maxValue):
221            bm.drawcoastlines()
222            logger.info('Plotted index %d of %s' % (i, variable))
223            figname = '%s_%s_%s.png' % (output_prefix, variable, i)
224            ax.set_xlim(xlim)
225            ax.set_ylim(ylim)
226            pylab.savefig(figname, dpi=dpi)
227            logger.info('Saved figure to %s' % figname)
228            ax = init_plot()
229
230if __name__ == '__main__':
231    import sys
232    (input_file, grid_file, out_prefix) = sys.argv[1:]
233   
234    main(input_file, grid_file, out_prefix)
Note: See TracBrowser for help on using the repository browser.