source: TI02-CSML/branches/csml-cdms2/tests/test_extract.py @ 3627

Subversion URL: http://proj.badc.rl.ac.uk/svn/ndg/TI02-CSML/branches/csml-cdms2/tests/test_extract.py@3627
Revision 3627, 10.6 KB checked in by spascoe, 12 years ago (diff)

This branch contains CSML converted to use cdat_lite-5.

  • convertcdms was run on the source
  • the MA.set_print_limit call was changed to the numpy equivilent
  • The tests were changed to account for the existence of numpy scalar types.

All tests, except the two known to fail, pass on i686 ubuntu.

Line 
1# Adapted for numpy/ma/cdms2 by convertcdms.py
2"""
3Test CSML subsetToGridSeries against cdms equivilent.
4
5"""
6
7import csml, cdms2 as cdms
8import os
9import tempfile
10
11from unittest import TestCase
12import numpy.oldnumeric as N
13import numpy
14
15# For image dump
16import Image
17from matplotlib import cm, colors
18
19
20here = os.path.dirname(os.path.abspath(__file__))
21data_csml = os.path.join(here, 'data', 'test.csml')
22
23#----------------------------------------------------------------------------
24# Utility functions
25
26def extract_csml(file, field, **sel):
27    d = csml.parser.Dataset()
28    d.parse(file)
29   
30    # Find feature by name
31    for id in d.getFeatureList():
32        f = d.getFeature(id)
33        if f.name.CONTENT == field:
34            break
35    else:
36        raise ValueError, 'No feature found with name %s' % field
37
38    # Extract to a temporary file
39    (fd, tmp) = tempfile.mkstemp('.nc', 'test_csml_'); os.close(fd)
40    try:
41        f.subsetToGridSeries(outputdir=os.path.dirname(tmp),
42                             ncname=os.path.basename(tmp),
43                             **sel)
44
45        # Open temporary file
46        d2 = cdms.open(tmp)
47        var = d2(field, squeeze=1)
48    finally:
49        os.remove(tmp)
50    return var
51
52def dump_img(var, filename, width=100, height=100):
53    norm = colors.normalize(0, 1000.)
54    cmap = cm.get_cmap()
55    a = norm(var)
56    img_buf = (cmap(a)*255).astype('B')
57    (y, x, c) = img_buf.shape
58    img = Image.frombuffer("RGBA", (x, y), img_buf.tostring(), 'raw', 'RGBA',
59                           0, 1)
60    img.save(filename)
61
62#---------------------------------------------------------------------------
63
64class TestExtractAll(TestCase):
65    """
66    Extract the entire lat/lon grid and check for correct shape and border.
67   
68    """
69    def setUp(self):
70        self._testKeys = ['test_0_360p', 'test_0_360m', 'test_m180_180m',
71                          'test_m180_180p']
72        self._selectors = {
73            'test_0_360p': dict(time='1980-01-01T00:00:00.0'),
74            'test_0_360m': dict(time='1980-01-01T00:00:00.0'),
75            'test_m180_180m': dict(time='1980-01-01T00:00:00.0'),
76            'test_m180_180p': dict(time='1980-01-01T00:00:00.0'),
77            }                   
78
79    def _checkValues(self, var):
80        """
81        Values should be the sum of lat and lon
82
83        """
84        lat = var.getLatitude().getValue()
85        lon = var.getLongitude().getValue()
86
87        print lat
88        print lon
89
90        print var[:10, :10]
91
92        self.assertEquals(var[0,5], lat[0] + lon[5])
93        self.assertEquals(var[0,30], lat[0] + lon[30])
94
95        self.assertEquals(var[30,0], lat[30] + lon[0])
96        self.assertEquals(var[5,0], lat[5] + lon[0])
97           
98
99    def _extract(self, i):
100        k = self._testKeys[i]
101        return extract_csml(data_csml, k, **self._selectors[k])
102           
103    def test1(self):
104        var = self._extract(0)
105        self.assertEquals(var.shape, (36,36))
106        self._checkValues(var)
107
108    def test2(self):
109        var = self._extract(1)
110        self.assertEquals(var.shape, (36,36))
111        self._checkValues(var)
112
113
114    def test3(self):
115        var = self._extract(2)
116        self.assertEquals(var.shape, (36,36))
117        self._checkValues(var)
118
119    def test4(self):
120        var = self._extract(3)
121        self.assertEquals(var.shape, (36,36))
122        self._checkValues(var)
123
124
125class TestExtractAll_360(TestExtractAll):
126    """
127    Extract everything with explicit latitude/longitude selection
128
129    """
130    def setUp(self):
131        super(TestExtractAll_360, self).setUp()
132        for k in self._selectors:
133            self._selectors[k].update(
134                dict(longitude=(0, 360), latitude=(-90, 90))
135                )
136
137class TestExtractAll_180(TestExtractAll):
138    """
139    Extract everything with explicit latitude/longitude selection
140
141    """
142    def setUp(self):
143        super(TestExtractAll_180, self).setUp()
144        for k in self._selectors:
145            self._selectors[k].update(
146                dict(longitude=(-180, 180), latitude=(-90, 90))
147                )
148
149
150#---------------------------------------------------------------------------
151
152
153class TestExtractQuadrant(TestCase):
154    """
155    Abstract base class to set things up.
156
157    """
158    def _extract1(self):
159        var = extract_csml(data_csml, 'test_0_360p',
160                            time='1980-01-01T00:00:00.0',
161                            latitude=(0, 90),
162                            longitude=(0,180))
163        return var
164   
165    def _extract2(self):
166        var = extract_csml(data_csml, 'test_0_360m',
167                            time='1980-01-01T00:00:00.0',
168                            latitude=(0, 90),
169                            longitude=(0,180))
170        return var
171   
172    def _extract3(self):
173        var = extract_csml(data_csml, 'test_m180_180m',
174                            time='1980-01-01T00:00:00.0',
175                            latitude=(0, 90),
176                            longitude=(-180,0))
177        return var
178   
179    def _extract4(self):
180        var = extract_csml(data_csml, 'test_m180_180p',
181                            time='1980-01-01T00:00:00.0',
182                            latitude=(0, 90),
183                            longitude=(-180,0))
184        return var
185
186class TestExtractShape(TestExtractQuadrant):
187    """
188    Test the shape of extracted GridSeries.
189
190    """
191    def test1(self):
192        var = self._extract1()
193        self.assertEquals(var.shape, (18,18))
194
195    def test2(self):
196        var = self._extract2()
197        self.assertEquals(var.shape, (18,18))
198
199    def test3(self):
200        var = self._extract3()
201        self.assertEquals(var.shape, (18,18))
202
203    def test4(self):
204        var = self._extract4()
205        self.assertEquals(var.shape, (18,18))
206
207
208class TestLatOrdering(TestExtractQuadrant):
209    """
210    Test the shape of extracted GridSeries.
211
212    """
213    def test1(self):
214        var = self._extract1()
215        lat = var.getLatitude()
216        dlat = lat[1] - lat[0]
217        assert dlat > 0
218       
219
220    def test2(self):
221        var = self._extract2()
222        lat = var.getLatitude()
223        dlat = lat[1] - lat[0]
224        assert dlat > 0
225
226
227    def test3(self):
228        var = self._extract3()
229        lat = var.getLatitude()
230        dlat = lat[1] - lat[0]
231        assert dlat > 0
232
233
234    def test4(self):
235        var = self._extract4()
236        lat = var.getLatitude()
237        dlat = lat[1] - lat[0]
238        assert dlat > 0
239
240class TestLonOrdering(TestExtractQuadrant):
241    """
242    Test the shape of extracted GridSeries.
243
244    """
245    def test1(self):
246        var = self._extract1()
247        lon = var.getLongitude()
248        dlon = lon[1] - lon[0]
249        assert dlon > 0
250       
251
252    def test2(self):
253        var = self._extract2()
254        lon = var.getLongitude()
255        dlon = lon[1] - lon[0]
256        assert dlon > 0
257
258
259    def test3(self):
260        var = self._extract3()
261        lon = var.getLongitude()
262        dlon = lon[1] - lon[0]
263        assert dlon > 0
264
265
266    def test4(self):
267        var = self._extract4()
268        lon = var.getLongitude()
269        dlon = lon[1] - lon[0]
270        assert dlon > 0
271
272#---------------------------------------------------------------------------------------------------------
273
274class TestExtractPoint(TestCase):
275    """
276    Extract a single grid box using min/max lat/lon.
277
278    """
279    def _extract1(self, lon, lat):
280        var = extract_csml(data_csml, 'test_0_360p',
281                            time='1980-01-01T00:00:00.0',
282                            latitude=lat,
283                            longitude=lon)
284        return var
285   
286    def _extract2(self, lon, lat):
287        var = extract_csml(data_csml, 'test_0_360m',
288                            time='1980-01-01T00:00:00.0',
289                            latitude=lat,
290                            longitude=lon)
291        return var
292   
293    def _extract3(self, lon, lat):
294        var = extract_csml(data_csml, 'test_m180_180m',
295                            time='1980-01-01T00:00:00.0',
296                            latitude=lat,
297                            longitude=lon)
298        return var
299   
300    def _extract4(self, lon, lat):
301        var = extract_csml(data_csml, 'test_m180_180p',
302                            time='1980-01-01T00:00:00.0',
303                            latitude=lat,
304                            longitude=lon)
305        return var
306
307    def test1(self):
308        var = self._extract1(85, 42.5)
309        self.assertEquals(type(var), numpy.float64)
310        self.assertEquals(var, 85.0+42.5)
311       
312    def test2(self):
313        var = self._extract2(85, 42.5)
314        self.assertEquals(type(var), numpy.float64)
315        self.assertEquals(var, 85.0+42.5)
316
317    def test3(self):
318        var = self._extract3(85, 42.5)
319        self.assertEquals(type(var), numpy.float64)
320        self.assertEquals(var, 85.0+42.5)
321
322    def test4(self):
323        var = self._extract4(85, 42.5)
324        self.assertEquals(type(var), numpy.float64)
325        self.assertEquals(var, 85.0+42.5)
326       
327
328#---------------------------------------------------------------------------------------------------------
329
330class TestExtractPointWithBounds(TestCase):
331    """
332    Extract a single grid box using min/max lat/lon.
333
334    """
335    def _extract1(self, lon, lat):
336        var = extract_csml(data_csml, 'test_0_360p',
337                            time='1980-01-01T00:00:00.0',
338                            latitude=(lat-2.5, lat+2.5),
339                            longitude=(lon-5, lon+5))
340        return var
341   
342    def _extract2(self, lon, lat):
343        var = extract_csml(data_csml, 'test_0_360m',
344                            time='1980-01-01T00:00:00.0',
345                            latitude=(lat-2.5, lat+2.5),
346                            longitude=(lon-5, lon+5))
347        return var
348   
349    def _extract3(self, lon, lat):
350        var = extract_csml(data_csml, 'test_m180_180m',
351                            time='1980-01-01T00:00:00.0',
352                            latitude=(lat-2.5, lat+2.5),
353                            longitude=(lon-5, lon+5))
354        return var
355   
356    def _extract4(self, lon, lat):
357        var = extract_csml(data_csml, 'test_m180_180p',
358                            time='1980-01-01T00:00:00.0',
359                            latitude=(lat-2.5, lat+2.5),
360                            longitude=(lon-5, lon+5))
361        return var
362
363    def test1(self):
364        var = self._extract1(85, 42.5)
365        self.assertEquals(type(var), numpy.float64)
366       
367    def test2(self):
368        var = self._extract2(85, 42.5)
369        self.assertEquals(type(var), numpy.float64)
370       
371    def test3(self):
372        var = self._extract3(85, 42.5)
373        self.assertEquals(type(var), numpy.float64)
374       
375    def test4(self):
376        var = self._extract4(85, 42.5)
377        self.assertEquals(type(var), numpy.float64)
378               
Note: See TracBrowser for help on using the repository browser.