source: qesdi/geoplot/trunk/lib/geoplot/colour_bar.py @ 5735

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/colour_bar.py@5735
Revision 5735, 11.4 KB checked in by pnorton, 10 years ago (diff)

Added the ability to change the coastline colour in the coastline layer drawer. Also fixed a minor problem with the colour bar that was causing problems when unicode is used rather than strings.

Line 
1
2import math
3import logging
4import numpy
5
6from matplotlib.colors import Normalize, ListedColormap
7
8import matplotlib.ticker
9from matplotlib.figure import Figure
10from matplotlib.colorbar import ColorbarBase
11import matplotlib.cm as cm
12import operator
13
14
15from geoplot.fixed_boundary_norm import FixedBoundaryNorm
16import geoplot.config as geoplot_config
17import geoplot.utils
18
19
20config = geoplot_config.getConfig()
21
22log = logging.getLogger(__name__)
23
24FONTS_SECTION = 'Fonts'
25MAX_CBAR_TICKS = 10
26ADJUSTED_TICK_FORMAT = "%1.2f"
27
28class ColourBar(object):
29
30    def __init__(self, colourBarLabel="",colourBarPosition='horizontal', 
31                       cmap=None, colourBarMin=None, colourBarMax=None,
32                       intervalColourbar=False, intervalNames=None,
33                       hideOutsideBounds=False):
34
35        self._position = None
36        self._range = None
37        self._cmap = None
38
39        self.cmap = cmap
40        self.colourBarLabel = colourBarLabel
41        self.colourBarPosition = colourBarPosition
42        self.colourBarMin = colourBarMin
43        self.colourBarMax = colourBarMax
44        self.hideOutsideBounds = hideOutsideBounds
45        self.intervalColourbar = intervalColourbar
46        self.intervalNames = intervalNames
47       
48        self.labelFont = config[FONTS_SECTION]['ColourBarLabel']
49
50    def draw(self, colourBarAxis, fontSize='medium', grid=None):
51        """
52        Adds the colour bar to the (and optionally a label) to the figure.
53
54        @param sm: the scalar mappable generated by applying the grid to the axis
55        @type sm: an instance of matplotlib.cm.ScalarMappable
56        @param units: the units of the values on the mesh, these will be used as a label
57            if the colourBarLabel property is not set.
58        @type units: string
59        """
60        log.debug("drawing colour bar")
61        kwargs = {}
62       
63       
64       
65        kwargs['cmap'] = self.getColourMap(grid)
66        kwargs['norm'] = self.getNormalize(grid)
67        kwargs['orientation'] = self.colourBarPosition
68       
69        if self.intervalColourbar:
70           
71            cbMin, cbMax = self._getCbarMinMax()
72            kwargs['ticks'] = grid.getUniqueValues(cbMin, cbMax)
73           
74            if self.intervalNames != None:
75               
76                if len(self.intervalNames) == len(kwargs['ticks']):
77                    kwargs['format'] = matplotlib.ticker.FixedFormatter(self.intervalNames)
78                else:
79                    log.warning("Incorrect number of interval names, expected %s but recieved %s" % (len(kwargs['ticks']), len(self.intervalNames)))
80           
81            kwargs['spacing'] = 'proportional'
82           
83            log.debug("kwargs['cmap'].N = %s" % (kwargs['cmap'].N,))
84       
85        cb = ColorbarBase(colourBarAxis, **kwargs)
86
87        if cb.cmap.__class__ == ListedColormap \
88           and not self.intervalColourbar:
89            ColourBar._repositionColourBarTicks(cb)
90           
91        if self.colourBarLabel != None:
92            labelDictionary = self.labelFont.getDict(fontSize)
93            cb.set_label(self.colourBarLabel, fontdict=labelDictionary)
94       
95        log.debug("finished drawing colour bar")
96
97    def getColourMap(self, grid=None):
98       
99        if self.cmap == None:
100            cmap = matplotlib.cm.get_cmap()
101            cmap.set_bad("w")       
102       
103        elif type(self.cmap) in [str, unicode]:
104            cmap = cm.get_cmap(self.cmap)
105            cmap.set_bad("w")     
106        else:
107            cmap = self.cmap
108   
109        if self.intervalColourbar and  cmap.__class__ != ListedColormap:
110           
111            cbMin, cbMax = self._getCbarMinMax()
112           
113            oldCmap = cmap
114            uniqueVals = grid.getUniqueValues(cbMin, cbMax)
115            n = Normalize(uniqueVals[0], uniqueVals[-1])
116            colours = [oldCmap(n(x)) for x in uniqueVals]
117            cmap = ListedColormap(colours)
118               
119        if self.hideOutsideBounds:
120            log.debug("self.hideOutsideBounds = %s" % (self.hideOutsideBounds,))
121            cmap.set_under('0.25', alpha=0.0)
122            cmap.set_over('0.75', alpha=0.0)
123           
124        return cmap
125
126    @staticmethod
127    def _repositionColourBarTicks(cb):
128        """
129        reposition the ticks of a ListedColormap so that they appear at the
130        """
131       
132        log.debug("Repositioning colour bar ticks")
133        span = cb.vmax - cb.vmin
134
135        # Define flag for whether or not zero should be added
136        useZero = False
137        if cb.vmin < 0 and cb.vmax > 0:
138            useZero = True
139
140        numColours = len(cb.cmap.colors)
141        interval = float(span) / float(numColours)
142
143        showEvery = 1
144        while float(numColours)/float(showEvery) > float(MAX_CBAR_TICKS):
145            showEvery += 1
146
147        newLocs = []
148        for i in range(0, numColours + 1, showEvery):
149            newLocs.append(cb.vmin + i * interval)
150
151        # If need to add a zero then do so
152        if useZero == True and 0 not in newLocs:
153            newLocsWithZero = []
154         
155            zeroInserted = False 
156            for newLoc in newLocs:
157                if newLoc > 0 and zeroInserted == False:
158                    newLocsWithZero.append(0)
159                    zeroInserted = True                 
160                newLocsWithZero.append(newLoc)
161
162            newLocs = newLocsWithZero
163
164        #change the locator and the formatter (as the locations now have a high number of dp)
165        cb.locator = matplotlib.ticker.FixedLocator(newLocs)
166
167        # Decide the float formatting of the tick labels based on the span
168        if span < 100:
169            # In Excel this worked: =IF(E3<100,(LOG(E3)*-1)+2,0)
170            tickFormatDecPoints = int((math.log(span, 10) * -1) + 2)
171            tickFormatString = "%." + str(tickFormatDecPoints) + "f"
172        else: 
173            tickFormatString = "%d"
174           
175        cb.formatter = matplotlib.ticker.FormatStrFormatter(tickFormatString)  ###ADJUSTED_TICK_FORMAT)
176
177        # The next line removes axis artists as draw_all() adds new ones
178        cb.ax.artists = []
179        cb.draw_all()    # cause the colourbar to be redrawn, otherwise not changes will hapen
180
181        # Hard code line width of colour bar outline
182        cb.outline.set_linewidth(0.5)
183
184        # this can sometimes cause the tick positions to become unknown so
185        # set it back to default.
186        if cb.ax.yaxis.get_ticks_position() == 'unknown':
187            log.debug("Resetting yaxis ticks position to default")
188            cb.ax.yaxis.set_ticks_position('default')
189       
190        if cb.ax.xaxis.get_ticks_position() == 'unknown':
191            log.debug("Resetting xaxis ticks position to default")
192            cb.ax.xaxis.set_ticks_position('default')
193   
194    def getNormalize(self, grid=None):
195       
196        cbMin, cbMax = self._getCbarMinMax()
197       
198        if self.intervalColourbar:
199            bounds = grid.getUniqueValueBounds(cbMin, cbMax)
200            n = len(grid.getUniqueValues(cbMin, cbMax))
201            norm = FixedBoundaryNorm(bounds, n)
202        else:
203            norm = Normalize(cbMin,  cbMax)
204            # this should work event if data is none (as N.ma.maximum(None) = None)
205            data = None
206            if grid != None:
207                data = grid.values
208           
209            norm.autoscale_None(data)
210           
211           
212            # check for masked values in vmin and vmax, can occur when data is completly masked
213            if norm.vmin.__class__ == numpy.ma.MaskedArray and norm.vmin.mask == True:
214                norm.vmin = None
215               
216            if norm.vmax.__class__ == numpy.ma.MaskedArray and norm.vmax.mask == True:
217                norm.vmax = None
218               
219        return norm
220   
221    def _getCbarMinMax(self):
222        cbMin = self.colourBarMin
223        cbMax = self.colourBarMax
224       
225        if cbMin is not None and cbMax is not None and cbMin > cbMax:
226            log.warning("min(=%s) > max(=%s) reversing values" % (cbMin, cbMax))
227            cbMax, cbMin = cbMin, cbMax
228           
229        return cbMin, cbMax       
230   
231    def __get_position(self): return self._position
232
233    def __set_position(self, value):
234        if value not in ['horizontal', 'vertical', None]:
235            raise ValueError("ColourBar position value must be 'horizontal'" + \
236                             " or 'vertical, value recieved :" + str(value))
237        self._position = value
238
239    colourBarPosition = property(__get_position, __set_position, None,
240                    "colour bar position, 'horizontal' or 'vertical'")
241
242    def __get_cmap(self): 
243        if self._cmap == None:
244            cmap = cm.jet
245            cmap.set_bad('w')
246            self.cmap = cmap
247           
248        return self.cmap
249   
250    def __set_cmap(self, value):
251        self._cmap == value;
252       
253
254def getColourBarImage(width=600, height=100,
255                      label=None, 
256                      cmap=cm.jet, 
257                      cmapRange=(0,1), 
258                      orientation='horizontal',
259                      dpi=100):
260   
261    figsize=(width / float(dpi), height / float(dpi))
262    fig = Figure(figsize=figsize, dpi=dpi, facecolor='w')
263   
264    #for agg bakcend
265    #need about 40px at the bottom of the axes to draw the labels
266    #x = 40.0
267    #cbMin = 0.5
268   
269    #for cairo backend
270    x = 50.0
271    cbMin = 0.6
272   
273    cbBottom = x/height
274   
275    if cbBottom < 0.1:
276        cbBottom = 0.1
277   
278    if cbBottom > cbMin:
279        cbBottom = cbMin
280       
281    cbHeight = 0.9 - cbBottom
282    axes = fig.add_axes([0.05, cbBottom, 0.9, cbHeight], frameon=False)
283   
284    log.debug("cbBottom = %s, cbHeight = %s" % (cbBottom, cbHeight,))
285   
286    if cmapRange[0] > cmapRange[1]:
287        log.warning("cmapRange[0] > cmapRange[1], swapping values")
288        cmapRange = (cmapRange[1], cmapRange[0])
289   
290    cb = ColourBar(colourBarLabel=label, 
291                   colourBarPosition=orientation, 
292                   cmap=cmap, 
293                   colourBarMin=cmapRange[0], 
294                   colourBarMax=cmapRange[1] )
295   
296    cb.draw(axes)   
297   
298    return geoplot.utils.figureToImage(fig)
299
300def getIntervalColourBarImage(grid, width=600, height=100,
301                      label=None, 
302                      cmap=cm.jet, 
303                      cmapRange=(None, None), 
304                      orientation='horizontal',
305                      intervalNames=None,
306                      dpi=100):
307   
308    figsize=(width / float(dpi), height / float(dpi))
309    fig = Figure(figsize=figsize, dpi=dpi, facecolor='w')
310   
311    #for agg bakcend
312    #need about 40px at the bottom of the axes to draw the labels
313    #x = 40.0
314    #cbMin = 0.5
315   
316    #for cairo backend
317    x = 50.0
318    cbMin = 0.6
319   
320    cbBottom = x/height
321   
322    if cbBottom < 0.1:
323        cbBottom = 0.1
324   
325    if cbBottom > cbMin:
326        cbBottom = cbMin
327       
328    cbHeight = 0.9 - cbBottom
329    axes = fig.add_axes([0.05, cbBottom, 0.9, cbHeight], frameon=False)
330   
331    log.debug("cbBottom = %s, cbHeight = %s" % (cbBottom, cbHeight,))
332   
333    if cmapRange[0] > cmapRange[1]:
334        log.warning("cmapRange[0] > cmapRange[1], swapping values")
335        cmapRange = (cmapRange[1], cmapRange[0])
336   
337    cb = ColourBar(colourBarLabel=label, 
338                   colourBarPosition=orientation, 
339                   cmap=cmap, 
340                   colourBarMin=cmapRange[0], 
341                   colourBarMax=cmapRange[1],
342                   intervalNames=intervalNames,
343                   intervalColourbar=True)
344   
345    cb.draw(axes, grid=grid)   
346   
347    return geoplot.utils.figureToImage(fig)
Note: See TracBrowser for help on using the repository browser.