Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/spit.py: 86%
344 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-03-07 11:54 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-03-07 11:54 +0000
1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6'''
7N-dimensional space partitioning multi-linear interpolator.
8'''
10import struct
11import logging
12import numpy as num
14try:
15 range = xrange
16except NameError:
17 pass
19logger = logging.getLogger('pyrocko.spit')
21or_ = num.logical_or
22and_ = num.logical_and
23not_ = num.logical_not
24all_ = num.all
25any_ = num.any
28class OutOfBounds(Exception):
29 pass
32class Cell(object):
33 __slots__ = (
34 "tree", "index", "depths", "bad", "children", "xbounds", "deepen",
35 "a", "b", "f")
37 def __init__(self, tree, index, f=None):
38 self.tree = tree
39 self.index = index
40 self.depths = num.log2(index).astype(int)
41 self.bad = False
42 self.children = []
43 n = 2**self.depths
44 i = self.index - n
45 delta = (self.tree.xbounds[:, 1] - self.tree.xbounds[:, 0])/n
46 xmin = self.tree.xbounds[:, 0]
47 self.xbounds = self.tree.xbounds.copy()
48 self.xbounds[:, 0] = xmin + i * delta
49 self.xbounds[:, 1] = xmin + (i+1) * delta
50 self.a = self.xbounds[:, ::-1].copy()
51 self.b = self.a.copy()
52 self.b[:, 1] = self.xbounds[:, 1] - self.xbounds[:, 0]
53 self.b[:, 0] = - self.b[:, 1]
55 self.a[:, 0] += (self.b[:, 0] == 0.0)*0.5
56 self.a[:, 1] -= (self.b[:, 1] == 0.0)*0.5
57 self.b[:, 0] -= (self.b[:, 0] == 0.0)
58 self.b[:, 1] += (self.b[:, 1] == 0.0)
60 if f is None:
61 it = nditer_outer(tuple(self.xbounds) + (None,))
62 for vvv in it:
63 vvv[-1][...] = self.tree._f_cached(vvv[:-1])
65 self.f = it.operands[-1]
66 else:
67 self.f = f
69 def interpolate(self, x):
70 if self.children:
71 for cell in self.children:
72 if all_(and_(cell.xbounds[:, 0] <= x,
73 x <= cell.xbounds[:, 1])):
74 return cell.interpolate(x)
76 if all_(num.isfinite(self.f)):
77 ws = (x[:, num.newaxis] - self.a)/self.b
78 wn = num.multiply.reduce(
79 num.array(num.ix_(*ws), dtype=object))
80 return num.sum(self.f * wn)
81 else:
82 if all_(num.isfinite(self.f)):
83 ws = (x[:, num.newaxis] - self.a)/self.b
84 wn2 = self.tree.ones_cell.copy()
85 for ws_ in num.ix_(*ws):
86 wn2 *= ws_
88 # wn = num.multiply.reduce(
89 # num.array(num.ix_(*ws), dtype=object))
90 return num.sum(self.f * wn2)
91 else:
92 return None
94 def interpolate_many(self, x):
95 ndim = self.tree.ndim
96 ndim_range = tuple(range(ndim))
97 if self.children:
98 result = num.full(x.shape[0], fill_value=num.nan)
99 for cell in self.children:
100 indices = num.where(
101 ndim == num.sum(and_(
102 cell.xbounds[:, 0] <= x,
103 x <= cell.xbounds[:, 1]), axis=-1))[0]
105 if indices.size != 0:
106 result[indices] = cell.interpolate_many(x[indices])
107 return result
109 if all_(num.isfinite(self.f)):
110 ws = (x[..., num.newaxis] - self.a)/self.b
111 npoints = ws.shape[0]
112 ws_pimped = [ws[:, i, :] for i in ndim_range]
113 for i in ndim_range:
114 s = [npoints] + [1] * ndim
115 s[1+i] = 2
116 ws_pimped[i].shape = tuple(s)
118 wn = ws_pimped[0]
119 for idim in ndim_range[1:]:
120 wn = wn * ws_pimped[idim]
122 result = wn * self.f
123 for i in ndim_range:
124 result = num.sum(result, axis=-1)
126 return result
127 else:
128 return num.full(x.shape[0], fill_value=num.nan)
130 def slice(self, x):
131 x = num.array(x, dtype=float)
132 x_mask = not_(num.isfinite(x))
133 x_ = x.copy()
134 x_[x_mask] = 0.0
135 return [
136 cell for cell in self.children if all_(or_(
137 x_mask,
138 and_(
139 cell.xbounds[:, 0] <= x_,
140 x_ <= cell.xbounds[:, 1])))]
142 def plot_rects(self, axes, x, dims):
143 if self.children:
144 for cell in self.slice(x):
145 cell.plot_rects(axes, x, dims)
147 else:
148 points = []
149 for iy, ix in ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)):
150 points.append(
151 (self.xbounds[dims[0], iy], self.xbounds[dims[1], ix]))
153 points = num.transpose(points)
154 axes.plot(points[1], points[0], color=(0.1, 0.1, 0.0, 0.1))
156 def check_holes(self):
157 '''
158 Check if :py:class:`Cell` or its children contain NaNs.
159 '''
160 if self.children:
161 return any([child.check_holes() for child in self.children])
162 else:
163 return num.any(num.isnan(self.f))
165 def plot_2d(self, axes, x, dims):
166 idims = num.array(dims)
167 self.plot_rects(axes, x, dims)
168 coords = [
169 num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d))
170 for (xb, d) in zip(self.xbounds[idims, :], self.tree.xtols[idims])]
172 npoints = coords[0].size * coords[1].size
173 g = num.meshgrid(*coords[::-1])[::-1]
174 points = num.empty((npoints, self.tree.ndim), dtype=float)
175 for idim in range(self.tree.ndim):
176 try:
177 idimout = dims.index(idim)
178 points[:, idim] = g[idimout].ravel()
179 except ValueError:
180 points[:, idim] = x[idim]
182 fi = num.empty((coords[0].size, coords[1].size), dtype=float)
183 fi_r = fi.ravel()
184 fi_r[...] = self.interpolate_many(points)
186 if num.any(num.isnan(fi)):
187 logger.warning('')
188 if any_(num.isfinite(fi)):
189 fi = num.ma.masked_invalid(fi)
190 axes.imshow(
191 fi, origin='lower',
192 extent=[coords[1].min(), coords[1].max(),
193 coords[0].min(), coords[0].max()],
194 interpolation='nearest',
195 aspect='auto',
196 cmap='RdYlBu')
198 def plot_1d(self, axes, x, dim):
199 xb = self.xbounds[dim]
200 d = self.tree.xtols[dim]
201 coords = num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d))
203 npoints = coords.size
204 points = num.empty((npoints, self.tree.ndim), dtype=float)
205 for idim in range(self.tree.ndim):
206 if idim == dim:
207 points[:, idim] = coords
208 else:
209 points[:, idim] = x[idim]
211 fi = self.interpolate_many(points)
212 if any_(num.isfinite(fi)):
213 fi = num.ma.masked_invalid(fi)
214 axes.plot(coords, fi)
216 def __iter__(self):
217 yield self
218 for c in self.children:
219 for x in c:
220 yield x
222 def dump(self, file):
223 self.index.astype('<i4').tofile(file)
224 self.f.astype('<f8').tofile(file)
225 for c in self.children:
226 c.dump(file)
229def bread(f, fmt):
230 s = f.read(struct.calcsize(fmt))
231 return struct.unpack(fmt, s)
234class SPTree(object):
235 '''
236 N-dimensional space partitioning interpolator.
238 :param f: callable function ``f(x)`` where ``x`` is a vector of size ``n``
239 :param ftol: target accuracy ``|f_interp(x) - f(x)| <= ftol``
240 :param xbounds: bounds of ``x``, shape ``(n, 2)``
241 :param xtols: target coarsenesses in ``x``, vector of size ``n``
242 :param addargs: additional arguments to pass to ``f``
243 '''
245 def __init__(self, f=None, ftol=None, xbounds=None, xtols=None,
246 filename=None, addargs=()):
248 if filename is None:
249 assert all(v is not None for v in (f, ftol, xbounds, xtols))
251 self.f = f
252 self.ftol = float(ftol)
253 self.f_values = {}
254 self.ncells = 0
255 self.addargs = addargs
257 self.xbounds = num.asarray(xbounds, dtype=float)
258 assert self.xbounds.ndim == 2
259 assert self.xbounds.shape[1] == 2
260 self.ndim = self.xbounds.shape[0]
262 self.xtols = num.asarray(xtols, dtype=float)
263 assert self.xtols.ndim == 1 and self.xtols.size == self.ndim
265 self.maxdepths = num.ceil(num.log2(
266 num.maximum(
267 1.0,
268 (self.xbounds[:, 1] - self.xbounds[:, 0]) / self.xtols)
269 )).astype(int)
271 self.root = None
272 self.ones_int = num.ones(self.ndim, dtype=int)
274 cc = num.ix_(*[num.arange(3)]*self.ndim)
275 w = num.zeros([3]*self.ndim + [self.ndim, 2])
276 for i, c in enumerate(cc):
277 w[..., i, 0] = (2-c)*0.5
278 w[..., i, 1] = c*0.5
280 self.pointmaker = w
281 self.pointmaker_mask = num.sum(w[..., 0] == 0.5, axis=-1) != 0
282 self.pointmaker_masked = w[self.pointmaker_mask]
284 self.nothing_found_yet = True
286 self.root = Cell(self, self.ones_int)
287 self.ncells += 1
289 self.fraction_bad = 0.0
290 self.nbad = 0
291 self.cells_to_continue = []
292 for clipdepth in range(0, num.max(self.maxdepths)+1):
293 self.clipdepth = clipdepth
294 self.tested = 0
295 if self.clipdepth == 0:
296 self._fill(self.root)
297 else:
298 self._continue_fill()
300 self.status()
302 if not self.cells_to_continue:
303 break
305 else:
306 self._load(filename)
308 self.ones_cell = num.ones((2,) * self.ndim, dtype=float)
310 def status(self):
311 perc = (1.0-self.fraction_bad)*100
312 s = '%6.1f%%' % perc
314 if self.fraction_bad != 0.0 and s == ' 100.0%':
315 s = '~100.0%'
317 logger.info('at level %2i: %s covered, %6i cell%s' % (
318 self.clipdepth, s, self.ncells, ['s', ''][self.ncells == 1]))
320 def __iter__(self):
321 return iter(self.root)
323 def __len__(self):
324 return self.ncells
326 def dump(self, filename):
327 with open(filename, 'wb') as file:
328 version = 1
329 file.write(b'SPITREE ')
330 file.write(struct.pack(
331 '<QQQd', version, self.ndim, self.ncells, self.ftol))
332 self.xbounds.astype('<f8').tofile(file)
333 self.xtols.astype('<f8').tofile(file)
334 self.root.dump(file)
336 def _load(self, filename):
337 with open(filename, 'rb') as file:
338 marker, version, self.ndim, self.ncells, self.ftol = bread(
339 file, '<8sQQQd')
340 assert marker == b'SPITREE '
341 assert version == 1
342 self.xbounds = num.fromfile(
343 file, dtype='<f8', count=self.ndim*2).reshape(self.ndim, 2)
344 self.xtols = num.fromfile(
345 file, dtype='<f8', count=self.ndim)
347 path = []
348 for icell in range(self.ncells):
349 index = num.fromfile(
350 file, dtype='<i4', count=self.ndim)
351 f = num.fromfile(
352 file, dtype='<f8', count=2**self.ndim).reshape(
353 [2]*self.ndim)
355 cell = Cell(self, index, f)
356 if not path:
357 self.root = cell
358 path.append(cell)
360 else:
361 while not any_(path[-1].index == (cell.index >> 1)):
362 path.pop()
364 path[-1].children.append(cell)
365 path.append(cell)
367 def _f_cached(self, x):
368 return getset(
369 self.f_values, tuple(float(xx) for xx in x), self.f, self.addargs)
371 def interpolate(self, x):
372 x = num.asarray(x, dtype=float)
373 assert x.ndim == 1 and x.size == self.ndim
374 if not all_(and_(self.xbounds[:, 0] <= x, x <= self.xbounds[:, 1])):
375 raise OutOfBounds()
377 return self.root.interpolate(x)
379 def __call__(self, x):
380 return self.interpolate(x)
382 def interpolate_many(self, x):
383 return self.root.interpolate_many(x)
385 def _continue_fill(self):
386 cells_to_continue, self.cells_to_continue = self.cells_to_continue, []
387 for cell in cells_to_continue:
388 self._deepen_cell(cell)
390 def _fill(self, cell):
392 self.tested += 1
393 xtestpoints = num.sum(cell.xbounds * self.pointmaker_masked, axis=-1)
395 fis = cell.interpolate_many(xtestpoints)
396 fes = num.array(
397 [self._f_cached(x) for x in xtestpoints], dtype=float)
399 iffes = num.isfinite(fes)
400 iffis = num.isfinite(fis)
401 works = iffes == iffis
402 iif = num.logical_and(iffes, iffis)
404 works[iif] *= num.abs(fes[iif] - fis[iif]) < self.ftol
406 nundef = num.sum(not_(num.isfinite(fes))) + \
407 num.sum(not_(num.isfinite(cell.f)))
409 some_undef = 0 < nundef < (xtestpoints.shape[0] + cell.f.size)
411 if any_(works):
412 self.nothing_found_yet = False
414 if not all_(works) or some_undef or self.nothing_found_yet:
415 deepen = self.ones_int.copy()
416 if not some_undef:
417 works_full = num.ones([3]*self.ndim, dtype=bool)
418 works_full[self.pointmaker_mask] = works
419 for idim in range(self.ndim):
420 dimcorners = [slice(None, None, 2)] * self.ndim
421 dimcorners[idim] = 1
422 if all_(works_full[tuple(dimcorners)]):
423 deepen[idim] = 0
425 if not any_(deepen):
426 deepen = self.ones_int
428 deepen = num.where(
429 cell.depths + deepen > self.maxdepths, 0, deepen)
431 cell.deepen = deepen
433 if any_(deepen) and all_(cell.depths + deepen <= self.clipdepth):
434 self._deepen_cell(cell)
435 else:
436 if any_(deepen):
437 self.cells_to_continue.append(cell)
439 cell.bad = True
440 self.fraction_bad += num.prod(1.0/2**cell.depths)
441 self.nbad += 1
443 def _deepen_cell(self, cell):
444 if cell.bad:
445 self.fraction_bad -= num.prod(1.0/2**cell.depths)
446 self.nbad -= 1
447 cell.bad = False
449 for iadd in num.ndindex(*(cell.deepen+1)):
450 index_child = (cell.index << cell.deepen) + iadd
451 child = Cell(self, index_child)
452 self.ncells += 1
453 cell.children.append(child)
454 self._fill(child)
456 def check_holes(self):
457 '''
458 Check for NaNs in :py:class:`SPTree`
459 '''
460 return self.root.check_holes()
462 def plot_2d(self, axes=None, x=None, dims=None):
463 assert self.ndim >= 2
465 if x is None:
466 x = num.zeros(self.ndim)
467 x[-2:] = None
469 x = num.asarray(x, dtype=float)
470 if dims is None:
471 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)]
473 assert len(dims) == 2
475 plt = None
476 if axes is None:
477 from matplotlib import pyplot as plt
478 axes = plt.gca()
480 self.root.plot_2d(axes, x, dims)
482 axes.set_xlabel('Dim %i' % dims[1])
483 axes.set_ylabel('Dim %i' % dims[0])
485 if plt:
486 plt.show()
488 def plot_1d(self, axes=None, x=None, dims=None):
490 if x is None:
491 x = num.zeros(self.ndim)
492 x[-1:] = None
494 x = num.asarray(x, dtype=float)
495 if dims is None:
496 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)]
498 assert len(dims) == 1
500 plt = None
501 if axes is None:
502 from matplotlib import pyplot as plt
503 axes = plt.gca()
505 self.root.plot_1d(axes, x, dims[0])
507 axes.set_xlabel('Dim %i' % dims[0])
509 if plt:
510 plt.show()
513def getset(d, k, f, addargs):
514 try:
515 return d[k]
516 except KeyError:
517 v = d[k] = f(k, *addargs)
518 return v
521def nditer_outer(x):
522 add = []
523 if x[-1] is None:
524 x_ = x[:-1]
525 add = [None]
526 else:
527 x_ = x
529 return num.nditer(
530 x,
531 op_axes=(num.identity(len(x_), dtype=int)-1).tolist() + add)
534if __name__ == '__main__':
535 import time
536 logging.basicConfig(level=logging.INFO)
538 def f(x):
539 x0 = num.array([0.5, 0.5, 0.5])
540 r = 0.5
541 if num.sqrt(num.sum((x-x0)**2)) < r:
543 return x[2]**4 + x[1]
545 return None
547 tree = SPTree(f, 0.01, [[0., 1.], [0., 1.], [0., 1.]], [0.025, 0.05, 0.1])
549 x = num.array([0.35, 0.25, 0.15])
550 n = 10000
551 xs = num.random.random((n, 3))
552 t0 = time.time()
553 vs = tree.interpolate_many(xs)
554 t1 = time.time()
555 for i in range(n):
556 v = tree.interpolate(xs[i])
557 assert vs[i] == v or (v is None and num.isnan(vs[i]))
558 t2 = time.time()
560 print(t1 - t0, t2 - t1)
562 import tempfile
563 import os
564 fid, fn = tempfile.mkstemp()
565 tree.dump(fn)
566 tree = SPTree(filename=fn)
567 os.unlink(fn)
569 from matplotlib import pyplot as plt
571 v = 0.5
572 axes = plt.subplot(2, 2, 1)
573 tree.plot_2d(axes, x=(v, None, None))
574 axes = plt.subplot(2, 2, 2)
575 tree.plot_2d(axes, x=(None, v, None))
576 axes = plt.subplot(2, 2, 3)
577 tree.plot_2d(axes, x=(None, None, v))
579 axes = plt.subplot(2, 2, 4)
580 tree.plot_1d(axes, x=(v, v, None))
582 plt.show()