source: qesdi/geoplot/trunk/lib/geoplot/layer_drawer.py @ 5735

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/layer_drawer.py@5735
Revision 5735, 7.4 KB checked in by pnorton, 10 years ago (diff)

Added the ability to change the coastline colour in the coastline layer drawer. Also fixed a minor problem with the colour bar that was causing problems when unicode is used rather than strings.

Line 
1"""
2An object to draw just the layer (map + grid) from the plot,
3can draw to a file, a string or create an Image object.
4"""
5
6
7import logging
8import time
9
10from matplotlib.figure import Figure
11
12import geoplot.utils as geoplot_utils
13from geoplot.colour_bar import ColourBar
14from geoplot.grid_factory import GridFactory
15from geoplot.map_factory import MapFactory
16
17
18
19log = logging.getLogger(__name__)
20
21VALID_GRID_TYPES = ['latlon', 'national', 'rotated']
22
23VALID_PROJECTIONS = ['latlon', 'national']
24
25class LayerDrawerBase(object):
26    "Draws only the layer section of the plot to create a PIL image object"
27   
28    def __init__(self, 
29                 gridType='latlon', 
30                 transparent=False,
31                 projection='latlon',
32                 resolution=None,
33                 cmap=None, 
34                 cmapRange=(None,None), 
35                 intervalColourbar=False, 
36                 intervalNames=None):
37       
38        self._cb = ColourBar()
39        self.transparent = transparent
40        self.cmap = cmap
41       
42        log.debug("cmapRange = %s" % (cmapRange,))
43       
44        if cmapRange[0] is not None and cmapRange[1] is not None and \
45            cmapRange[0] > cmapRange[1]:
46            log.warning("cmapRange[0] > cmapRange[1], swapping values")
47            cmapRange = (cmapRange[1], cmapRange[0])             
48       
49        self.cmapRange = cmapRange
50        self.intervalColourbar = intervalColourbar
51        self.intervalNames = intervalNames
52       
53        self._gridFactory = GridFactory(dataType=gridType)
54        self._mapFactory = MapFactory(projection, drawCoast=True, drawRivers=False, resolution=resolution)
55       
56    def makeImage(self, xLimits=None, yLimits=None, width=800, height=600, dpi=200):
57        """
58        Creates a PIL image of the selected area of the layer.
59        """
60        st = time.time()
61        fig = self._getFigure(width, height, dpi)
62       
63        axes = self._addAxes(fig)
64       
65        self._drawToAxes(axes, xLimits, yLimits)
66       
67        self._resetAxes(axes, xLimits, yLimits)
68       
69        im = geoplot_utils.figureToImage(fig)
70       
71        log.debug("drawn layer in %s" % (time.time() -st ,))
72       
73        return im
74   
75   
76    def _drawToAxes(self, *args, **kwargs):
77        """
78        Draw the layer onto the axis, should be overidden by all subclasses.
79        """
80       
81        raise NotImplementedError()
82   
83    def _getFigure(self, width, height, dpi):
84        """
85        Returns a new figure object that is ready to be drawn on.
86        """
87
88        figsize=(width / float(dpi), height / float(dpi))
89
90        fig = Figure(figsize=figsize, dpi=dpi, facecolor='w', 
91                     frameon=(not self.transparent))
92       
93        log.debug("fig.frameon = %s" % (fig.frameon,))
94       
95        return fig       
96   
97    def _addAxes(self, figure):
98        """
99        Adds an axis to the figure object provided. The axes has no border and takes
100        up all the area on the figure so that anything drawn on the axis will
101        completly cover the figure.
102       
103        The axis background is transparent, if self.transparency is set to false
104        then the Figure's frameon should be set to true not the axis.
105        """
106        axes = figure.add_axes([0.0, 0.0, 1.0, 1.0], 
107                               xticks=[], yticks=[], frameon=False)
108       
109        if self.transparent:
110            axes.set_alpha(0.0)
111            axes.patch.set_alpha(0.0)
112       
113        return axes
114   
115    def _buildGrid(self, cdmsVar, xLimits, yLimits):
116        """
117        Builds a new grid object using the data found in the cdmsVar.
118        """
119       
120        self._gridFactory.cdmsVar = cdmsVar
121       
122        grid = self._gridFactory.getGrid(xLimits, yLimits)
123       
124        return grid
125   
126    def _resetAxes(self, axes, xLimits=None, yLimits=None):
127        """
128        resets the axis to the original limis and aspect after they have
129        been drawn on, this is needed as some methods of drawing to the axis
130        (notably basemap) change these properties.
131        """
132       
133        axes.set_aspect('auto')
134       
135        axes.set_xticks([])
136        axes.set_yticks([])
137       
138        if self.projection == 'latlon':
139            xLimitsMapUnits, yLimitsMapUnits = xLimits, yLimits
140        else:
141            map = self._getMap(xLimits, yLimits)
142            xLimitsMapUnits, yLimitsMapUnits = map.basemap(xLimits, yLimits)
143       
144        #reset the limits after drawing the grid
145        if xLimits != None:
146            axes.set_xlim(float(xLimitsMapUnits[0]), float(xLimitsMapUnits[1]))
147       
148        if yLimits != None:
149            axes.set_ylim(float(yLimitsMapUnits[0]), float(yLimitsMapUnits[1]))   
150       
151    def _getMap(self, xLimits, yLimits, ):
152        """
153        Returns a map object that corresponds to the current projection,
154        map objects can be used for transformation or drawing data.
155        """
156       
157        self._mapFactory.xLimits = xLimits
158        self._mapFactory.yLimits = yLimits
159       
160        map = self._mapFactory.buildMap()
161           
162        return map       
163   
164    ### properties ###       
165   
166    def __set_gridType(self, value):
167        if value not in VALID_GRID_TYPES:
168           
169            raise Exception(\
170             "Invalid value of '%s' for LayerDrawer.gridType property, must be one of %s" 
171                % (value, VALID_GRID_TYPES,))
172           
173        self._gridFactory.dataType = value
174
175    def __get_gridType(self):
176        return self._gridFactory.dataType
177   
178    gridType = property(__get_gridType, __set_gridType, None, None)
179   
180    def __set_showGridLines(self, value):
181        self._gridDrawer.showGridLines = value
182   
183    def __get_showGridLines(self):
184        return self._gridDrawer.showGridLines
185   
186    showGridLines = property(__get_showGridLines, __set_showGridLines) 
187       
188    def __set_outline(self, value):
189        self._gridDrawer.outline = value
190       
191    def __get_outline(self):
192        return self._gridDrawer.outline
193   
194    outline = property(__get_outline, __set_outline)
195
196       
197    def __set_projection(self, value):
198        if value not in VALID_PROJECTIONS:
199           
200            raise Exception(\
201             "Invalid value of '%s' for projection property, must be one of %s" 
202                % (value, VALID_PROJECTIONS,))
203           
204        self._mapFactory.projection = value
205
206    def __get_projection(self):
207        return self._mapFactory.projection
208   
209    projection = property(__get_projection, __set_projection, None, None)
210
211    def __set_cmap(self, value):
212        self._cb.cmap = value
213       
214    def __get_cmap(self):
215        return self._cb.cmap
216   
217    cmap = property(__get_cmap, __set_cmap)
218   
219    def __set_cmapRange(self, value):
220        self._cb.colourBarMin = value[0]
221        self._cb.colourBarMax = value[1]
222       
223    def __get_cmapRange(self):
224        return (self._cb.colourBarMin, self._cb.colourBarMax)
225   
226    cmapRange = property(__get_cmapRange, __set_cmapRange)
227   
228    def __set_intervalColourbar(self, value):
229        self._cb.intervalColourbar = value
230       
231    def __get_intervalColourbar(self):
232        return self._cb.intervalColourbar
233   
234    intervalColourbar = property(__get_intervalColourbar, __set_intervalColourbar)   
235   
236    def __set_intervalNames(self, value):
237        self._cb.intervalNames = value
238       
239    def __get_intervalNames(self):
240        return self._cb.intervalNames
241   
242    intervalNames = property(__get_intervalNames, __set_intervalNames)
243   
Note: See TracBrowser for help on using the repository browser.