# http://pyrocko.org - GPLv3 # # The Pyrocko Developers, 21st Century # ---|P------/S----------~Lg----------
it = nditer_outer(tuple(self.xbounds) + (None,)) for vvv in it: vvv[-1][...] = self.tree._f_cached(vvv[:-1])
self.f = it.operands[-1] else:
x <= cell.xbounds[:, 1])):
else: num.array(num.ix_(*ws), dtype=num.object)) else: return None
if self.children: result = num.empty(x.shape[0], dtype=num.float) result[:] = None for cell in self.children: indices = num.where( self.tree.ndim == num.sum(and_( cell.xbounds[:, 0] <= x, x <= cell.xbounds[:, 1]), axis=-1))[0]
if indices.size != 0: result[indices] = cell.interpolate_many(x[indices])
return result
else: if all_(num.isfinite(self.f)): ws = (x[..., num.newaxis] - self.a)/self.b npoints = ws.shape[0] ndim = self.tree.ndim ws_pimped = [ws[:, i, :] for i in range(ndim)] for i in range(ndim): s = [npoints] + [1] * ndim s[1+i] = 2 ws_pimped[i].shape = tuple(s)
wn = ws_pimped[0] for idim in range(1, ndim): wn = wn * ws_pimped[idim]
result = wn * self.f for i in range(ndim): result = num.sum(result, axis=-1)
return result else: result = num.empty(x.shape[0], dtype=num.float) result[:] = None return result
x = num.array(x, dtype=num.float) x_mask = not_(num.isfinite(x)) x_ = x.copy() x_[x_mask] = 0.0 return [ cell for cell in self.children if all_(or_( x_mask, and_( cell.xbounds[:, 0] <= x_, x_ <= cell.xbounds[:, 1])))]
if self.children: for cell in self.slice(x): cell.plot_rects(axes, x, dims)
else: points = [] for iy, ix in ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)): points.append( (self.xbounds[dims[0], iy], self.xbounds[dims[1], ix]))
points = num.transpose(points) axes.plot(points[1], points[0], color=(0.1, 0.1, 0.0, 0.1))
''' Check if :py:class:`Cell` or its' children contain NaNs''' if self.children: return any([child.check_holes() for child in self.children]) else: return num.any(num.isnan(self.f))
idims = num.array(dims) self.plot_rects(axes, x, dims) coords = [ num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d)) for (xb, d) in zip(self.xbounds[idims, :], self.tree.xtols[idims])]
npoints = coords[0].size * coords[1].size g = num.meshgrid(*coords[::-1])[::-1] points = num.empty((npoints, self.tree.ndim), dtype=num.float) for idim in range(self.tree.ndim): try: idimout = dims.index(idim) points[:, idim] = g[idimout].ravel() except ValueError: points[:, idim] = x[idim]
fi = num.empty((coords[0].size, coords[1].size), dtype=num.float) fi_r = fi.ravel() fi_r[...] = self.interpolate_many(points)
if num.any(num.isnan(fi)): logger.warn('') if any_(num.isfinite(fi)): fi = num.ma.masked_invalid(fi) axes.imshow( fi, origin='lower', extent=[coords[1].min(), coords[1].max(), coords[0].min(), coords[0].max()], interpolation='nearest', aspect='auto', cmap='RdYlBu')
xb = self.xbounds[dim] d = self.tree.xtols[dim] coords = num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d))
npoints = coords.size points = num.empty((npoints, self.tree.ndim), dtype=num.float) for idim in range(self.tree.ndim): if idim == dim: points[:, idim] = coords else: points[:, idim] = x[idim]
fi = self.interpolate_many(points) if any_(num.isfinite(fi)): fi = num.ma.masked_invalid(fi) axes.plot(coords, fi)
yield self for c in self.children: for x in c: yield x
self.index.astype('<i4').tofile(file) self.f.astype('<f8').tofile(file) for c in self.children: c.dump(file)
filename=None, addargs=()):
'''Create n-dimensional space partitioning interpolator.
:param f: callable function f(x) where x is a vector of size n :param ftol: target accuracy |f_interp(x) - f(x)| <= ftol :param xbounds: bounds of x, shape (n, 2) :param xtols: target coarsenesses in x, vector of size n :param addargs: additional arguments to pass to f '''
assert all(v is not None for v in (f, ftol, xbounds, xtols))
self.f = f self.ftol = float(ftol) self.f_values = {} self.ncells = 0 self.addargs = addargs
self.xbounds = num.asarray(xbounds, dtype=num.float) assert self.xbounds.ndim == 2 assert self.xbounds.shape[1] == 2 self.ndim = self.xbounds.shape[0]
self.xtols = num.asarray(xtols, dtype=num.float) assert self.xtols.ndim == 1 and self.xtols.size == self.ndim
self.maxdepths = num.ceil(num.log2( num.maximum( 1.0, (self.xbounds[:, 1] - self.xbounds[:, 0]) / self.xtols) )).astype(num.int)
self.root = None self.ones_int = num.ones(self.ndim, dtype=num.int)
cc = num.ix_(*[num.arange(3)]*self.ndim) w = num.zeros([3]*self.ndim + [self.ndim, 2]) for i, c in enumerate(cc): w[..., i, 0] = (2-c)*0.5 w[..., i, 1] = c*0.5
self.pointmaker = w self.pointmaker_mask = num.sum(w[..., 0] == 0.5, axis=-1) != 0 self.pointmaker_masked = w[self.pointmaker_mask]
self.nothing_found_yet = True
self.root = Cell(self, self.ones_int) self.ncells += 1
self.fraction_bad = 0.0 self.nbad = 0 self.cells_to_continue = [] for clipdepth in range(0, num.max(self.maxdepths)+1): self.clipdepth = clipdepth self.tested = 0 if self.clipdepth == 0: self._fill(self.root) else: self._continue_fill()
self.status()
if not self.cells_to_continue: break
else:
perc = (1.0-self.fraction_bad)*100 s = '%6.1f%%' % perc
if self.fraction_bad != 0.0 and s == ' 100.0%': s = '~100.0%'
logger.info('at level %2i: %s covered, %6i cell%s' % ( self.clipdepth, s, self.ncells, ['s', ''][self.ncells == 1]))
return iter(self.root)
return self.ncells
with open(filename, 'wb') as file: version = 1 file.write(b'SPITREE ') file.write(struct.pack( '<QQQd', version, self.ndim, self.ncells, self.ftol)) self.xbounds.astype('<f8').tofile(file) self.xtols.astype('<f8').tofile(file) self.root.dump(file)
file, '<8sQQQd') file, dtype='<f8', count=self.ndim*2).reshape(self.ndim, 2) file, dtype='<f8', count=self.ndim)
file, dtype='<i4', count=self.ndim) file, dtype='<f8', count=2**self.ndim).reshape( [2]*self.ndim)
else:
return getset( self.f_values, tuple(float(xx) for xx in x), self.f, self.addargs)
raise OutOfBounds()
return self.root.interpolate_many(x)
cells_to_continue, self.cells_to_continue = self.cells_to_continue, [] for cell in cells_to_continue: self._deepen_cell(cell)
self.tested += 1 xtestpoints = num.sum(cell.xbounds * self.pointmaker_masked, axis=-1)
fis = cell.interpolate_many(xtestpoints) fes = num.array( [self._f_cached(x) for x in xtestpoints], dtype=num.float)
iffes = num.isfinite(fes) iffis = num.isfinite(fis) works = iffes == iffis iif = num.logical_and(iffes, iffis)
works[iif] *= num.abs(fes[iif] - fis[iif]) < self.ftol
nundef = num.sum(not_(num.isfinite(fes))) + \ num.sum(not_(num.isfinite(cell.f)))
some_undef = 0 < nundef < (xtestpoints.shape[0] + cell.f.size)
if any_(works): self.nothing_found_yet = False
if not all_(works) or some_undef or self.nothing_found_yet: deepen = self.ones_int.copy() if not some_undef: works_full = num.ones([3]*self.ndim, dtype=num.bool) works_full[self.pointmaker_mask] = works for idim in range(self.ndim): dimcorners = [slice(None, None, 2)] * self.ndim dimcorners[idim] = 1 if all_(works_full[tuple(dimcorners)]): deepen[idim] = 0
if not any_(deepen): deepen = self.ones_int
deepen = num.where( cell.depths + deepen > self.maxdepths, 0, deepen)
cell.deepen = deepen
if any_(deepen) and all_(cell.depths + deepen <= self.clipdepth): self._deepen_cell(cell) else: if any_(deepen): self.cells_to_continue.append(cell)
cell.bad = True self.fraction_bad += num.product(1.0/2**cell.depths) self.nbad += 1
if cell.bad: self.fraction_bad -= num.product(1.0/2**cell.depths) self.nbad -= 1 cell.bad = False
for iadd in num.ndindex(*(cell.deepen+1)): index_child = (cell.index << cell.deepen) + iadd child = Cell(self, index_child) self.ncells += 1 cell.children.append(child) self._fill(child)
'''Check for NaNs in :py:class:`SPTree`''' return self.root.check_holes()
assert self.ndim >= 2
if x is None: x = num.zeros(self.ndim) x[-2:] = None
x = num.asarray(x, dtype=num.float) if dims is None: dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)]
assert len(dims) == 2
plt = None if axes is None: from matplotlib import pyplot as plt axes = plt.gca()
self.root.plot_2d(axes, x, dims)
axes.set_xlabel('Dim %i' % dims[1]) axes.set_ylabel('Dim %i' % dims[0])
if plt: plt.show()
if x is None: x = num.zeros(self.ndim) x[-1:] = None
x = num.asarray(x, dtype=num.float) if dims is None: dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)]
assert len(dims) == 1
plt = None if axes is None: from matplotlib import pyplot as plt axes = plt.gca()
self.root.plot_1d(axes, x, dims[0])
axes.set_xlabel('Dim %i' % dims[0])
if plt: plt.show()
try: return d[k] except KeyError: v = d[k] = f(k, *addargs) return v
add = [] if x[-1] is None: x_ = x[:-1] add = [None] else: x_ = x
return num.nditer( x, op_axes=(num.identity(len(x_), dtype=num.int)-1).tolist() + add)
if __name__ == '__main__': logging.basicConfig(level=logging.INFO)
def f(x): x0 = num.array([0.5, 0.5, 0.5]) r = 0.5 if num.sqrt(num.sum((x-x0)**2)) < r:
return x[2]**4 + x[1]
return None
tree = SPTree(f, 0.01, [[0., 1.], [0., 1.], [0., 1.]], [0.025, 0.05, 0.1])
import tempfile import os fid, fn = tempfile.mkstemp() tree.dump(fn) tree = SPTree(filename=fn) os.unlink(fn)
from matplotlib import pyplot as plt
v = 0.5 axes = plt.subplot(2, 2, 1) tree.plot_2d(axes, x=(v, None, None)) axes = plt.subplot(2, 2, 2) tree.plot_2d(axes, x=(None, v, None)) axes = plt.subplot(2, 2, 3) tree.plot_2d(axes, x=(None, None, v))
axes = plt.subplot(2, 2, 4) tree.plot_1d(axes, x=(v, v, None))
plt.show() |