source: qesdi/geoplot/trunk/lib/geoplot/tests/unit/test_grid_builder_base.py @ 6308

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/qesdi/geoplot/trunk/lib/geoplot/tests/unit/test_grid_builder_base.py@6308
Revision 6308, 12.5 KB checked in by pnorton, 11 years ago (diff)

Applied the fix to _mergeBounds so that it will cope with reversed bounds.

Line 
1#/urs/bin/env python
2"""
3test_grid_builder_base.py
4"""
5import logging
6
7import nose
8
9import cdms2 as cdms
10import numpy as N
11import numpy.ma as MA
12
13import geoplot.grid_builder_base
14from geoplot.grid_builder_base import GridBuilderBase
15from geoplot.grid import Grid
16
17class Test_GridBuidlerBase(object):
18
19    def setUp(self):
20        self.buildAxes()
21        self.buildCdmsVar()
22        self.gridBuilder = SimpleGridBuilder(self.tempVar)
23   
24    def buildAxes(self):
25        self.axisX = cdms.createAxis(N.array([10,20,30,40,50]))
26        self.axisX.axis = 'X'
27        self.axisX.id = 'X'
28        self.axisY = cdms.createAxis(N.array([4,5,6,7,8,9]))
29        self.axisY.axis = 'Y'
30        self.axisY.id = 'Y'
31       
32        self.xLimits = (self.axisX[0],self.axisX[-1])
33        self.yLimits = (self.axisY[0], self.axisY[-1])
34           
35    def buildCdmsVar(self):
36        self.a = N.array([[1,2,3,4,5],
37                     [6,7,8,9,10],
38                     [11,12,13,14,15],
39                     [16,17,18,19,20],
40                     [21,22,23,24,25],
41                     [26,27,28,29,30]])
42       
43        self.tempVar = cdms.createVariable(self.a,id='temp')
44        self.tempVar.setAxisList([self.axisY, self.axisX])
45        self.tempVar.setMissing(5)
46   
47    def tearDown(self):
48        pass
49   
50    def test_001_checBuildsGrid(self):
51       
52        resultingGrid = self.gridBuilder.buildGrid(self.xLimits, self.yLimits )
53       
54        nose.tools.assert_true(resultingGrid.__class__ == Grid)
55       
56        #chekc the X bounds and midpoints
57        for row in resultingGrid.boundsX:
58            nose.tools.assert_equal(row.tolist(),
59                   [5.0, 15.0, 25.0, 35.0, 45.0, 55.0])
60           
61        for row in resultingGrid.midpointsX:
62            nose.tools.assert_equal(row.tolist(),
63                    [10.0, 20.0 , 30.0, 40.0, 50.0])
64       
65        #check the Y bounds and midpoints
66        for row in zip(*resultingGrid.boundsY):
67            nose.tools.assert_equal(list(row), 
68                    [3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5])
69
70        for row in zip(*resultingGrid.midpointsY):
71            nose.tools.assert_equal(list(row), 
72                    [4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
73       
74        #check the values
75        xlim = 5
76        ylim = 6
77        Z=[[0]*xlim for i in range(ylim)]
78        for j in range(ylim):
79            for i in range(xlim):           
80                Z[j][i] = j*xlim + i               
81        Z = MA.array(Z)
82        Z = Z + 1
83        Z[0,4] = MA.masked
84
85        nose.tools.assert_equal(Z.tolist(999), resultingGrid.values.tolist(999))
86           
87    def test_002_checkBuiltGridHasMissingValues(self):
88        resultingGrid = self.gridBuilder.buildGrid(self.xLimits, self.yLimits)
89        nose.tools.assert_true(resultingGrid.values.mask[0,4])
90       
91    def test_008_checkUsesBoundsFormListWhenGetBoundsReturnsNone(self):
92        originalGrid = self.gridBuilder.buildGrid(self.xLimits, self.yLimits)
93       
94        self.axisY.getBounds = lambda :None
95        self.axisX.getBounds = lambda :None
96        self.buildCdmsVar()
97        builder = SimpleGridBuilder(self.tempVar)
98        newGrid = builder.buildGrid(self.xLimits, self.yLimits) 
99       
100        nose.tools.assert_equal(originalGrid.midpointsX.tolist(),
101                                newGrid.midpointsX.tolist())
102       
103        nose.tools.assert_equal(originalGrid.midpointsY.tolist(),
104                                newGrid.midpointsY.tolist())
105           
106    def test_009_checkIncorrectCdmsAxisWarningsAreWritten(self):
107     
108        oldWarningFunction = logging.Logger.__dict__['warning']
109       
110        logging.Logger.__dict__['warning'] = addCallsCounter(logging.Logger.__dict__['warning'])
111        countBefore =  geoplot.grid_builder_base.log.warning.callCount
112       
113        axisZ = cdms.createAxis(N.array([100]))
114        axisZ.id = 'z'
115
116        b = N.array([self.a])       
117        cdmsVar = cdms.createVariable(b,id='temp')
118        cdmsVar.setAxisList([axisZ, self.axisY, self.axisX])
119       
120        for type in ['Z', 'T', 'another']:
121            axisZ.axis = type
122
123            cdmsVar.setAxisList([axisZ, self.axisY, self.axisX])
124
125            countBefore =  geoplot.grid_builder_base.log.warning.callCount
126
127            builder = SimpleGridBuilder(cdmsVar)
128            newGrid = builder.buildGrid(self.xLimits, self.yLimits)
129                         
130       
131            countAfter =  geoplot.grid_builder_base.log.warning.callCount
132       
133            #you get an addiotnal warning if the axis isLevel or isTime
134            if type in ['Z', 'T']:
135                nose.tools.assert_true((countAfter - countBefore) == 2)
136            else:
137                nose.tools.assert_true((countAfter - countBefore) == 1)
138       
139        #put the old warning function back
140        logging.Logger.__dict__['warning'] = oldWarningFunction
141   
142    def test_010_createBoundsFromList(self):
143        vals = [1, 1.5, 2, 2.5]
144        bounds = GridBuilderBase._createBoundsFormList(vals)
145        nose.tools.assert_equal(bounds.tolist(), [0.75, 1.25, 1.75, 2.25, 2.75])
146       
147    def test_011_mergeBounds(self):
148        unmergedBounds = N.array([[0.75, 1.25], [1.25, 1.75], [1.75, 2.25], [2.25, 2.75]])
149       
150        mergedBounds = GridBuilderBase._mergeBounds(unmergedBounds)
151        nose.tools.assert_equal(mergedBounds.tolist(), [0.75, 1.25, 1.75, 2.25, 2.75]) 
152       
153        # ascending, in order [low, high]
154        unmergedBounds = N.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
155        mergedBounds = GridBuilderBase._mergeBounds(unmergedBounds)
156        nose.tools.assert_equal(mergedBounds.tolist(), [1.0, 2.0, 3.0, 4.0])
157   
158        # ascending in order [high, low]
159        unmergedBounds = N.array([[2.0, 1.0], [3.0, 2.0], [4.0, 3.0]])
160        mergedBounds = GridBuilderBase._mergeBounds(unmergedBounds)
161        nose.tools.assert_equal(mergedBounds.tolist(), [1.0, 2.0, 3.0, 4.0])
162       
163        # decending in order [low, high]
164        unmergedBounds = N.array([[3.0, 4.0], [2.0, 3.0], [1.0, 2.0]])
165        mergedBounds = GridBuilderBase._mergeBounds(unmergedBounds)
166        nose.tools.assert_equal(mergedBounds.tolist(), [4.0, 3.0, 2.0, 1.0])
167       
168        # decending in order [high, low]
169        unmergedBounds = N.array([[4.0, 3.0], [3.0, 2.0], [2.0, 1.0]])
170        mergedBounds = GridBuilderBase._mergeBounds(unmergedBounds)
171        nose.tools.assert_equal(mergedBounds.tolist(), [4.0, 3.0, 2.0, 1.0])
172           
173    def test_012_getBoundsFromAxis(self):
174        axis =  cdms.createAxis(N.array([10.0,20.0,30.0,40.0,50.0]))
175        #if the axis isnt set as lat/lon/level/time no bounds will be generated
176        nose.tools.assert_true(axis.getBounds() == None)
177        bounds = GridBuilderBase._getBoundsFromAxis(axis)
178        nose.tools.assert_equal(bounds.tolist(), 
179                                [5.0, 15.0, 25.0, 35.0, 45.0, 55.0])
180        axis.axis = 'X'
181        nose.tools.assert_true(axis.getBounds() != None)
182        bounds = GridBuilderBase._getBoundsFromAxis(axis)
183        nose.tools.assert_equal(bounds.tolist(), 
184                                [5.0, 15.0, 25.0, 35.0, 45.0, 55.0])
185       
186   
187    def test_013_fillMissingLimitsFromArray(self):
188               
189        array = N.array([7.5, 12.5, 17.5, 22.5, 27.5])
190        arrayMin = 7.5
191        arrayMax = 27.5
192               
193        #test both
194        limits = (None, None)
195        newlimits = GridBuilderBase._fillMissingLimitsFromArray(limits, array)
196        nose.tools.assert_true(type(newlimits) == tuple)
197        nose.tools.assert_equal(newlimits, (arrayMin, arrayMax))
198       
199        #test lower only
200        limits = (None, 25.0)
201        newlimits = GridBuilderBase._fillMissingLimitsFromArray(limits, array)
202        nose.tools.assert_equal(newlimits, (arrayMin, 25.0))
203       
204        #test upper only
205        limits = (17.5, None)
206        newlimits = GridBuilderBase._fillMissingLimitsFromArray(limits, array)
207        nose.tools.assert_equal(newlimits, (17.5, arrayMax))       
208       
209        #test none
210        limits = (20.0, 22.5)
211        newlimits = GridBuilderBase._fillMissingLimitsFromArray(limits, array)
212        nose.tools.assert_equal(newlimits, limits)
213
214    def test__014__replaceNoneInLimitsWithMaxMin(self):
215        midX, midY = self.gridBuilder._buildGridMidpoints(self.tempVar)
216       
217        (minX, maxX) = (midX.min(), midX.max())
218        (minY, maxY) = (midY.min(), midY.max())
219       
220        fullXlimits = (20,40)
221        fullYlimits = (5,7)
222       
223        #check with no None's
224        xLimits = (fullXlimits[0], fullXlimits[1])
225        yLimits = (fullYlimits[0], fullYlimits[1])
226        resX, resY = self.gridBuilder._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
227        nose.tools.assert_equal(resX, xLimits)
228        nose.tools.assert_equal(resY, yLimits)
229       
230        #check with all None
231        xLimits = (None, None)
232        yLimits = (None, None)
233        resX, resY = self.gridBuilder._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
234        nose.tools.assert_equal(resX, (minX, maxX))
235        nose.tools.assert_equal(resY, (minY, maxY))     
236
237        #check with some None
238        xLimits = (fullXlimits[0], None)
239        yLimits = (None, fullYlimits[1])
240        resX, resY = self.gridBuilder._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
241        nose.tools.assert_equal(resX, (xLimits[0], maxX))
242        nose.tools.assert_equal(resY, (minY, yLimits[1]))
243       
244        #check with all None
245        xLimits = (fullXlimits[0], fullXlimits[1])
246        yLimits = (None, fullYlimits[1])
247        resX, resY = self.gridBuilder._replaceNoneInLimitsWithMaxMin(xLimits, yLimits)
248        nose.tools.assert_equal(resX, xLimits)
249        nose.tools.assert_equal(resY, (minY, yLimits[1]))
250   
251    def test__015__getResizedVar(self):
252       
253        xLimits = (20,40)
254        yLimits = (5, 7)
255       
256        midX, midY = self.gridBuilder._buildGridMidpoints(self.tempVar)
257        xLimitsAuto = (midX.min(), midX.max())
258        yLimitsAuto = (midY.min(), midY.max())
259       
260        for xLimits in [(20,40), (None, 40), (20, None), (None, None)]:
261            for yLimits in [(5,7), (None, 7), (5, None), (None, None)]:
262                resizedVar = self.gridBuilder._getResizedVar(xLimits, yLimits)
263               
264                resXMin = resizedVar.getLongitude().getValue().min()
265                resXMax = resizedVar.getLongitude().getValue().max()
266                resYMin = resizedVar.getLatitude().getValue().min()
267                resYMax = resizedVar.getLatitude().getValue().max()   
268               
269                if xLimits[0] == None:
270                    nose.tools.assert_equal(resXMin, xLimitsAuto[0])
271                else:
272                    nose.tools.assert_equal(resXMin, xLimits[0])
273
274                if yLimits[0] == None:
275                    nose.tools.assert_equal(resYMin, yLimitsAuto[0])
276                else:
277                    nose.tools.assert_equal(resYMin, yLimits[0])
278                   
279                if xLimits[1] == None:
280                    nose.tools.assert_equal(resXMax, xLimitsAuto[1])
281                else:
282                    nose.tools.assert_equal(resXMax, xLimits[1])
283                   
284                if yLimits[1] == None:
285                    nose.tools.assert_equal(resYMax, yLimitsAuto[1])
286                else:
287                    nose.tools.assert_equal(resYMax, yLimits[1])       
288   
289
290class SimpleGridBuilder(GridBuilderBase):
291    """
292    A simple implementation of the abstract GridBuilderBase class to allow it to
293    be tested.
294    """
295    def __init__(self, cdmsVar):
296        GridBuilderBase.__init__(self, cdmsVar)
297
298    def _resizeVar(self, xLimits, yLimits):
299        return self.cdmsVar(longitude=(xLimits[0], xLimits[1]),
300                            latitude=(yLimits[0], yLimits[1]))
301   
302    def _buildGridBounds(self, cdmsVar):
303       
304       
305        lonBounds = GridBuilderBase._getBoundsFromAxis(cdmsVar.getLongitude())
306        latBounds = GridBuilderBase._getBoundsFromAxis(cdmsVar.getLatitude()) 
307
308        return N.meshgrid(lonBounds, latBounds)
309
310    def _buildGridMidpoints(self, cdmsVar):       
311        midX, midY = N.meshgrid(cdmsVar.getLongitude().getValue(),
312                                cdmsVar.getLatitude().getValue())
313        return (midX, midY)
314
315    def _buildGridValues(self,cdmsVar):
316        return MA.masked_values(cdmsVar.getValue(), cdmsVar.getMissing())
317
318def addCallsCounter(f):
319   
320    def wrapped(*args, **kwargs):
321        wrapped.callCount += 1
322        return f(*args, **kwargs)
323    wrapped.callCount = 0
324   
325    return wrapped
326
327if __name__ == '__main__':
328
329    import geoplot.log_util
330    geoplot.log_util.setGeoplotHandlerToStdOut()
331
332    nose.runmodule()
Note: See TracBrowser for help on using the repository browser.