source: qesdi/linplot/trunk/src/linplot/plot.py @ 6329

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/linplot/trunk/src/linplot/plot.py@6329
Revision 6329, 7.0 KB checked in by pnorton, 11 years ago (diff)

Changed how linplot sets up the legend, it now tries to resize the axis to fit in additional legend items.

Line 
1'''
2Created on 4 Nov 2009
3
4@author: pnorton
5'''
6
7import logging
8import matplotlib
9import numpy
10import math
11from linplot import utils
12from linplot.range import Range
13
14log = logging.getLogger(__name__)
15
16
17
18#AXIS_POSITION = [0.1, 0.15, 0.7, 0.7]
19#LEGEND_POSITION = [0.82, 0.15, 0.15, 0.7]
20
21# one row
22BASE_AXIS_POSITION =  [0.1, 0.20, 0.8, 0.7]
23BASE_LEGEND_POSITION = [0.1, 0.05, 0.8, 0.05]
24
25# two rows
26#AXIS_POSITION = [0.1, 0.25, 0.8, 0.65]
27#LEGEND_POSITION = [0.1, 0.05, 0.8, 0.1]
28
29##three rows
30#AXIS_POSITION = [0.1, 0.30, 0.8, 0.6]
31#LEGEND_POSITION = [0.1, 0.05, 0.8, 0.15]
32
33COLORS = ['red', 'blue', 'yellow', 'green', 'orange', 'purple', 'cyan', 'black', 'brown', 'pink', 'grey', 'lightblue', 'darkblue', 'lightgreen', 'darkgreen']
34
35class Plot(object):
36   
37    def __init__(self, size=(800,600), dpi=100):
38       
39        self.size = size
40        self.dpi = dpi
41       
42        self._xRange = None
43        self._yRange = None
44        self._fig = self._makeFigure()
45        self._ax = self._makeAxes()
46        self._legendAx = self._makeLegendAxes()
47        self._xIndexLabels = None
48   
49    def draw(self, xdata, ydata, **kwargs):
50       
51        nlines = len(self._ax.lines)
52        kwargs.setdefault('label', 'Line #%s' % (nlines,))
53        kwargs.setdefault('color', COLORS[nlines % len(COLORS)])
54               
55        if utils.isString(xdata[0]):
56            self._plotIndexXAxis(xdata, ydata, kwargs)
57        else:
58            self._plotNumericXAxis(xdata, ydata, kwargs)
59   
60    def addXAxisIndexValues(self, indexLabels):
61        """
62        Sets the index values for the x axis, these will be used when string
63        values are passed as xdata
64        """ 
65        self._setXIndexLabels(indexLabels)
66       
67   
68    def _plotIndexXAxis(self, xdata, ydata, kwargs):
69        self._setXIndexLabels(xdata)
70
71        #work out what the xdata is for this axis
72        xIndex = [self._xIndexLabels.index(d) for d in xdata]
73       
74        self._ax.add_line(matplotlib.lines.Line2D(xIndex, ydata, **kwargs))
75   
76    def _setXIndexLabels(self, newLabels):
77       
78        if self._xIndexLabels is None:
79            self._xIndexLabels = [d for d in newLabels]
80        else:
81            for d in newLabels:
82                if d not in self._xIndexLabels:
83                    self._xIndexLabels.append(d)
84       
85        log.debug("self._xIndexLabels = %s" % (self._xIndexLabels,))
86                   
87        # (re) set the x axis labels
88        locations = numpy.arange(len(self._xIndexLabels))
89        loc = matplotlib.ticker.FixedLocator(locations)
90        self._ax.get_xaxis().set_major_locator(loc)
91       
92        self._ax.set_xticklabels(self._xIndexLabels)       
93       
94   
95    def _plotNumericXAxis(self, xdata, ydata, kwargs):
96       
97        if self._xIndexLabels is not None:
98            raise Exception("Can't mix index and non-index values in the x axis")
99       
100        self._ax.add_line(matplotlib.lines.Line2D(xdata, ydata, **kwargs))
101   
102    def setXLabel(self, label):
103        self._ax.set_xlabel(label)
104   
105    def setYLabel(self, label):
106        self._ax.set_ylabel(label)
107       
108    def saveImage(self, outputFile):
109       
110        if len(self._ax.lines) == 0:
111            raise Exception("Trying to create a plot with no lines.")
112       
113        self._scaleAxes()
114        self._populateLegend()
115        im = utils.figureToImage(self._fig)
116        im.save(outputFile)
117   
118    def getImage(self):
119        if len(self._ax.lines) == 0:
120            raise Exception("Trying to create a plot with no lines.")
121       
122        self._scaleAxes()
123        self._populateLegend()
124        im = utils.figureToImage(self._fig)
125        return im
126   
127    def _makeFigure(self):
128        figsize=(self.size[0] / self.dpi, self.size[1] / self.dpi)
129        return matplotlib.figure.Figure(figsize=figsize, dpi=self.dpi, facecolor='w')       
130       
131    def _makeAxes(self):
132        return self._fig.add_axes(BASE_AXIS_POSITION,  frameon=True)
133   
134    def _makeLegendAxes(self):
135        return self._fig.add_axes(BASE_LEGEND_POSITION, xticks=[], yticks=[], frameon=False)
136   
137    def _scaleAxes(self):
138       
139        from matplotlib.transforms import blended_transform_factory
140       
141        self._ax.autoscale_view(tight=True)
142       
143        xlim = self._ax.get_xlim(); ylim = self._ax.get_ylim()
144
145        # for x=0 line
146        trans = blended_transform_factory(self._ax.transAxes, self._ax.transData)
147        self._ax.plot([0,1], [0,0], "-", transform=trans, color="0.6", linewidth=1, zorder=-1)
148       
149        log.debug("ylim = %s" % (ylim,))
150       
151        # reset the limits to avoid including the value from the line above
152        self._ax.set_xlim(xlim); self._ax.set_ylim(ylim)
153               
154        if not self._xRange is None:
155            self._ax.set_xlim(float(self._xRange.minimum), float(self._xRange.maximum))
156       
157        if not self._yRange is None:
158            self._ax.set_ylim(float(self._yRange.minimum), float(self._yRange.maximum))       
159   
160    def _populateLegend(self):
161       
162        log.debug("self._ax.lines = %s" % (self._ax.lines,))
163       
164        handles, labels = self._ax.get_legend_handles_labels()
165        log.debug("len(handles) = %s" % (len(handles),))
166
167        ncol = len(handles) if len(handles) < 3 else 3
168        maxLabelLen = max([len(x) for x in labels])
169        log.debug("labels = %s" % (labels,))
170        log.debug("maxLabelLen = %s" % (maxLabelLen,))
171       
172        if maxLabelLen > 15 and ncol > 2:
173            ncol = 2
174       
175        if maxLabelLen > 30 and ncol > 1:
176            ncol = 1
177       
178        log.debug("ncol = %s" % (ncol,))
179       
180        leg = self._legendAx.legend(handles, labels, ncol=ncol,
181                                    loc=2, mode='expand', borderaxespad=0)
182
183        nlines = len(self._ax.lines) - 1
184        nrow = math.ceil(float(nlines)/ncol)
185
186        axisPos, legPos = self.getAxisPositions(nrow)
187
188        self._ax.set_position(axisPos)
189        self._legendAx.set_position(legPos)
190
191        vp = leg._legend_box._children[-1]._children[0] 
192        log.debug("vp = %s" % (vp,))
193
194        return leg
195   
196   
197    def getAxisPositions(self, nrow):
198
199        rowHeight = 0.05
200        shift = rowHeight * (nrow - 1)
201       
202        axisPos = [0.0] * 4 
203        axisPos[0] = BASE_AXIS_POSITION[0]
204        axisPos[1] = BASE_AXIS_POSITION[1] + shift
205        axisPos[2] = BASE_AXIS_POSITION[2]
206        axisPos[3] = BASE_AXIS_POSITION[3] - shift
207       
208        legPos = [0.0] * 4
209        legPos[0] = BASE_LEGEND_POSITION[0]
210        legPos[1] = BASE_LEGEND_POSITION[1]
211        legPos[2] = BASE_LEGEND_POSITION[2]
212        legPos[3] = BASE_LEGEND_POSITION[3] + shift
213       
214        return axisPos, legPos
215       
216
217    def setXRange(self, minVal, maxVal=None):
218        self._xRange = self._buildRange(minVal, maxVal)
219
220    def setYRange(self, minVal, maxVal=None):
221        self._yRange = self._buildRange(minVal, maxVal)
222
223    def setTitle(self, title):
224        self._ax.set_title(title)
225
226    def _buildRange(self, minVal, maxVal):
227        if minVal is None:
228            return None
229        else:
230            return Range(minVal, maxVal)
231       
Note: See TracBrowser for help on using the repository browser.