source: qesdi/geoplot/trunk/lib/geoplot/colour_bar.py @ 6118

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/colour_bar.py@6118
Revision 6118, 13.8 KB checked in by pnorton, 10 years ago (diff)

Fixed lots of small bugs and tidied up the layer drawer code a bit.

Line 
1
2import math
3import logging
4
5from matplotlib.colors import Normalize, ListedColormap
6
7import matplotlib.ticker
8from matplotlib.figure import Figure
9from matplotlib.colorbar import ColorbarBase
10from matplotlib.patches import Rectangle
11import matplotlib.cm as cm
12import geoplot.config as geoplot_config
13import geoplot
14
15config = geoplot_config.getConfig()
16
17log = logging.getLogger(__name__)
18
19FONTS_SECTION = 'Fonts'
20MAX_CBAR_TICKS = 10
21ADJUSTED_TICK_FORMAT = "%1.2f"
22
23class COLOUR_BAR_STYLES:
24    LEGEND='legend'
25    CONTINUOUS = 'continuous'
26    INTERVAL = 'interval' 
27    LINE = 'line'
28
29    @staticmethod
30    def all():
31        return [COLOUR_BAR_STYLES.CONTINUOUS, COLOUR_BAR_STYLES.LEGEND, COLOUR_BAR_STYLES.LINE, COLOUR_BAR_STYLES.INTERVAL]
32
33class COLOUR_SCHEME_SCALE:
34    LINEAR='linear'
35    LOG='log'
36   
37    @staticmethod
38    def all():
39        return [COLOUR_SCHEME_SCALE.LINEAR, COLOUR_SCHEME_SCALE.LOG]
40   
41
42class ColourBar(object):
43
44    def __init__(self, colourBarLabel="",colourBarPosition='horizontal', colourBarStyle=COLOUR_BAR_STYLES.CONTINUOUS):
45
46        self._position = None
47
48        self.colourBarLabel = colourBarLabel
49        self.colourBarPosition = colourBarPosition
50        self.colourBarStyle = colourBarStyle
51       
52        self.labelFont = config[FONTS_SECTION]['ColourBarLabel']
53
54    def draw(self, colourBarAxis, colourScheme, fontSize='medium', intervalNames=None):
55        """
56        Adds the colour bar to the (and optionally a label) to the figure.
57
58        @param sm: the scalar mappable generated by applying the grid to the axis
59        @type sm: an instance of matplotlib.cm.ScalarMappable
60        @param units: the units of the values on the mesh, these will be used as a label
61            if the colourBarLabel property is not set.
62        @type units: string
63        """
64        log.debug("drawing colour bar")
65       
66        if self.colourBarStyle == COLOUR_BAR_STYLES.LEGEND:
67            self._drawLegendColourBar(colourBarAxis, colourScheme, fontSize, intervalNames)
68           
69        elif self.colourBarStyle == COLOUR_BAR_STYLES.LINE:
70            self._drawLineColourBar(colourBarAxis, colourScheme, fontSize)
71           
72        elif self.colourBarStyle == COLOUR_BAR_STYLES.INTERVAL:
73            self._drawIntervalColourBar(colourBarAxis, colourScheme, fontSize)
74           
75        else:
76            self._drawContiunousColourBar(colourBarAxis, colourScheme, fontSize)
77               
78        log.debug("finished drawing colour bar")
79
80    def _drawContiunousColourBar(self, axes, colourScheme, fontSize):
81       
82        kwargs = {}
83        kwargs['cmap'] = colourScheme.cmap
84        kwargs['norm'] = colourScheme.norm
85       
86        kwargs['orientation'] = self.colourBarPosition
87       
88#        if intervalColourbar:
89#            kwargs['ticks'] = colourScheme.midpoints
90#            kwargs['format'] = matplotlib.ticker.FixedFormatter(colourScheme.labels)
91#            kwargs['spacing'] = 'proportional'
92       
93        cb = ColorbarBase(axes, **kwargs)
94
95        if cb.cmap.__class__ == ListedColormap:
96            ColourBar._repositionColourBarTicks(cb)
97           
98        if self.colourBarLabel != None:
99            labelDictionary = self.labelFont.getDict(fontSize)
100            cb.set_label(self.colourBarLabel, fontdict=labelDictionary)
101       
102        return cb
103   
104
105    def _drawIntervalColourBar(self, axes, colourScheme, fontSize):
106       
107        kwargs = {}
108        kwargs['cmap'] = colourScheme.getListedCmap()
109        kwargs['norm'] = colourScheme.getBoundaryNorm()
110        kwargs['orientation'] = self.colourBarPosition
111       
112        kwargs['ticks'] = colourScheme.intervals.bounds
113 
114        labelFormat = "%.2f"
115        labels  = [ labelFormat % (bound,) for bound in colourScheme.intervals.bounds]
116       
117        kwargs['format'] = matplotlib.ticker.FixedFormatter(labels)
118        kwargs['spacing'] = 'proportional'
119       
120        cb = ColorbarBase(axes, **kwargs)
121           
122        if self.colourBarLabel != None:
123            labelDictionary = self.labelFont.getDict(fontSize)
124            cb.set_label(self.colourBarLabel, fontdict=labelDictionary)
125       
126        return cb   
127   
128    def _drawLegendColourBar(self, colourBarAxis, colourScheme, fontSize, intervalNames):
129        """
130       
131        """
132        kwargs = {}
133       
134        cmap = colourScheme.cmap
135        norm = colourScheme.norm
136       
137        bounds = colourScheme.intervals.bounds
138        labels = None
139       
140        log.debug("intervalNames = %s" % (intervalNames,))
141       
142        if not intervalNames is None:
143           
144            labels = filter(lambda x: x.strip() != '', intervalNames.split(','))
145           
146            if len(labels) != len(colourScheme.intervals.midpoints):
147               
148                log.warning("Number of labels found (%s) is != to the number of midpoints (%s), using default labels instead" \
149                                % (len(labels), len(colourScheme.intervals.midpoints)))
150                labels = None                 
151       
152        if labels is None:
153            if colourScheme.scale == COLOUR_SCHEME_SCALE.LOG:
154                labelFormat = "%.1e"
155            else:
156                labelFormat = "%.2f"
157               
158            lfString = labelFormat + " - " + labelFormat
159            labels  = [ lfString % (bounds[index], bounds[index + 1]) for index in range(len(bounds)-1)]       
160       
161        locations = colourScheme.intervals.midpoints
162
163        kwargs['orientation'] = self.colourBarPosition
164 
165        handles = [Rectangle((0,0), 1, 1, fc=cmap(norm(i))) for i in locations]
166        labels = labels
167       
168        log.debug("locations = %s" % ([i for i in locations],))
169        log.debug("norm = %s" % ([norm(i) for i in locations],))
170        log.debug("legend colours = %s" % ([cmap(norm(i)) for i in locations],))
171       
172        x = 1.5e-3
173        log.debug(" %s, norm = %s, colour = %s" % (x, norm(x), cmap(norm(x))))
174       
175        if self.colourBarPosition == 'horizontal' :
176            if colourScheme.scale == COLOUR_SCHEME_SCALE.LINEAR:
177                numCols = 3
178            else:
179                numCols = 2
180               
181            if len(handles) < numCols:
182                numCols = len(handles)
183        else:
184            numCols = 1
185
186        leg = colourBarAxis.legend(handles, labels, loc=10, mode='expand', 
187                                   ncol=numCols, borderaxespad=0)
188       
189        colourBarAxis.set_xticks([])
190        colourBarAxis.set_yticks([])
191       
192        if self.colourBarLabel != None:
193            labelDictionary = self.labelFont.getDict(fontSize)
194           
195            if self.colourBarPosition == 'horizontal':
196                colourBarAxis.set_xlabel(self.colourBarLabel, fontdict=labelDictionary)     
197            else:
198                colourBarAxis.set_ylabel(self.colourBarLabel, fontdict=labelDictionary)
199               
200        return leg
201   
202    def _drawLineColourBar(self, colourBarAxis, colourScheme, fontSize):
203        """
204       
205        """
206       
207        cmap = colourScheme.cmap
208        norm = colourScheme.norm
209        levels = colourScheme.intervals.bounds
210        colours = [cmap(norm(x)) for x in levels]
211        widths = [1.0 for x in levels]
212       
213        # to stop the lines being mixed up with the edges of the colour
214        # bar extend the normalisation limits a bit.
215        downShift = (levels[1] - levels[0])/4.0
216        upShift = (levels[-1] - levels[-2])/4.0
217       
218        n2 =norm.__class__(norm.vmin - downShift, norm.vmax + upShift)
219       
220        kwargs = {}
221        kwargs['norm'] = n2
222        kwargs['ticks'] =  colourScheme.intervals.bounds
223 
224        labelFormat = "%.2f"
225        labels  = [ labelFormat % (bound,) for bound in colourScheme.intervals.bounds]
226       
227        kwargs['orientation'] = 'horizontal'
228        kwargs['filled'] = False
229        kwargs['extend'] = 'neither'
230        kwargs['format'] = matplotlib.ticker.FixedFormatter(labels)
231       
232        if colourScheme.scale == COLOUR_SCHEME_SCALE.LOG:
233            kwargs['spacing'] = 'uniform'
234        else:
235            kwargs['spacing'] = 'proportional'
236        log.debug("kwargs = %s" % (kwargs,))
237       
238        log.debug("norm.__class__ = %s" % (norm.__class__,))
239        log.debug("n2.__class__ = %s" % (n2.__class__,))
240        log.debug("colourScheme.scale = %s" % (colourScheme.scale,))
241       
242        cb = ColorbarBase(colourBarAxis, **kwargs)
243       
244        log.debug("cb.dividers = %s" % (cb.dividers,))
245       
246        cb.add_lines(levels, colours, widths)
247       
248        if self.colourBarLabel != None:
249            labelDictionary = self.labelFont.getDict(fontSize)
250            cb.set_label(self.colourBarLabel, fontdict=labelDictionary)
251       
252        #get rid of any additional tick on the colourBarAxis
253        colourBarAxis.minorticks_off()
254        for ax in (colourBarAxis.xaxis, colourBarAxis.yaxis):
255            ax.set_ticks_position('none')
256
257        return cb
258       
259    @staticmethod
260    def _repositionColourBarTicks(cb):
261        """
262        reposition the ticks of a ListedColormap so that they appear at the
263        """
264       
265        log.debug("Repositioning colour bar ticks")
266        span = cb.vmax - cb.vmin
267
268        # Define flag for whether or not zero should be added
269        useZero = False
270        if cb.vmin < 0 and cb.vmax > 0:
271            useZero = True
272
273        numColours = len(cb.cmap.colors)
274        interval = float(span) / float(numColours)
275
276        showEvery = 1
277        while float(numColours)/float(showEvery) > float(MAX_CBAR_TICKS):
278            showEvery += 1
279
280        newLocs = []
281        for i in range(0, numColours + 1, showEvery):
282            newLocs.append(cb.vmin + i * interval)
283
284        # If need to add a zero then do so
285        if useZero == True and 0 not in newLocs:
286            newLocsWithZero = []
287         
288            zeroInserted = False 
289            for newLoc in newLocs:
290                if newLoc > 0 and zeroInserted == False:
291                    newLocsWithZero.append(0)
292                    zeroInserted = True                 
293                newLocsWithZero.append(newLoc)
294
295            newLocs = newLocsWithZero
296
297        #change the locator and the formatter (as the locations now have a high number of dp)
298        cb.locator = matplotlib.ticker.FixedLocator(newLocs)
299
300        # Decide the float formatting of the tick labels based on the span
301        if span < 100:
302            # In Excel this worked: =IF(E3<100,(LOG(E3)*-1)+2,0)
303            tickFormatDecPoints = int((math.log(span, 10) * -1) + 2)
304            tickFormatString = "%." + str(tickFormatDecPoints) + "f"
305        else: 
306            tickFormatString = "%d"
307           
308        cb.formatter = matplotlib.ticker.FormatStrFormatter(tickFormatString)  ###ADJUSTED_TICK_FORMAT)
309
310        # The next line removes axis artists as draw_all() adds new ones
311        cb.ax.artists = []
312        cb.draw_all()    # cause the colourbar to be redrawn, otherwise not changes will hapen
313
314        # Hard code line width of colour bar outline
315        cb.outline.set_linewidth(0.5)
316
317        # this can sometimes cause the tick positions to become unknown so
318        # set it back to default.
319        if cb.ax.yaxis.get_ticks_position() == 'unknown':
320            log.debug("Resetting yaxis ticks position to default")
321            cb.ax.yaxis.set_ticks_position('default')
322       
323        if cb.ax.xaxis.get_ticks_position() == 'unknown':
324            log.debug("Resetting xaxis ticks position to default")
325            cb.ax.xaxis.set_ticks_position('default')
326   
327 
328    def __get_position(self): return self._position
329
330    def __set_position(self, value):
331        if value not in ['horizontal', 'vertical', None]:
332            raise ValueError("ColourBar position value must be 'horizontal'" + \
333                             " or 'vertical, value recieved :" + str(value))
334        self._position = value
335
336    colourBarPosition = property(__get_position, __set_position, None,
337                    "colour bar position, 'horizontal' or 'vertical'")
338
339       
340
341def getColourBarImage(width=600, height=100,
342                      label=None, 
343                      cmap=cm.jet, 
344                      colourBarMin=None, 
345                      colourBarMax=None, 
346                      colourBarScale=COLOUR_SCHEME_SCALE.LINEAR,
347                      orientation='horizontal',
348                      intervals=None,
349                      intervalNames=None,
350                      numIntervals=None,
351                      colourBarStyle='continuous',
352                      dpi=100):
353   
354    figsize=(width / float(dpi), height / float(dpi))
355    fig = Figure(figsize=figsize, dpi=dpi, facecolor='w')
356   
357    #build a colour scheme
358    log.debug("intervalNames = %s" % (intervalNames,))
359    schemeBuilder = geoplot.colour_scheme.ColourSchemeBuilder(cmap=cmap, 
360                 colourBarMin=colourBarMin, 
361                 colourBarMax=colourBarMax, 
362                 colourBarScale=colourBarScale,
363                 numIntervals=numIntervals,
364                 intervals=intervals)
365   
366    colourScheme = schemeBuilder.buildScheme(colourBarStyle)   
367    log.debug("colourScheme.norm.__class__ = %s" % (colourScheme.norm.__class__,))
368    #for agg bakcend
369    #need about 40px at the bottom of the axes to draw the labels
370    #x = 40.0
371    #cbMin = 0.5
372   
373    #for cairo backend
374    if colourBarStyle == 'legend':
375        x = 40.0
376        cbMin = 0.6
377    else:
378        x = 70.0
379        cbMin = 0.6
380    cbBottom = x/height
381   
382    if cbBottom < 0.1:
383        cbBottom = 0.1
384   
385    if cbBottom > cbMin:
386        cbBottom = cbMin
387       
388    cbHeight = 0.9 - cbBottom
389    axes = fig.add_axes([0.05, cbBottom, 0.9, cbHeight], frameon=False, xticks=[], yticks=[])
390   
391    cb = ColourBar(colourBarLabel=label, 
392                   colourBarPosition=orientation,
393                   colourBarStyle=colourBarStyle)
394   
395    cb.draw(axes, colourScheme,
396            intervalNames=intervalNames)   
397   
398    return geoplot.utils.figureToImage(fig)
Note: See TracBrowser for help on using the repository browser.