1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6import struct
7import logging
8import numpy as num
10try:
11 range = xrange
12except NameError:
13 pass
15logger = logging.getLogger('pyrocko.spit')
17or_ = num.logical_or
18and_ = num.logical_and
19not_ = num.logical_not
20all_ = num.all
21any_ = num.any
24class OutOfBounds(Exception):
25 pass
28class Cell(object):
29 def __init__(self, tree, index, f=None):
30 self.tree = tree
31 self.index = index
32 self.depths = num.log2(index).astype(int)
33 self.bad = False
34 self.children = []
35 n = 2**self.depths
36 i = self.index - n
37 delta = (self.tree.xbounds[:, 1] - self.tree.xbounds[:, 0])/n
38 xmin = self.tree.xbounds[:, 0]
39 self.xbounds = self.tree.xbounds.copy()
40 self.xbounds[:, 0] = xmin + i * delta
41 self.xbounds[:, 1] = xmin + (i+1) * delta
42 self.a = self.xbounds[:, ::-1].copy()
43 self.b = self.a.copy()
44 self.b[:, 1] = self.xbounds[:, 1] - self.xbounds[:, 0]
45 self.b[:, 0] = - self.b[:, 1]
47 self.a[:, 0] += (self.b[:, 0] == 0.0)*0.5
48 self.a[:, 1] -= (self.b[:, 1] == 0.0)*0.5
49 self.b[:, 0] -= (self.b[:, 0] == 0.0)
50 self.b[:, 1] += (self.b[:, 1] == 0.0)
52 if f is None:
53 it = nditer_outer(tuple(self.xbounds) + (None,))
54 for vvv in it:
55 vvv[-1][...] = self.tree._f_cached(vvv[:-1])
57 self.f = it.operands[-1]
58 else:
59 self.f = f
61 def interpolate(self, x):
62 if self.children:
63 for cell in self.children:
64 if all_(and_(cell.xbounds[:, 0] <= x,
65 x <= cell.xbounds[:, 1])):
66 return cell.interpolate(x)
68 else:
69 if all_(num.isfinite(self.f)):
70 ws = (x[:, num.newaxis] - self.a)/self.b
71 wn = num.multiply.reduce(
72 num.array(num.ix_(*ws), dtype=object))
73 return num.sum(self.f * wn)
74 else:
75 return None
77 def interpolate_many(self, x):
78 if self.children:
79 result = num.empty(x.shape[0], dtype=float)
80 result[:] = None
81 for cell in self.children:
82 indices = num.where(
83 self.tree.ndim == num.sum(and_(
84 cell.xbounds[:, 0] <= x,
85 x <= cell.xbounds[:, 1]), axis=-1))[0]
87 if indices.size != 0:
88 result[indices] = cell.interpolate_many(x[indices])
90 return result
92 else:
93 if all_(num.isfinite(self.f)):
94 ws = (x[..., num.newaxis] - self.a)/self.b
95 npoints = ws.shape[0]
96 ndim = self.tree.ndim
97 ws_pimped = [ws[:, i, :] for i in range(ndim)]
98 for i in range(ndim):
99 s = [npoints] + [1] * ndim
100 s[1+i] = 2
101 ws_pimped[i].shape = tuple(s)
103 wn = ws_pimped[0]
104 for idim in range(1, ndim):
105 wn = wn * ws_pimped[idim]
107 result = wn * self.f
108 for i in range(ndim):
109 result = num.sum(result, axis=-1)
111 return result
112 else:
113 result = num.empty(x.shape[0], dtype=float)
114 result[:] = None
115 return result
117 def slice(self, x):
118 x = num.array(x, dtype=float)
119 x_mask = not_(num.isfinite(x))
120 x_ = x.copy()
121 x_[x_mask] = 0.0
122 return [
123 cell for cell in self.children if all_(or_(
124 x_mask,
125 and_(
126 cell.xbounds[:, 0] <= x_,
127 x_ <= cell.xbounds[:, 1])))]
129 def plot_rects(self, axes, x, dims):
130 if self.children:
131 for cell in self.slice(x):
132 cell.plot_rects(axes, x, dims)
134 else:
135 points = []
136 for iy, ix in ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)):
137 points.append(
138 (self.xbounds[dims[0], iy], self.xbounds[dims[1], ix]))
140 points = num.transpose(points)
141 axes.plot(points[1], points[0], color=(0.1, 0.1, 0.0, 0.1))
143 def check_holes(self):
144 '''
145 Check if :py:class:`Cell` or its children contain NaNs.
146 '''
147 if self.children:
148 return any([child.check_holes() for child in self.children])
149 else:
150 return num.any(num.isnan(self.f))
152 def plot_2d(self, axes, x, dims):
153 idims = num.array(dims)
154 self.plot_rects(axes, x, dims)
155 coords = [
156 num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d))
157 for (xb, d) in zip(self.xbounds[idims, :], self.tree.xtols[idims])]
159 npoints = coords[0].size * coords[1].size
160 g = num.meshgrid(*coords[::-1])[::-1]
161 points = num.empty((npoints, self.tree.ndim), dtype=float)
162 for idim in range(self.tree.ndim):
163 try:
164 idimout = dims.index(idim)
165 points[:, idim] = g[idimout].ravel()
166 except ValueError:
167 points[:, idim] = x[idim]
169 fi = num.empty((coords[0].size, coords[1].size), dtype=float)
170 fi_r = fi.ravel()
171 fi_r[...] = self.interpolate_many(points)
173 if num.any(num.isnan(fi)):
174 logger.warning('')
175 if any_(num.isfinite(fi)):
176 fi = num.ma.masked_invalid(fi)
177 axes.imshow(
178 fi, origin='lower',
179 extent=[coords[1].min(), coords[1].max(),
180 coords[0].min(), coords[0].max()],
181 interpolation='nearest',
182 aspect='auto',
183 cmap='RdYlBu')
185 def plot_1d(self, axes, x, dim):
186 xb = self.xbounds[dim]
187 d = self.tree.xtols[dim]
188 coords = num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d))
190 npoints = coords.size
191 points = num.empty((npoints, self.tree.ndim), dtype=float)
192 for idim in range(self.tree.ndim):
193 if idim == dim:
194 points[:, idim] = coords
195 else:
196 points[:, idim] = x[idim]
198 fi = self.interpolate_many(points)
199 if any_(num.isfinite(fi)):
200 fi = num.ma.masked_invalid(fi)
201 axes.plot(coords, fi)
203 def __iter__(self):
204 yield self
205 for c in self.children:
206 for x in c:
207 yield x
209 def dump(self, file):
210 self.index.astype('<i4').tofile(file)
211 self.f.astype('<f8').tofile(file)
212 for c in self.children:
213 c.dump(file)
216def bread(f, fmt):
217 s = f.read(struct.calcsize(fmt))
218 return struct.unpack(fmt, s)
221class SPTree(object):
223 def __init__(self, f=None, ftol=None, xbounds=None, xtols=None,
224 filename=None, addargs=()):
226 '''
227 Create n-dimensional space partitioning interpolator.
229 :param f: callable function f(x) where x is a vector of size n
230 :param ftol: target accuracy |f_interp(x) - f(x)| <= ftol
231 :param xbounds: bounds of x, shape (n, 2)
232 :param xtols: target coarsenesses in x, vector of size n
233 :param addargs: additional arguments to pass to f
234 '''
236 if filename is None:
237 assert all(v is not None for v in (f, ftol, xbounds, xtols))
239 self.f = f
240 self.ftol = float(ftol)
241 self.f_values = {}
242 self.ncells = 0
243 self.addargs = addargs
245 self.xbounds = num.asarray(xbounds, dtype=float)
246 assert self.xbounds.ndim == 2
247 assert self.xbounds.shape[1] == 2
248 self.ndim = self.xbounds.shape[0]
250 self.xtols = num.asarray(xtols, dtype=float)
251 assert self.xtols.ndim == 1 and self.xtols.size == self.ndim
253 self.maxdepths = num.ceil(num.log2(
254 num.maximum(
255 1.0,
256 (self.xbounds[:, 1] - self.xbounds[:, 0]) / self.xtols)
257 )).astype(int)
259 self.root = None
260 self.ones_int = num.ones(self.ndim, dtype=int)
262 cc = num.ix_(*[num.arange(3)]*self.ndim)
263 w = num.zeros([3]*self.ndim + [self.ndim, 2])
264 for i, c in enumerate(cc):
265 w[..., i, 0] = (2-c)*0.5
266 w[..., i, 1] = c*0.5
268 self.pointmaker = w
269 self.pointmaker_mask = num.sum(w[..., 0] == 0.5, axis=-1) != 0
270 self.pointmaker_masked = w[self.pointmaker_mask]
272 self.nothing_found_yet = True
274 self.root = Cell(self, self.ones_int)
275 self.ncells += 1
277 self.fraction_bad = 0.0
278 self.nbad = 0
279 self.cells_to_continue = []
280 for clipdepth in range(0, num.max(self.maxdepths)+1):
281 self.clipdepth = clipdepth
282 self.tested = 0
283 if self.clipdepth == 0:
284 self._fill(self.root)
285 else:
286 self._continue_fill()
288 self.status()
290 if not self.cells_to_continue:
291 break
293 else:
294 self._load(filename)
296 def status(self):
297 perc = (1.0-self.fraction_bad)*100
298 s = '%6.1f%%' % perc
300 if self.fraction_bad != 0.0 and s == ' 100.0%':
301 s = '~100.0%'
303 logger.info('at level %2i: %s covered, %6i cell%s' % (
304 self.clipdepth, s, self.ncells, ['s', ''][self.ncells == 1]))
306 def __iter__(self):
307 return iter(self.root)
309 def __len__(self):
310 return self.ncells
312 def dump(self, filename):
313 with open(filename, 'wb') as file:
314 version = 1
315 file.write(b'SPITREE ')
316 file.write(struct.pack(
317 '<QQQd', version, self.ndim, self.ncells, self.ftol))
318 self.xbounds.astype('<f8').tofile(file)
319 self.xtols.astype('<f8').tofile(file)
320 self.root.dump(file)
322 def _load(self, filename):
323 with open(filename, 'rb') as file:
324 marker, version, self.ndim, self.ncells, self.ftol = bread(
325 file, '<8sQQQd')
326 assert marker == b'SPITREE '
327 assert version == 1
328 self.xbounds = num.fromfile(
329 file, dtype='<f8', count=self.ndim*2).reshape(self.ndim, 2)
330 self.xtols = num.fromfile(
331 file, dtype='<f8', count=self.ndim)
333 path = []
334 for icell in range(self.ncells):
335 index = num.fromfile(
336 file, dtype='<i4', count=self.ndim)
337 f = num.fromfile(
338 file, dtype='<f8', count=2**self.ndim).reshape(
339 [2]*self.ndim)
341 cell = Cell(self, index, f)
342 if not path:
343 self.root = cell
344 path.append(cell)
346 else:
347 while not any_(path[-1].index == (cell.index >> 1)):
348 path.pop()
350 path[-1].children.append(cell)
351 path.append(cell)
353 def _f_cached(self, x):
354 return getset(
355 self.f_values, tuple(float(xx) for xx in x), self.f, self.addargs)
357 def interpolate(self, x):
358 x = num.asarray(x, dtype=float)
359 assert x.ndim == 1 and x.size == self.ndim
360 if not all_(and_(self.xbounds[:, 0] <= x, x <= self.xbounds[:, 1])):
361 raise OutOfBounds()
363 return self.root.interpolate(x)
365 def __call__(self, x):
366 return self.interpolate(x)
368 def interpolate_many(self, x):
369 return self.root.interpolate_many(x)
371 def _continue_fill(self):
372 cells_to_continue, self.cells_to_continue = self.cells_to_continue, []
373 for cell in cells_to_continue:
374 self._deepen_cell(cell)
376 def _fill(self, cell):
378 self.tested += 1
379 xtestpoints = num.sum(cell.xbounds * self.pointmaker_masked, axis=-1)
381 fis = cell.interpolate_many(xtestpoints)
382 fes = num.array(
383 [self._f_cached(x) for x in xtestpoints], dtype=float)
385 iffes = num.isfinite(fes)
386 iffis = num.isfinite(fis)
387 works = iffes == iffis
388 iif = num.logical_and(iffes, iffis)
390 works[iif] *= num.abs(fes[iif] - fis[iif]) < self.ftol
392 nundef = num.sum(not_(num.isfinite(fes))) + \
393 num.sum(not_(num.isfinite(cell.f)))
395 some_undef = 0 < nundef < (xtestpoints.shape[0] + cell.f.size)
397 if any_(works):
398 self.nothing_found_yet = False
400 if not all_(works) or some_undef or self.nothing_found_yet:
401 deepen = self.ones_int.copy()
402 if not some_undef:
403 works_full = num.ones([3]*self.ndim, dtype=bool)
404 works_full[self.pointmaker_mask] = works
405 for idim in range(self.ndim):
406 dimcorners = [slice(None, None, 2)] * self.ndim
407 dimcorners[idim] = 1
408 if all_(works_full[tuple(dimcorners)]):
409 deepen[idim] = 0
411 if not any_(deepen):
412 deepen = self.ones_int
414 deepen = num.where(
415 cell.depths + deepen > self.maxdepths, 0, deepen)
417 cell.deepen = deepen
419 if any_(deepen) and all_(cell.depths + deepen <= self.clipdepth):
420 self._deepen_cell(cell)
421 else:
422 if any_(deepen):
423 self.cells_to_continue.append(cell)
425 cell.bad = True
426 self.fraction_bad += num.prod(1.0/2**cell.depths)
427 self.nbad += 1
429 def _deepen_cell(self, cell):
430 if cell.bad:
431 self.fraction_bad -= num.prod(1.0/2**cell.depths)
432 self.nbad -= 1
433 cell.bad = False
435 for iadd in num.ndindex(*(cell.deepen+1)):
436 index_child = (cell.index << cell.deepen) + iadd
437 child = Cell(self, index_child)
438 self.ncells += 1
439 cell.children.append(child)
440 self._fill(child)
442 def check_holes(self):
443 '''
444 Check for NaNs in :py:class:`SPTree`
445 '''
446 return self.root.check_holes()
448 def plot_2d(self, axes=None, x=None, dims=None):
449 assert self.ndim >= 2
451 if x is None:
452 x = num.zeros(self.ndim)
453 x[-2:] = None
455 x = num.asarray(x, dtype=float)
456 if dims is None:
457 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)]
459 assert len(dims) == 2
461 plt = None
462 if axes is None:
463 from matplotlib import pyplot as plt
464 axes = plt.gca()
466 self.root.plot_2d(axes, x, dims)
468 axes.set_xlabel('Dim %i' % dims[1])
469 axes.set_ylabel('Dim %i' % dims[0])
471 if plt:
472 plt.show()
474 def plot_1d(self, axes=None, x=None, dims=None):
476 if x is None:
477 x = num.zeros(self.ndim)
478 x[-1:] = None
480 x = num.asarray(x, dtype=float)
481 if dims is None:
482 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)]
484 assert len(dims) == 1
486 plt = None
487 if axes is None:
488 from matplotlib import pyplot as plt
489 axes = plt.gca()
491 self.root.plot_1d(axes, x, dims[0])
493 axes.set_xlabel('Dim %i' % dims[0])
495 if plt:
496 plt.show()
499def getset(d, k, f, addargs):
500 try:
501 return d[k]
502 except KeyError:
503 v = d[k] = f(k, *addargs)
504 return v
507def nditer_outer(x):
508 add = []
509 if x[-1] is None:
510 x_ = x[:-1]
511 add = [None]
512 else:
513 x_ = x
515 return num.nditer(
516 x,
517 op_axes=(num.identity(len(x_), dtype=int)-1).tolist() + add)
520if __name__ == '__main__':
521 logging.basicConfig(level=logging.INFO)
523 def f(x):
524 x0 = num.array([0.5, 0.5, 0.5])
525 r = 0.5
526 if num.sqrt(num.sum((x-x0)**2)) < r:
528 return x[2]**4 + x[1]
530 return None
532 tree = SPTree(f, 0.01, [[0., 1.], [0., 1.], [0., 1.]], [0.025, 0.05, 0.1])
534 import tempfile
535 import os
536 fid, fn = tempfile.mkstemp()
537 tree.dump(fn)
538 tree = SPTree(filename=fn)
539 os.unlink(fn)
541 from matplotlib import pyplot as plt
543 v = 0.5
544 axes = plt.subplot(2, 2, 1)
545 tree.plot_2d(axes, x=(v, None, None))
546 axes = plt.subplot(2, 2, 2)
547 tree.plot_2d(axes, x=(None, v, None))
548 axes = plt.subplot(2, 2, 3)
549 tree.plot_2d(axes, x=(None, None, v))
551 axes = plt.subplot(2, 2, 4)
552 tree.plot_1d(axes, x=(v, v, None))
554 plt.show()