source: qesdi/geoplot/trunk/lib/geoplot/tests/unit/test_grid_builder_base.py @ 5403

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/tests/unit/test_grid_builder_base.py@5403
Revision 5403, 11.5 KB checked in by pnorton, 11 years ago (diff)

Moved QESDI tree from DCIP repository
 http://proj.badc.rl.ac.uk/svn/dcip/qesdi@3900.

Line 
1#/urs/bin/env python
2"""
3test_grid_builder_base.py
4"""
5import logging
6
7import nose
8
9import cdms2 as cdms
10import numpy as N
11import numpy.ma as MA
12
13import geoplot.grid_builder_base
14from geoplot.grid_builder_base import GridBuilderBase
15from geoplot.grid import Grid
16
17class Test_GridBuidlerBase(object):
18
19    def setUp(self):
20        self.buildAxes()
21        self.buildCdmsVar()
22        self.gridBuilder = SimpleGridBuilder(self.tempVar)
23   
24    def buildAxes(self):
25        self.axisX = cdms.createAxis(N.array([10,20,30,40,50]))
26        self.axisX.axis = 'X'
27        self.axisX.id = 'X'
28        self.axisY = cdms.createAxis(N.array([4,5,6,7,8,9]))
29        self.axisY.axis = 'Y'
30        self.axisY.id = 'Y'
31       
32        self.xLimits = (self.axisX[0],self.axisX[-1])
33        self.yLimits = (self.axisY[0], self.axisY[-1])
34           
35    def buildCdmsVar(self):
36        self.a = N.array([[1,2,3,4,5],
37                     [6,7,8,9,10],
38                     [11,12,13,14,15],
39                     [16,17,18,19,20],
40                     [21,22,23,24,25],
41                     [26,27,28,29,30]])
42       
43        self.tempVar = cdms.createVariable(self.a,id='temp')
44        self.tempVar.setAxisList([self.axisY, self.axisX])
45        self.tempVar.setMissing(5)
46   
47    def tearDown(self):
48        pass
49   
50    def test_001_checBuildsGrid(self):
51       
52        resultingGrid = self.gridBuilder.buildGrid(self.xLimits, self.yLimits )
53       
54        nose.tools.assert_true(resultingGrid.__class__ == Grid)
55       
56        #chekc the X bounds and midpoints
57        for row in resultingGrid.boundsX:
58            nose.tools.assert_equal(row.tolist(),
59                   [5.0, 15.0, 25.0, 35.0, 45.0, 55.0])
60           
61        for row in resultingGrid.midpointsX:
62            nose.tools.assert_equal(row.tolist(),
63                    [10.0, 20.0 , 30.0, 40.0, 50.0])
64       
65        #check the Y bounds and midpoints
66        for row in zip(*resultingGrid.boundsY):
67            nose.tools.assert_equal(list(row), 
68                    [3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5])
69
70        for row in zip(*resultingGrid.midpointsY):
71            nose.tools.assert_equal(list(row), 
72                    [4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
73       
74        #check the values
75        xlim = 5
76        ylim = 6
77        Z=[[0]*xlim for i in range(ylim)]
78        for j in range(ylim):
79            for i in range(xlim):           
80                Z[j][i] = j*xlim + i               
81        Z = MA.array(Z)
82        Z = Z + 1
83        Z[0,4] = MA.masked
84
85        nose.tools.assert_equal(Z.tolist(999), resultingGrid.values.tolist(999))
86           
87    def test_002_checkBuiltGridHasMissingValues(self):
88        resultingGrid = self.gridBuilder.buildGrid(self.xLimits, self.yLimits)
89        nose.tools.assert_true(resultingGrid.values.mask[0,4])
90       
91    def test_008_checkUsesBoundsFormListWhenGetBoundsReturnsNone(self):
92        originalGrid = self.gridBuilder.buildGrid(self.xLimits, self.yLimits)
93       
94        self.axisY.getBounds = lambda :None
95        self.axisX.getBounds = lambda :None
96        self.buildCdmsVar()
97        builder = SimpleGridBuilder(self.tempVar)
98        newGrid = builder.buildGrid(self.xLimits, self.yLimits) 
99       
100        nose.tools.assert_equal(originalGrid.midpointsX.tolist(),
101                                newGrid.midpointsX.tolist())
102       
103        nose.tools.assert_equal(originalGrid.midpointsY.tolist(),
104                                newGrid.midpointsY.tolist())
105           
106    def test_009_checkIncorrectCdmsAxisWarningsAreWritten(self):
107     
108        oldWarningFunction = logging.Logger.__dict__['warning']
109       
110        logging.Logger.__dict__['warning'] = addCallsCounter(logging.Logger.__dict__['warning'])
111        countBefore =  geoplot.grid_builder_base.log.warning.callCount
112       
113        axisZ = cdms.createAxis(N.array([100]))
114        axisZ.id = 'z'
115
116        b = N.array([self.a])       
117        cdmsVar = cdms.createVariable(b,id='temp')
118        cdmsVar.setAxisList([axisZ, self.axisY, self.axisX])
119       
120        for type in ['Z', 'T', 'another']:
121            axisZ.axis = type
122
123            cdmsVar.setAxisList([axisZ, self.axisY, self.axisX])
124
125            countBefore =  geoplot.grid_builder_base.log.warning.callCount
126
127            builder = SimpleGridBuilder(cdmsVar)
128            newGrid = builder.buildGrid(self.xLimits, self.yLimits)
129                         
130       
131            countAfter =  geoplot.grid_builder_base.log.warning.callCount
132       
133            #you get an addiotnal warning if the axis isLevel or isTime
134            if type in ['Z', 'T']:
135                nose.tools.assert_true((countAfter - countBefore) == 2)
136            else:
137                nose.tools.assert_true((countAfter - countBefore) == 1)
138       
139        #put the old warning function back
140        logging.Logger.__dict__['warning'] = oldWarningFunction
141   
142    def test_010_createBoundsFromList(self):
143        vals = [1, 1.5, 2, 2.5]
144        bounds = GridBuilderBase._createBoundsFormList(vals)
145        nose.tools.assert_equal(bounds.tolist(), [0.75, 1.25, 1.75, 2.25, 2.75])
146       
147    def test_011_mergeBounds(self):
148        unmergedBounds = N.array([[0.75, 1.25], [1.25, 1.75], 
149                                  [1.75, 2.25], [2.25, 2.75]])
150       
151        mergedBounds = GridBuilderBase._mergeBounds(unmergedBounds)
152        nose.tools.assert_equal(mergedBounds.tolist(), [0.75, 1.25, 1.75, 2.25, 2.75]) 
153   
154    def test_012_getBoundsFromAxis(self):
155        axis =  cdms.createAxis(N.array([10.0,20.0,30.0,40.0,50.0]))
156        #if the axis isnt set as lat/lon/level/time no bounds will be generated
157        nose.tools.assert_true(axis.getBounds() == None)
158        bounds = GridBuilderBase._getBoundsFromAxis(axis)
159        nose.tools.assert_equal(bounds.tolist(), 
160                                [5.0, 15.0, 25.0, 35.0, 45.0, 55.0])
161        axis.axis = 'X'
162        nose.tools.assert_true(axis.getBounds() != None)
163        bounds = GridBuilderBase._getBoundsFromAxis(axis)
164        nose.tools.assert_equal(bounds.tolist(), 
165                                [5.0, 15.0, 25.0, 35.0, 45.0, 55.0])
166       
167   
168    def test_013_fillMissingLimitsFromArray(self):
169               
170        array = N.array([7.5, 12.5, 17.5, 22.5, 27.5])
171        arrayMin = 7.5
172        arrayMax = 27.5
173               
174        #test both
175        limits = (None, None)
176        newlimits = GridBuilderBase._fillMissingLimitsFromArray(limits, array)
177        nose.tools.assert_true(type(newlimits) == tuple)
178        nose.tools.assert_equal(newlimits, (arrayMin, arrayMax))
179       
180        #test lower only
181        limits = (None, 25.0)
182        newlimits = GridBuilderBase._fillMissingLimitsFromArray(limits, array)
183        nose.tools.assert_equal(newlimits, (arrayMin, 25.0))
184       
185        #test upper only
186        limits = (17.5, None)
187        newlimits = GridBuilderBase._fillMissingLimitsFromArray(limits, array)
188        nose.tools.assert_equal(newlimits, (17.5, arrayMax))       
189       
190        #test none
191        limits = (20.0, 22.5)
192        newlimits = GridBuilderBase._fillMissingLimitsFromArray(limits, array)
193        nose.tools.assert_equal(newlimits, limits)
194
195    def test__014__replaceNoneInLimitsWithMaxMin(self):
196        midX, midY = self.gridBuilder._buildGridMidpoints(self.tempVar)
197       
198        (minX, maxX) = (midX.min(), midX.max())
199        (minY, maxY) = (midY.min(), midY.max())
200       
201        fullXlimits = (20,40)
202        fullYlimits = (5,7)
203       
204        #check with no None's
205        xLimits = (fullXlimits[0], fullXlimits[1])
206        yLimits = (fullYlimits[0], fullYlimits[1])
207        resX, resY = self.gridBuilder._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
208        nose.tools.assert_equal(resX, xLimits)
209        nose.tools.assert_equal(resY, yLimits)
210       
211        #check with all None
212        xLimits = (None, None)
213        yLimits = (None, None)
214        resX, resY = self.gridBuilder._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
215        nose.tools.assert_equal(resX, (minX, maxX))
216        nose.tools.assert_equal(resY, (minY, maxY))     
217
218        #check with some None
219        xLimits = (fullXlimits[0], None)
220        yLimits = (None, fullYlimits[1])
221        resX, resY = self.gridBuilder._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
222        nose.tools.assert_equal(resX, (xLimits[0], maxX))
223        nose.tools.assert_equal(resY, (minY, yLimits[1]))
224       
225        #check with all None
226        xLimits = (fullXlimits[0], fullXlimits[1])
227        yLimits = (None, fullYlimits[1])
228        resX, resY = self.gridBuilder._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
229        nose.tools.assert_equal(resX, xLimits)
230        nose.tools.assert_equal(resY, (minY, yLimits[1]))
231   
232    def test__015__getResizedVar(self):
233       
234        xLimits = (20,40)
235        yLimits = (5, 7)
236       
237        midX, midY = self.gridBuilder._buildGridMidpoints(self.tempVar)
238        xLimitsAuto = (midX.min(), midX.max())
239        yLimitsAuto = (midY.min(), midY.max())
240       
241        for xLimits in [(20,40), (None, 40), (20, None), (None, None)]:
242            for yLimits in [(5,7), (None, 7), (5, None), (None, None)]:
243                resizedVar = self.gridBuilder._getResizedVar(xLimits, yLimits)
244               
245                resXMin = resizedVar.getLongitude().getValue().min()
246                resXMax = resizedVar.getLongitude().getValue().max()
247                resYMin = resizedVar.getLatitude().getValue().min()
248                resYMax = resizedVar.getLatitude().getValue().max()   
249               
250                if xLimits[0] == None:
251                    nose.tools.assert_equal(resXMin, xLimitsAuto[0])
252                else:
253                    nose.tools.assert_equal(resXMin, xLimits[0])
254
255                if yLimits[0] == None:
256                    nose.tools.assert_equal(resYMin, yLimitsAuto[0])
257                else:
258                    nose.tools.assert_equal(resYMin, yLimits[0])
259                   
260                if xLimits[1] == None:
261                    nose.tools.assert_equal(resXMax, xLimitsAuto[1])
262                else:
263                    nose.tools.assert_equal(resXMax, xLimits[1])
264                   
265                if yLimits[1] == None:
266                    nose.tools.assert_equal(resYMax, yLimitsAuto[1])
267                else:
268                    nose.tools.assert_equal(resYMax, yLimits[1])       
269   
270
271class SimpleGridBuilder(GridBuilderBase):
272    """
273    A simple implementation of the abstract GridBuilderBase class to allow it to
274    be tested.
275    """
276    def __init__(self, cdmsVar):
277        GridBuilderBase.__init__(self, cdmsVar)
278
279    def _resizeVar(self, xLimits, yLimits):
280        return self.cdmsVar(longitude=(xLimits[0], xLimits[1]),
281                            latitude=(yLimits[0], yLimits[1]))
282   
283    def _buildGridBounds(self, cdmsVar):
284       
285       
286        lonBounds = GridBuilderBase._getBoundsFromAxis(cdmsVar.getLongitude())
287        latBounds = GridBuilderBase._getBoundsFromAxis(cdmsVar.getLatitude()) 
288
289        return N.meshgrid(lonBounds, latBounds)
290
291    def _buildGridMidpoints(self, cdmsVar):       
292        midX, midY = N.meshgrid(cdmsVar.getLongitude().getValue(),
293                                cdmsVar.getLatitude().getValue())
294        return (midX, midY)
295
296    def _buildGridValues(self,cdmsVar):
297        return MA.masked_values(cdmsVar.getValue(), cdmsVar.getMissing())
298
299def addCallsCounter(f):
300   
301    def wrapped(*args, **kwargs):
302        wrapped.callCount += 1
303        return f(*args, **kwargs)
304    wrapped.callCount = 0
305   
306    return wrapped
307
308if __name__ == '__main__':
309
310    import geoplot.log_util
311    geoplot.log_util.setGeoplotHandlerToStdOut()
312
313    nose.runmodule()
Note: See TracBrowser for help on using the repository browser.