source: qesdi/geoplot/trunk/lib/geoplot/colour_bar.py @ 6089

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/colour_bar.py@6089
Revision 6089, 9.1 KB checked in by pnorton, 12 years ago (diff)

Imroved the colour bar code so that a legend colour bar can be used without specifying any intervals.

Line 
1
2import math
3import logging
4
5from matplotlib.colors import Normalize, ListedColormap
6
7import matplotlib.ticker
8from matplotlib.figure import Figure
9from matplotlib.colorbar import ColorbarBase
10from matplotlib.patches import Rectangle
11import matplotlib.cm as cm
12import geoplot.config as geoplot_config
13import geoplot
14
15config = geoplot_config.getConfig()
16
17log = logging.getLogger(__name__)
18
19FONTS_SECTION = 'Fonts'
20MAX_CBAR_TICKS = 10
21ADJUSTED_TICK_FORMAT = "%1.2f"
22
23class COLOUR_BAR_STYLES:
24    LEGEND='legend'
25    CONTINUOUS = 'continuous' 
26
27    @staticmethod
28    def all():
29        return [COLOUR_BAR_STYLES.CONTINUOUS, COLOUR_BAR_STYLES.LEGEND]
30
31class COLOUR_SCHEME_SCALE:
32    LINEAR='linear'
33    LOG='log'
34
35class ColourBar(object):
36
37    def __init__(self, colourBarLabel="",colourBarPosition='horizontal', colourBarStyle=COLOUR_BAR_STYLES.CONTINUOUS):
38
39        self._position = None
40
41        self.colourBarLabel = colourBarLabel
42        self.colourBarPosition = colourBarPosition
43        self.colourBarStyle = colourBarStyle
44       
45        self.labelFont = config[FONTS_SECTION]['ColourBarLabel']
46
47    def draw(self, colourBarAxis, colourScheme, fontSize='medium'):
48        """
49        Adds the colour bar to the (and optionally a label) to the figure.
50
51        @param sm: the scalar mappable generated by applying the grid to the axis
52        @type sm: an instance of matplotlib.cm.ScalarMappable
53        @param units: the units of the values on the mesh, these will be used as a label
54            if the colourBarLabel property is not set.
55        @type units: string
56        """
57        log.debug("drawing colour bar")
58       
59        if self.colourBarStyle == 'legend':
60            self._drawLegendColourBar(colourBarAxis, colourScheme, fontSize)
61        else:
62            self._drawContiunousColourBar(colourBarAxis, colourScheme, fontSize)
63               
64        log.debug("finished drawing colour bar")
65
66    def _drawContiunousColourBar(self, axes, colourScheme, fontSize):
67       
68        intervalColourbar = isinstance(colourScheme, geoplot.colour_scheme.IntervalColourScheme)
69         
70        kwargs = {}
71        kwargs['cmap'] = colourScheme.colourMap
72        kwargs['norm'] = colourScheme.norm
73       
74        kwargs['orientation'] = self.colourBarPosition
75       
76        if intervalColourbar:
77            kwargs['ticks'] = colourScheme.labelLocations
78            kwargs['format'] = matplotlib.ticker.FixedFormatter(colourScheme.labels)
79            kwargs['spacing'] = 'proportional'
80       
81        cb = ColorbarBase(axes, **kwargs)
82
83        if cb.cmap.__class__ == ListedColormap \
84           and not intervalColourbar:
85            ColourBar._repositionColourBarTicks(cb)
86           
87        if self.colourBarLabel != None:
88            labelDictionary = self.labelFont.getDict(fontSize)
89            cb.set_label(self.colourBarLabel, fontdict=labelDictionary)
90       
91        return cb
92   
93    def _drawLegendColourBar(self, colourBarAxis, colourScheme, fontSize):
94        """
95       
96        """
97        kwargs = {}
98       
99        cmap = colourScheme.colourMap
100        norm = colourScheme.norm
101       
102        locations, labels = colourScheme.labelLocations, colourScheme.labels
103
104        kwargs['orientation'] = self.colourBarPosition
105 
106        handles = [Rectangle((0,0), 1, 1, fc=cmap(norm(i))) for i in locations]
107        labels = labels
108       
109        if self.colourBarPosition == 'horizontal':
110            if len(handles) < 3:
111                ncol = len(handles)
112            else:
113                ncol = 3
114        else:
115            ncol = 1
116
117        leg = colourBarAxis.legend(handles, labels, loc=10, mode='expand', 
118                                   ncol=ncol, borderaxespad=0)
119       
120        colourBarAxis.set_xticks([])
121        colourBarAxis.set_yticks([])
122       
123        if self.colourBarLabel != None:
124            labelDictionary = self.labelFont.getDict(fontSize)
125           
126            if self.colourBarPosition == 'horizontal':
127                colourBarAxis.set_xlabel(self.colourBarLabel, fontdict=labelDictionary)     
128            else:
129                colourBarAxis.set_ylabel(self.colourBarLabel, fontdict=labelDictionary)
130               
131        return leg
132       
133    @staticmethod
134    def _repositionColourBarTicks(cb):
135        """
136        reposition the ticks of a ListedColormap so that they appear at the
137        """
138       
139        log.debug("Repositioning colour bar ticks")
140        span = cb.vmax - cb.vmin
141
142        # Define flag for whether or not zero should be added
143        useZero = False
144        if cb.vmin < 0 and cb.vmax > 0:
145            useZero = True
146
147        numColours = len(cb.cmap.colors)
148        interval = float(span) / float(numColours)
149
150        showEvery = 1
151        while float(numColours)/float(showEvery) > float(MAX_CBAR_TICKS):
152            showEvery += 1
153
154        newLocs = []
155        for i in range(0, numColours + 1, showEvery):
156            newLocs.append(cb.vmin + i * interval)
157
158        # If need to add a zero then do so
159        if useZero == True and 0 not in newLocs:
160            newLocsWithZero = []
161         
162            zeroInserted = False 
163            for newLoc in newLocs:
164                if newLoc > 0 and zeroInserted == False:
165                    newLocsWithZero.append(0)
166                    zeroInserted = True                 
167                newLocsWithZero.append(newLoc)
168
169            newLocs = newLocsWithZero
170
171        #change the locator and the formatter (as the locations now have a high number of dp)
172        cb.locator = matplotlib.ticker.FixedLocator(newLocs)
173
174        # Decide the float formatting of the tick labels based on the span
175        if span < 100:
176            # In Excel this worked: =IF(E3<100,(LOG(E3)*-1)+2,0)
177            tickFormatDecPoints = int((math.log(span, 10) * -1) + 2)
178            tickFormatString = "%." + str(tickFormatDecPoints) + "f"
179        else: 
180            tickFormatString = "%d"
181           
182        cb.formatter = matplotlib.ticker.FormatStrFormatter(tickFormatString)  ###ADJUSTED_TICK_FORMAT)
183
184        # The next line removes axis artists as draw_all() adds new ones
185        cb.ax.artists = []
186        cb.draw_all()    # cause the colourbar to be redrawn, otherwise not changes will hapen
187
188        # Hard code line width of colour bar outline
189        cb.outline.set_linewidth(0.5)
190
191        # this can sometimes cause the tick positions to become unknown so
192        # set it back to default.
193        if cb.ax.yaxis.get_ticks_position() == 'unknown':
194            log.debug("Resetting yaxis ticks position to default")
195            cb.ax.yaxis.set_ticks_position('default')
196       
197        if cb.ax.xaxis.get_ticks_position() == 'unknown':
198            log.debug("Resetting xaxis ticks position to default")
199            cb.ax.xaxis.set_ticks_position('default')
200   
201 
202    def __get_position(self): return self._position
203
204    def __set_position(self, value):
205        if value not in ['horizontal', 'vertical', None]:
206            raise ValueError("ColourBar position value must be 'horizontal'" + \
207                             " or 'vertical, value recieved :" + str(value))
208        self._position = value
209
210    colourBarPosition = property(__get_position, __set_position, None,
211                    "colour bar position, 'horizontal' or 'vertical'")
212
213       
214
215def getColourBarImage(width=600, height=100,
216                      label=None, 
217                      cmap=cm.jet, 
218                      colourBarMin=None, 
219                      colourBarMax=None, 
220                      colourBarScale=COLOUR_SCHEME_SCALE.LINEAR,
221                      orientation='horizontal',
222                      intervals=None,
223                      intervalNames=None,
224                      numIntervals=None,
225                      colourBarStyle='continuous',
226                      dpi=100):
227   
228    figsize=(width / float(dpi), height / float(dpi))
229    fig = Figure(figsize=figsize, dpi=dpi, facecolor='w')
230   
231    #build a colour scheme
232    log.debug("colourBarScale = %s" % (colourBarScale,))
233    schemeBuilder = geoplot.colour_scheme.ColourSchemeBuilder(cmap=cmap, 
234                 colourBarMin=colourBarMin, 
235                 colourBarMax=colourBarMax, 
236                 colourBarScale=colourBarScale,
237                 numIntervals=numIntervals,
238                 intervals=intervals,
239                 intervalNames=intervalNames)
240   
241    colourScheme = schemeBuilder.buildScheme(colourBarStyle)   
242    log.debug("colourScheme.norm.__class__ = %s" % (colourScheme.norm.__class__,))
243    #for agg bakcend
244    #need about 40px at the bottom of the axes to draw the labels
245    #x = 40.0
246    #cbMin = 0.5
247   
248    #for cairo backend
249    if colourBarStyle == 'legend':
250        x = 40.0
251        cbMin = 0.6
252    else:
253        x = 70.0
254        cbMin = 0.6
255    cbBottom = x/height
256   
257    if cbBottom < 0.1:
258        cbBottom = 0.1
259   
260    if cbBottom > cbMin:
261        cbBottom = cbMin
262       
263    cbHeight = 0.9 - cbBottom
264    axes = fig.add_axes([0.05, cbBottom, 0.9, cbHeight], frameon=False)
265   
266    cb = ColourBar(colourBarLabel=label, 
267                   colourBarPosition=orientation,
268                   colourBarStyle=colourBarStyle)
269   
270    cb.draw(axes, colourScheme)   
271   
272    return geoplot.utils.figureToImage(fig)
Note: See TracBrowser for help on using the repository browser.