Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/spit.py: 88%
338 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-04 09:52 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-04 09:52 +0000
1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6'''
7N-dimensional space partitioning multi-linear interpolator.
8'''
10import struct
11import logging
12import numpy as num
14try:
15 range = xrange
16except NameError:
17 pass
19logger = logging.getLogger('pyrocko.spit')
21or_ = num.logical_or
22and_ = num.logical_and
23not_ = num.logical_not
24all_ = num.all
25any_ = num.any
28class OutOfBounds(Exception):
29 pass
32class Cell(object):
33 def __init__(self, tree, index, f=None):
34 self.tree = tree
35 self.index = index
36 self.depths = num.log2(index).astype(int)
37 self.bad = False
38 self.children = []
39 n = 2**self.depths
40 i = self.index - n
41 delta = (self.tree.xbounds[:, 1] - self.tree.xbounds[:, 0])/n
42 xmin = self.tree.xbounds[:, 0]
43 self.xbounds = self.tree.xbounds.copy()
44 self.xbounds[:, 0] = xmin + i * delta
45 self.xbounds[:, 1] = xmin + (i+1) * delta
46 self.a = self.xbounds[:, ::-1].copy()
47 self.b = self.a.copy()
48 self.b[:, 1] = self.xbounds[:, 1] - self.xbounds[:, 0]
49 self.b[:, 0] = - self.b[:, 1]
51 self.a[:, 0] += (self.b[:, 0] == 0.0)*0.5
52 self.a[:, 1] -= (self.b[:, 1] == 0.0)*0.5
53 self.b[:, 0] -= (self.b[:, 0] == 0.0)
54 self.b[:, 1] += (self.b[:, 1] == 0.0)
56 if f is None:
57 it = nditer_outer(tuple(self.xbounds) + (None,))
58 for vvv in it:
59 vvv[-1][...] = self.tree._f_cached(vvv[:-1])
61 self.f = it.operands[-1]
62 else:
63 self.f = f
65 def interpolate(self, x):
66 if self.children:
67 for cell in self.children:
68 if all_(and_(cell.xbounds[:, 0] <= x,
69 x <= cell.xbounds[:, 1])):
70 return cell.interpolate(x)
72 else:
73 if all_(num.isfinite(self.f)):
74 ws = (x[:, num.newaxis] - self.a)/self.b
75 wn = num.multiply.reduce(
76 num.array(num.ix_(*ws), dtype=object))
77 return num.sum(self.f * wn)
78 else:
79 return None
81 def interpolate_many(self, x):
82 if self.children:
83 result = num.empty(x.shape[0], dtype=float)
84 result[:] = None
85 for cell in self.children:
86 indices = num.where(
87 self.tree.ndim == num.sum(and_(
88 cell.xbounds[:, 0] <= x,
89 x <= cell.xbounds[:, 1]), axis=-1))[0]
91 if indices.size != 0:
92 result[indices] = cell.interpolate_many(x[indices])
94 return result
96 else:
97 if all_(num.isfinite(self.f)):
98 ws = (x[..., num.newaxis] - self.a)/self.b
99 npoints = ws.shape[0]
100 ndim = self.tree.ndim
101 ws_pimped = [ws[:, i, :] for i in range(ndim)]
102 for i in range(ndim):
103 s = [npoints] + [1] * ndim
104 s[1+i] = 2
105 ws_pimped[i].shape = tuple(s)
107 wn = ws_pimped[0]
108 for idim in range(1, ndim):
109 wn = wn * ws_pimped[idim]
111 result = wn * self.f
112 for i in range(ndim):
113 result = num.sum(result, axis=-1)
115 return result
116 else:
117 result = num.empty(x.shape[0], dtype=float)
118 result[:] = None
119 return result
121 def slice(self, x):
122 x = num.array(x, dtype=float)
123 x_mask = not_(num.isfinite(x))
124 x_ = x.copy()
125 x_[x_mask] = 0.0
126 return [
127 cell for cell in self.children if all_(or_(
128 x_mask,
129 and_(
130 cell.xbounds[:, 0] <= x_,
131 x_ <= cell.xbounds[:, 1])))]
133 def plot_rects(self, axes, x, dims):
134 if self.children:
135 for cell in self.slice(x):
136 cell.plot_rects(axes, x, dims)
138 else:
139 points = []
140 for iy, ix in ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)):
141 points.append(
142 (self.xbounds[dims[0], iy], self.xbounds[dims[1], ix]))
144 points = num.transpose(points)
145 axes.plot(points[1], points[0], color=(0.1, 0.1, 0.0, 0.1))
147 def check_holes(self):
148 '''
149 Check if :py:class:`Cell` or its children contain NaNs.
150 '''
151 if self.children:
152 return any([child.check_holes() for child in self.children])
153 else:
154 return num.any(num.isnan(self.f))
156 def plot_2d(self, axes, x, dims):
157 idims = num.array(dims)
158 self.plot_rects(axes, x, dims)
159 coords = [
160 num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d))
161 for (xb, d) in zip(self.xbounds[idims, :], self.tree.xtols[idims])]
163 npoints = coords[0].size * coords[1].size
164 g = num.meshgrid(*coords[::-1])[::-1]
165 points = num.empty((npoints, self.tree.ndim), dtype=float)
166 for idim in range(self.tree.ndim):
167 try:
168 idimout = dims.index(idim)
169 points[:, idim] = g[idimout].ravel()
170 except ValueError:
171 points[:, idim] = x[idim]
173 fi = num.empty((coords[0].size, coords[1].size), dtype=float)
174 fi_r = fi.ravel()
175 fi_r[...] = self.interpolate_many(points)
177 if num.any(num.isnan(fi)):
178 logger.warning('')
179 if any_(num.isfinite(fi)):
180 fi = num.ma.masked_invalid(fi)
181 axes.imshow(
182 fi, origin='lower',
183 extent=[coords[1].min(), coords[1].max(),
184 coords[0].min(), coords[0].max()],
185 interpolation='nearest',
186 aspect='auto',
187 cmap='RdYlBu')
189 def plot_1d(self, axes, x, dim):
190 xb = self.xbounds[dim]
191 d = self.tree.xtols[dim]
192 coords = num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d))
194 npoints = coords.size
195 points = num.empty((npoints, self.tree.ndim), dtype=float)
196 for idim in range(self.tree.ndim):
197 if idim == dim:
198 points[:, idim] = coords
199 else:
200 points[:, idim] = x[idim]
202 fi = self.interpolate_many(points)
203 if any_(num.isfinite(fi)):
204 fi = num.ma.masked_invalid(fi)
205 axes.plot(coords, fi)
207 def __iter__(self):
208 yield self
209 for c in self.children:
210 for x in c:
211 yield x
213 def dump(self, file):
214 self.index.astype('<i4').tofile(file)
215 self.f.astype('<f8').tofile(file)
216 for c in self.children:
217 c.dump(file)
220def bread(f, fmt):
221 s = f.read(struct.calcsize(fmt))
222 return struct.unpack(fmt, s)
225class SPTree(object):
226 '''
227 N-dimensional space partitioning interpolator.
229 :param f: callable function ``f(x)`` where ``x`` is a vector of size ``n``
230 :param ftol: target accuracy ``|f_interp(x) - f(x)| <= ftol``
231 :param xbounds: bounds of ``x``, shape ``(n, 2)``
232 :param xtols: target coarsenesses in ``x``, vector of size ``n``
233 :param addargs: additional arguments to pass to ``f``
234 '''
236 def __init__(self, f=None, ftol=None, xbounds=None, xtols=None,
237 filename=None, addargs=()):
239 if filename is None:
240 assert all(v is not None for v in (f, ftol, xbounds, xtols))
242 self.f = f
243 self.ftol = float(ftol)
244 self.f_values = {}
245 self.ncells = 0
246 self.addargs = addargs
248 self.xbounds = num.asarray(xbounds, dtype=float)
249 assert self.xbounds.ndim == 2
250 assert self.xbounds.shape[1] == 2
251 self.ndim = self.xbounds.shape[0]
253 self.xtols = num.asarray(xtols, dtype=float)
254 assert self.xtols.ndim == 1 and self.xtols.size == self.ndim
256 self.maxdepths = num.ceil(num.log2(
257 num.maximum(
258 1.0,
259 (self.xbounds[:, 1] - self.xbounds[:, 0]) / self.xtols)
260 )).astype(int)
262 self.root = None
263 self.ones_int = num.ones(self.ndim, dtype=int)
265 cc = num.ix_(*[num.arange(3)]*self.ndim)
266 w = num.zeros([3]*self.ndim + [self.ndim, 2])
267 for i, c in enumerate(cc):
268 w[..., i, 0] = (2-c)*0.5
269 w[..., i, 1] = c*0.5
271 self.pointmaker = w
272 self.pointmaker_mask = num.sum(w[..., 0] == 0.5, axis=-1) != 0
273 self.pointmaker_masked = w[self.pointmaker_mask]
275 self.nothing_found_yet = True
277 self.root = Cell(self, self.ones_int)
278 self.ncells += 1
280 self.fraction_bad = 0.0
281 self.nbad = 0
282 self.cells_to_continue = []
283 for clipdepth in range(0, num.max(self.maxdepths)+1):
284 self.clipdepth = clipdepth
285 self.tested = 0
286 if self.clipdepth == 0:
287 self._fill(self.root)
288 else:
289 self._continue_fill()
291 self.status()
293 if not self.cells_to_continue:
294 break
296 else:
297 self._load(filename)
299 def status(self):
300 perc = (1.0-self.fraction_bad)*100
301 s = '%6.1f%%' % perc
303 if self.fraction_bad != 0.0 and s == ' 100.0%':
304 s = '~100.0%'
306 logger.info('at level %2i: %s covered, %6i cell%s' % (
307 self.clipdepth, s, self.ncells, ['s', ''][self.ncells == 1]))
309 def __iter__(self):
310 return iter(self.root)
312 def __len__(self):
313 return self.ncells
315 def dump(self, filename):
316 with open(filename, 'wb') as file:
317 version = 1
318 file.write(b'SPITREE ')
319 file.write(struct.pack(
320 '<QQQd', version, self.ndim, self.ncells, self.ftol))
321 self.xbounds.astype('<f8').tofile(file)
322 self.xtols.astype('<f8').tofile(file)
323 self.root.dump(file)
325 def _load(self, filename):
326 with open(filename, 'rb') as file:
327 marker, version, self.ndim, self.ncells, self.ftol = bread(
328 file, '<8sQQQd')
329 assert marker == b'SPITREE '
330 assert version == 1
331 self.xbounds = num.fromfile(
332 file, dtype='<f8', count=self.ndim*2).reshape(self.ndim, 2)
333 self.xtols = num.fromfile(
334 file, dtype='<f8', count=self.ndim)
336 path = []
337 for icell in range(self.ncells):
338 index = num.fromfile(
339 file, dtype='<i4', count=self.ndim)
340 f = num.fromfile(
341 file, dtype='<f8', count=2**self.ndim).reshape(
342 [2]*self.ndim)
344 cell = Cell(self, index, f)
345 if not path:
346 self.root = cell
347 path.append(cell)
349 else:
350 while not any_(path[-1].index == (cell.index >> 1)):
351 path.pop()
353 path[-1].children.append(cell)
354 path.append(cell)
356 def _f_cached(self, x):
357 return getset(
358 self.f_values, tuple(float(xx) for xx in x), self.f, self.addargs)
360 def interpolate(self, x):
361 x = num.asarray(x, dtype=float)
362 assert x.ndim == 1 and x.size == self.ndim
363 if not all_(and_(self.xbounds[:, 0] <= x, x <= self.xbounds[:, 1])):
364 raise OutOfBounds()
366 return self.root.interpolate(x)
368 def __call__(self, x):
369 return self.interpolate(x)
371 def interpolate_many(self, x):
372 return self.root.interpolate_many(x)
374 def _continue_fill(self):
375 cells_to_continue, self.cells_to_continue = self.cells_to_continue, []
376 for cell in cells_to_continue:
377 self._deepen_cell(cell)
379 def _fill(self, cell):
381 self.tested += 1
382 xtestpoints = num.sum(cell.xbounds * self.pointmaker_masked, axis=-1)
384 fis = cell.interpolate_many(xtestpoints)
385 fes = num.array(
386 [self._f_cached(x) for x in xtestpoints], dtype=float)
388 iffes = num.isfinite(fes)
389 iffis = num.isfinite(fis)
390 works = iffes == iffis
391 iif = num.logical_and(iffes, iffis)
393 works[iif] *= num.abs(fes[iif] - fis[iif]) < self.ftol
395 nundef = num.sum(not_(num.isfinite(fes))) + \
396 num.sum(not_(num.isfinite(cell.f)))
398 some_undef = 0 < nundef < (xtestpoints.shape[0] + cell.f.size)
400 if any_(works):
401 self.nothing_found_yet = False
403 if not all_(works) or some_undef or self.nothing_found_yet:
404 deepen = self.ones_int.copy()
405 if not some_undef:
406 works_full = num.ones([3]*self.ndim, dtype=bool)
407 works_full[self.pointmaker_mask] = works
408 for idim in range(self.ndim):
409 dimcorners = [slice(None, None, 2)] * self.ndim
410 dimcorners[idim] = 1
411 if all_(works_full[tuple(dimcorners)]):
412 deepen[idim] = 0
414 if not any_(deepen):
415 deepen = self.ones_int
417 deepen = num.where(
418 cell.depths + deepen > self.maxdepths, 0, deepen)
420 cell.deepen = deepen
422 if any_(deepen) and all_(cell.depths + deepen <= self.clipdepth):
423 self._deepen_cell(cell)
424 else:
425 if any_(deepen):
426 self.cells_to_continue.append(cell)
428 cell.bad = True
429 self.fraction_bad += num.prod(1.0/2**cell.depths)
430 self.nbad += 1
432 def _deepen_cell(self, cell):
433 if cell.bad:
434 self.fraction_bad -= num.prod(1.0/2**cell.depths)
435 self.nbad -= 1
436 cell.bad = False
438 for iadd in num.ndindex(*(cell.deepen+1)):
439 index_child = (cell.index << cell.deepen) + iadd
440 child = Cell(self, index_child)
441 self.ncells += 1
442 cell.children.append(child)
443 self._fill(child)
445 def check_holes(self):
446 '''
447 Check for NaNs in :py:class:`SPTree`
448 '''
449 return self.root.check_holes()
451 def plot_2d(self, axes=None, x=None, dims=None):
452 assert self.ndim >= 2
454 if x is None:
455 x = num.zeros(self.ndim)
456 x[-2:] = None
458 x = num.asarray(x, dtype=float)
459 if dims is None:
460 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)]
462 assert len(dims) == 2
464 plt = None
465 if axes is None:
466 from matplotlib import pyplot as plt
467 axes = plt.gca()
469 self.root.plot_2d(axes, x, dims)
471 axes.set_xlabel('Dim %i' % dims[1])
472 axes.set_ylabel('Dim %i' % dims[0])
474 if plt:
475 plt.show()
477 def plot_1d(self, axes=None, x=None, dims=None):
479 if x is None:
480 x = num.zeros(self.ndim)
481 x[-1:] = None
483 x = num.asarray(x, dtype=float)
484 if dims is None:
485 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)]
487 assert len(dims) == 1
489 plt = None
490 if axes is None:
491 from matplotlib import pyplot as plt
492 axes = plt.gca()
494 self.root.plot_1d(axes, x, dims[0])
496 axes.set_xlabel('Dim %i' % dims[0])
498 if plt:
499 plt.show()
502def getset(d, k, f, addargs):
503 try:
504 return d[k]
505 except KeyError:
506 v = d[k] = f(k, *addargs)
507 return v
510def nditer_outer(x):
511 add = []
512 if x[-1] is None:
513 x_ = x[:-1]
514 add = [None]
515 else:
516 x_ = x
518 return num.nditer(
519 x,
520 op_axes=(num.identity(len(x_), dtype=int)-1).tolist() + add)
523if __name__ == '__main__':
524 logging.basicConfig(level=logging.INFO)
526 def f(x):
527 x0 = num.array([0.5, 0.5, 0.5])
528 r = 0.5
529 if num.sqrt(num.sum((x-x0)**2)) < r:
531 return x[2]**4 + x[1]
533 return None
535 tree = SPTree(f, 0.01, [[0., 1.], [0., 1.], [0., 1.]], [0.025, 0.05, 0.1])
537 import tempfile
538 import os
539 fid, fn = tempfile.mkstemp()
540 tree.dump(fn)
541 tree = SPTree(filename=fn)
542 os.unlink(fn)
544 from matplotlib import pyplot as plt
546 v = 0.5
547 axes = plt.subplot(2, 2, 1)
548 tree.plot_2d(axes, x=(v, None, None))
549 axes = plt.subplot(2, 2, 2)
550 tree.plot_2d(axes, x=(None, v, None))
551 axes = plt.subplot(2, 2, 3)
552 tree.plot_2d(axes, x=(None, None, v))
554 axes = plt.subplot(2, 2, 4)
555 tree.plot_1d(axes, x=(v, v, None))
557 plt.show()