source: qesdi/geoplot/trunk/lib/geoplot/layer_drawer.py @ 5704

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/layer_drawer.py@5704
Revision 5704, 7.6 KB checked in by pnorton, 11 years ago (diff)

Added a module to fix the import Image / from PIL import Image problem.

Line 
1"""
2An object to draw just the layer (map + grid) from the plot,
3can draw to a file, a string or create an Image object.
4"""
5
6
7import logging
8import StringIO
9import thread
10import time
11
12from matplotlib.figure import Figure
13
14from geoplot.image_import import Image
15from geoplot.grid_builder_lat_lon import GridBuilderLatLon
16from geoplot.grid_builder_national import GridBuilderNational
17from geoplot.grid_builder_rotated import GridBuilderRotated
18
19from geoplot.grid_factory import GridFactory
20from geoplot.map_factory import MapFactory
21
22import geoplot.utils as geoplot_utils
23
24from geoplot.colour_bar import ColourBar
25
26log = logging.getLogger(__name__)
27
28VALID_GRID_TYPES = ['latlon', 'national', 'rotated']
29
30VALID_PROJECTIONS = ['latlon', 'national']
31
32class LayerDrawerBase(object):
33    "Draws only the layer section of the plot to create a PIL image object"
34   
35    def __init__(self, 
36                 gridType='latlon', 
37                 transparent=False,
38                 projection='latlon',
39                 resolution=None,
40                 cmap=None, 
41                 cmapRange=(None,None), 
42                 intervalColourbar=False, 
43                 intervalNames=None):
44       
45        self._cb = ColourBar()
46        self.transparent = transparent
47        self.cmap = cmap
48        self.cmapRange = cmapRange
49        self.intervalColourbar = intervalColourbar
50        self.intervalNames = intervalNames
51       
52        self._gridFactory = GridFactory(dataType=gridType)
53        self._mapFactory = MapFactory(projection, drawCoast=True, drawRivers=False, resolution=resolution)
54       
55    def makeImage(self, xLimits=None, yLimits=None, width=800, height=600, dpi=100):
56        """
57        Creates a PIL image of the selected area of the layer.
58        """
59        st = time.time()
60        fig = self._getFigure(width, height, dpi)
61       
62        axes = self._addAxes(fig)
63       
64        self._drawToAxes(axes, xLimits, yLimits)
65       
66        self._resetAxes(axes, xLimits, yLimits)
67       
68        im = geoplot_utils.figureToImage(fig)
69       
70        log.debug("drawn layer in %s" % (time.time() -st ,))
71       
72        return im
73   
74   
75    def _drawToAxes(self, *args, **kwargs):
76        """
77        Draw the layer onto the axis, should be overidden by all subclasses.
78        """
79       
80        raise NotImplementedError()
81   
82    def _getFigure(self, width, height, dpi):
83        """
84        Returns a new figure object that is ready to be drawn on.
85        """
86
87        figsize=(width / float(dpi), height / float(dpi))
88
89        fig = Figure(figsize=figsize, dpi=dpi, facecolor='w', 
90                     frameon=(not self.transparent))
91       
92        log.debug("fig.frameon = %s" % (fig.frameon,))
93       
94        return fig       
95   
96    def _addAxes(self, figure):
97        """
98        Adds an axis to the figure object provided. The axes has no border and takes
99        up all the area on the figure so that anything drawn on the axis will
100        completly cover the figure.
101       
102        The axis background is transparent, if self.transparency is set to false
103        then the Figure's frameon should be set to true not the axis.
104        """
105        axes = figure.add_axes([0.0, 0.0, 1.0, 1.0], 
106                               xticks=[], yticks=[], frameon=False)
107       
108        if self.transparent:
109            axes.set_alpha(0.0)
110            axes.patch.set_alpha(0.0)
111       
112        return axes
113   
114    def _buildGrid(self, cdmsVar, xLimits, yLimits):
115        """
116        Builds a new grid object using the data found in the cdmsVar.
117        """
118       
119        self._gridFactory.cdmsVar = cdmsVar
120       
121        grid = self._gridFactory.getGrid(xLimits, yLimits)
122       
123        return grid
124   
125    def _resetAxes(self, axes, xLimits=None, yLimits=None):
126        """
127        resets the axis to the original limis and aspect after they have
128        been drawn on, this is needed as some methods of drawing to the axis
129        (notably basemap) change these properties.
130        """
131       
132        axes.set_aspect('auto')
133       
134        axes.set_xticks([])
135        axes.set_yticks([])
136       
137        if self.projection == 'latlon':
138            xLimitsMapUnits, yLimitsMapUnits = xLimits, yLimits
139        else:
140            map = self._getMap(xLimits, yLimits)
141            xLimitsMapUnits, yLimitsMapUnits = map.basemap(xLimits, yLimits)
142       
143        #reset the limits after drawing the grid
144        if xLimits != None:
145            axes.set_xlim(float(xLimitsMapUnits[0]), float(xLimitsMapUnits[1]))
146       
147        if yLimits != None:
148            axes.set_ylim(float(yLimitsMapUnits[0]), float(yLimitsMapUnits[1]))   
149       
150    def _getMap(self, xLimits, yLimits, ):
151        """
152        Returns a map object that corresponds to the current projection,
153        map objects can be used for transformation or drawing data.
154        """
155       
156        self._mapFactory.xLimits = xLimits
157        self._mapFactory.yLimits = yLimits
158       
159        map = self._mapFactory.buildMap()
160           
161        return map       
162   
163    ### properties ###       
164   
165    def __set_gridType(self, value):
166        if value not in VALID_GRID_TYPES:
167           
168            raise Exception(\
169             "Invalid value of '%s' for LayerDrawer.gridType property, must be one of %s" 
170                % (value, VALID_GRID_TYPES,))
171           
172        self._gridFactory.dataType = value
173
174    def __get_gridType(self):
175        return self._gridFactory.dataType
176   
177    gridType = property(__get_gridType, __set_gridType, None, None)
178   
179    def __set_showGridLines(self, value):
180        self._gridDrawer.showGridLines = value
181   
182    def __get_showGridLines(self):
183        return self._gridDrawer.showGridLines
184   
185    showGridLines = property(__get_showGridLines, __set_showGridLines) 
186   
187    def __set_cmapRange(self, value):
188        self._cmapRange = value
189       
190    def __get_cmapRange(self):
191        return self._cmapRange
192   
193    cmapRange = property(__get_cmapRange, __set_cmapRange)
194   
195    def __set_outline(self, value):
196        self._gridDrawer.outline = value
197       
198    def __get_outline(self):
199        return self._gridDrawer.outline
200   
201    outline = property(__get_outline, __set_outline)
202
203       
204    def __set_projection(self, value):
205        if value not in VALID_PROJECTIONS:
206           
207            raise Exception(\
208             "Invalid value of '%s' for projection property, must be one of %s" 
209                % (value, VALID_PROJECTIONS,))
210           
211        self._mapFactory.projection = value
212
213    def __get_projection(self):
214        return self._mapFactory.projection
215   
216    projection = property(__get_projection, __set_projection, None, None)
217
218    def __set_cmap(self, value):
219        self._cb.cmap = value
220       
221    def __get_cmap(self):
222        return self._cb.cmap
223   
224    cmap = property(__get_cmap, __set_cmap)
225   
226    def __set_cmapRange(self, value):
227        self._cb.colourBarMin = value[0]
228        self._cb.colourBarMax = value[1]
229       
230    def __get_cmapRange(self):
231        return (self._cb.colourBarMin, self._cb.colourBarMax)
232   
233    cmapRange = property(__get_cmapRange, __set_cmapRange)
234   
235    def __set_intervalColourbar(self, value):
236        self._cb.intervalColourbar = value
237       
238    def __get_intervalColourbar(self):
239        return self._cb.intervalColourbar
240   
241    intervalColourbar = property(__get_intervalColourbar, __set_intervalColourbar)   
242   
243    def __set_intervalNames(self, value):
244        self._cb.intervalNames = value
245       
246    def __get_intervalNames(self):
247        return self._cb.intervalNames
248   
249    intervalNames = property(__get_intervalNames, __set_intervalNames)
250   
Note: See TracBrowser for help on using the repository browser.