source: qesdi/geoplot/trunk/lib/geoplot/colour_bar.py @ 5835

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/colour_bar.py@5835
Revision 5835, 11.4 KB checked in by pnorton, 10 years ago (diff)

Improved the performance of the getUniqueValues method on the grid, also cached the result to avoid the cost of re-calculation.

Line 
1
2import math
3import logging
4import numpy
5
6from matplotlib.colors import Normalize, ListedColormap
7
8import matplotlib.ticker
9from matplotlib.figure import Figure
10from matplotlib.colorbar import ColorbarBase
11import matplotlib.cm as cm
12import operator
13import time
14
15
16from geoplot.fixed_boundary_norm import FixedBoundaryNorm
17import geoplot.config as geoplot_config
18import geoplot.utils
19
20
21config = geoplot_config.getConfig()
22
23log = logging.getLogger(__name__)
24
25FONTS_SECTION = 'Fonts'
26MAX_CBAR_TICKS = 10
27ADJUSTED_TICK_FORMAT = "%1.2f"
28
29class ColourBar(object):
30
31    def __init__(self, colourBarLabel="",colourBarPosition='horizontal', 
32                       cmap=None, colourBarMin=None, colourBarMax=None,
33                       intervalColourbar=False, intervalNames=None,
34                       hideOutsideBounds=False):
35
36        self._position = None
37        self._range = None
38        self._cmap = None
39
40        self.cmap = cmap
41        self.colourBarLabel = colourBarLabel
42        self.colourBarPosition = colourBarPosition
43        self.colourBarMin = colourBarMin
44        self.colourBarMax = colourBarMax
45        self.hideOutsideBounds = hideOutsideBounds
46        self.intervalColourbar = intervalColourbar
47        self.intervalNames = intervalNames
48       
49        self.labelFont = config[FONTS_SECTION]['ColourBarLabel']
50
51    def draw(self, colourBarAxis, fontSize='medium', grid=None):
52        """
53        Adds the colour bar to the (and optionally a label) to the figure.
54
55        @param sm: the scalar mappable generated by applying the grid to the axis
56        @type sm: an instance of matplotlib.cm.ScalarMappable
57        @param units: the units of the values on the mesh, these will be used as a label
58            if the colourBarLabel property is not set.
59        @type units: string
60        """
61        log.debug("drawing colour bar")
62        kwargs = {}
63       
64        kwargs['cmap'] = self.getColourMap(grid)
65        kwargs['norm'] = self.getNormalize(grid)
66       
67        kwargs['orientation'] = self.colourBarPosition
68       
69        if self.intervalColourbar:
70           
71            cbMin, cbMax = self._getCbarMinMax()
72           
73            kwargs['ticks'] = grid.getUniqueValues(cbMin, cbMax)
74           
75            if self.intervalNames != None:
76               
77                if len(self.intervalNames) == len(kwargs['ticks']):
78                    kwargs['format'] = matplotlib.ticker.FixedFormatter(self.intervalNames)
79                else:
80                    log.warning("Incorrect number of interval names, expected %s but recieved %s" % (len(kwargs['ticks']), len(self.intervalNames)))
81           
82            kwargs['spacing'] = 'proportional'
83           
84            log.debug("kwargs['cmap'].N = %s" % (kwargs['cmap'].N,))
85       
86        cb = ColorbarBase(colourBarAxis, **kwargs)
87
88        if cb.cmap.__class__ == ListedColormap \
89           and not self.intervalColourbar:
90            ColourBar._repositionColourBarTicks(cb)
91           
92        if self.colourBarLabel != None:
93            labelDictionary = self.labelFont.getDict(fontSize)
94            cb.set_label(self.colourBarLabel, fontdict=labelDictionary)
95       
96        log.debug("finished drawing colour bar")
97
98    def getColourMap(self, grid=None):
99       
100        if self.cmap == None:
101            cmap = matplotlib.cm.get_cmap()
102            cmap.set_bad("w")       
103       
104        elif type(self.cmap) in [str, unicode]:
105            cmap = cm.get_cmap(self.cmap)
106            cmap.set_bad("w")     
107        else:
108            cmap = self.cmap
109   
110        if self.intervalColourbar and  cmap.__class__ != ListedColormap:
111           
112            cbMin, cbMax = self._getCbarMinMax()
113           
114            oldCmap = cmap
115            uniqueVals = grid.getUniqueValues(cbMin, cbMax)
116            n = Normalize(uniqueVals[0], uniqueVals[-1])
117            colours = [oldCmap(n(x)) for x in uniqueVals]
118            cmap = ListedColormap(colours)
119               
120        if self.hideOutsideBounds:
121            log.debug("self.hideOutsideBounds = %s" % (self.hideOutsideBounds,))
122            cmap.set_under('0.25', alpha=0.0)
123            cmap.set_over('0.75', alpha=0.0)
124           
125        return cmap
126
127    @staticmethod
128    def _repositionColourBarTicks(cb):
129        """
130        reposition the ticks of a ListedColormap so that they appear at the
131        """
132       
133        log.debug("Repositioning colour bar ticks")
134        span = cb.vmax - cb.vmin
135
136        # Define flag for whether or not zero should be added
137        useZero = False
138        if cb.vmin < 0 and cb.vmax > 0:
139            useZero = True
140
141        numColours = len(cb.cmap.colors)
142        interval = float(span) / float(numColours)
143
144        showEvery = 1
145        while float(numColours)/float(showEvery) > float(MAX_CBAR_TICKS):
146            showEvery += 1
147
148        newLocs = []
149        for i in range(0, numColours + 1, showEvery):
150            newLocs.append(cb.vmin + i * interval)
151
152        # If need to add a zero then do so
153        if useZero == True and 0 not in newLocs:
154            newLocsWithZero = []
155         
156            zeroInserted = False 
157            for newLoc in newLocs:
158                if newLoc > 0 and zeroInserted == False:
159                    newLocsWithZero.append(0)
160                    zeroInserted = True                 
161                newLocsWithZero.append(newLoc)
162
163            newLocs = newLocsWithZero
164
165        #change the locator and the formatter (as the locations now have a high number of dp)
166        cb.locator = matplotlib.ticker.FixedLocator(newLocs)
167
168        # Decide the float formatting of the tick labels based on the span
169        if span < 100:
170            # In Excel this worked: =IF(E3<100,(LOG(E3)*-1)+2,0)
171            tickFormatDecPoints = int((math.log(span, 10) * -1) + 2)
172            tickFormatString = "%." + str(tickFormatDecPoints) + "f"
173        else: 
174            tickFormatString = "%d"
175           
176        cb.formatter = matplotlib.ticker.FormatStrFormatter(tickFormatString)  ###ADJUSTED_TICK_FORMAT)
177
178        # The next line removes axis artists as draw_all() adds new ones
179        cb.ax.artists = []
180        cb.draw_all()    # cause the colourbar to be redrawn, otherwise not changes will hapen
181
182        # Hard code line width of colour bar outline
183        cb.outline.set_linewidth(0.5)
184
185        # this can sometimes cause the tick positions to become unknown so
186        # set it back to default.
187        if cb.ax.yaxis.get_ticks_position() == 'unknown':
188            log.debug("Resetting yaxis ticks position to default")
189            cb.ax.yaxis.set_ticks_position('default')
190       
191        if cb.ax.xaxis.get_ticks_position() == 'unknown':
192            log.debug("Resetting xaxis ticks position to default")
193            cb.ax.xaxis.set_ticks_position('default')
194   
195    def getNormalize(self, grid=None):
196       
197        cbMin, cbMax = self._getCbarMinMax()
198       
199        if self.intervalColourbar:
200            bounds = grid.getUniqueValueBounds(cbMin, cbMax)
201            n = len(grid.getUniqueValues(cbMin, cbMax))
202            norm = FixedBoundaryNorm(bounds, n)
203        else:
204            norm = Normalize(cbMin,  cbMax)
205            # this should work event if data is none (as N.ma.maximum(None) = None)
206            data = None
207            if grid != None:
208                data = grid.values
209           
210            norm.autoscale_None(data)
211           
212           
213            # check for masked values in vmin and vmax, can occur when data is completly masked
214            if norm.vmin.__class__ == numpy.ma.MaskedArray and norm.vmin.mask == True:
215                norm.vmin = None
216               
217            if norm.vmax.__class__ == numpy.ma.MaskedArray and norm.vmax.mask == True:
218                norm.vmax = None
219               
220        return norm
221   
222    def _getCbarMinMax(self):
223        cbMin = self.colourBarMin
224        cbMax = self.colourBarMax
225       
226        if cbMin is not None and cbMax is not None and cbMin > cbMax:
227            log.warning("min(=%s) > max(=%s) reversing values" % (cbMin, cbMax))
228            cbMax, cbMin = cbMin, cbMax
229           
230        return cbMin, cbMax       
231   
232    def __get_position(self): return self._position
233
234    def __set_position(self, value):
235        if value not in ['horizontal', 'vertical', None]:
236            raise ValueError("ColourBar position value must be 'horizontal'" + \
237                             " or 'vertical, value recieved :" + str(value))
238        self._position = value
239
240    colourBarPosition = property(__get_position, __set_position, None,
241                    "colour bar position, 'horizontal' or 'vertical'")
242
243    def __get_cmap(self): 
244        if self._cmap == None:
245            cmap = cm.jet
246            cmap.set_bad('w')
247            self.cmap = cmap
248           
249        return self.cmap
250   
251    def __set_cmap(self, value):
252        self._cmap = value;
253       
254
255def getColourBarImage(width=600, height=100,
256                      label=None, 
257                      cmap=cm.jet, 
258                      cmapRange=(0,1), 
259                      orientation='horizontal',
260                      dpi=100):
261   
262    figsize=(width / float(dpi), height / float(dpi))
263    fig = Figure(figsize=figsize, dpi=dpi, facecolor='w')
264   
265    #for agg bakcend
266    #need about 40px at the bottom of the axes to draw the labels
267    #x = 40.0
268    #cbMin = 0.5
269   
270    #for cairo backend
271    x = 50.0
272    cbMin = 0.6
273   
274    cbBottom = x/height
275   
276    if cbBottom < 0.1:
277        cbBottom = 0.1
278   
279    if cbBottom > cbMin:
280        cbBottom = cbMin
281       
282    cbHeight = 0.9 - cbBottom
283    axes = fig.add_axes([0.05, cbBottom, 0.9, cbHeight], frameon=False)
284   
285    log.debug("cbBottom = %s, cbHeight = %s" % (cbBottom, cbHeight,))
286   
287    if cmapRange[0] > cmapRange[1]:
288        log.warning("cmapRange[0] > cmapRange[1], swapping values")
289        cmapRange = (cmapRange[1], cmapRange[0])
290   
291    cb = ColourBar(colourBarLabel=label, 
292                   colourBarPosition=orientation, 
293                   cmap=cmap, 
294                   colourBarMin=cmapRange[0], 
295                   colourBarMax=cmapRange[1] )
296   
297    cb.draw(axes)   
298   
299    return geoplot.utils.figureToImage(fig)
300
301def getIntervalColourBarImage(grid, width=600, height=100,
302                      label=None, 
303                      cmap=cm.jet, 
304                      cmapRange=(None, None), 
305                      orientation='horizontal',
306                      intervalNames=None,
307                      dpi=100):
308   
309    figsize=(width / float(dpi), height / float(dpi))
310    fig = Figure(figsize=figsize, dpi=dpi, facecolor='w')
311   
312    #for agg bakcend
313    #need about 40px at the bottom of the axes to draw the labels
314    #x = 40.0
315    #cbMin = 0.5
316   
317    #for cairo backend
318    x = 50.0
319    cbMin = 0.6
320   
321    cbBottom = x/height
322   
323    if cbBottom < 0.1:
324        cbBottom = 0.1
325   
326    if cbBottom > cbMin:
327        cbBottom = cbMin
328       
329    cbHeight = 0.9 - cbBottom
330    axes = fig.add_axes([0.05, cbBottom, 0.9, cbHeight], frameon=False)
331   
332    log.debug("cbBottom = %s, cbHeight = %s" % (cbBottom, cbHeight,))
333   
334    if cmapRange[0] > cmapRange[1]:
335        log.warning("cmapRange[0] > cmapRange[1], swapping values")
336        cmapRange = (cmapRange[1], cmapRange[0])
337   
338    cb = ColourBar(colourBarLabel=label, 
339                   colourBarPosition=orientation, 
340                   cmap=cmap, 
341                   colourBarMin=cmapRange[0], 
342                   colourBarMax=cmapRange[1],
343                   intervalNames=intervalNames,
344                   intervalColourbar=True)
345   
346    cb.draw(axes, grid=grid)   
347   
348    return geoplot.utils.figureToImage(fig)
Note: See TracBrowser for help on using the repository browser.