1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5from __future__ import division 

6import struct 

7import logging 

8import numpy as num 

9 

10try: 

11 range = xrange 

12except NameError: 

13 pass 

14 

15logger = logging.getLogger('pyrocko.spit') 

16 

17or_ = num.logical_or 

18and_ = num.logical_and 

19not_ = num.logical_not 

20all_ = num.all 

21any_ = num.any 

22 

23 

24class OutOfBounds(Exception): 

25 pass 

26 

27 

28class Cell(object): 

29 def __init__(self, tree, index, f=None): 

30 self.tree = tree 

31 self.index = index 

32 self.depths = num.log2(index).astype(int) 

33 self.bad = False 

34 self.children = [] 

35 n = 2**self.depths 

36 i = self.index - n 

37 delta = (self.tree.xbounds[:, 1] - self.tree.xbounds[:, 0])/n 

38 xmin = self.tree.xbounds[:, 0] 

39 self.xbounds = self.tree.xbounds.copy() 

40 self.xbounds[:, 0] = xmin + i * delta 

41 self.xbounds[:, 1] = xmin + (i+1) * delta 

42 self.a = self.xbounds[:, ::-1].copy() 

43 self.b = self.a.copy() 

44 self.b[:, 1] = self.xbounds[:, 1] - self.xbounds[:, 0] 

45 self.b[:, 0] = - self.b[:, 1] 

46 

47 self.a[:, 0] += (self.b[:, 0] == 0.0)*0.5 

48 self.a[:, 1] -= (self.b[:, 1] == 0.0)*0.5 

49 self.b[:, 0] -= (self.b[:, 0] == 0.0) 

50 self.b[:, 1] += (self.b[:, 1] == 0.0) 

51 

52 if f is None: 

53 it = nditer_outer(tuple(self.xbounds) + (None,)) 

54 for vvv in it: 

55 vvv[-1][...] = self.tree._f_cached(vvv[:-1]) 

56 

57 self.f = it.operands[-1] 

58 else: 

59 self.f = f 

60 

61 def interpolate(self, x): 

62 if self.children: 

63 for cell in self.children: 

64 if all_(and_(cell.xbounds[:, 0] <= x, 

65 x <= cell.xbounds[:, 1])): 

66 return cell.interpolate(x) 

67 

68 else: 

69 if all_(num.isfinite(self.f)): 

70 ws = (x[:, num.newaxis] - self.a)/self.b 

71 wn = num.multiply.reduce( 

72 num.array(num.ix_(*ws), dtype=num.object)) 

73 return num.sum(self.f * wn) 

74 else: 

75 return None 

76 

77 def interpolate_many(self, x): 

78 if self.children: 

79 result = num.empty(x.shape[0], dtype=float) 

80 result[:] = None 

81 for cell in self.children: 

82 indices = num.where( 

83 self.tree.ndim == num.sum(and_( 

84 cell.xbounds[:, 0] <= x, 

85 x <= cell.xbounds[:, 1]), axis=-1))[0] 

86 

87 if indices.size != 0: 

88 result[indices] = cell.interpolate_many(x[indices]) 

89 

90 return result 

91 

92 else: 

93 if all_(num.isfinite(self.f)): 

94 ws = (x[..., num.newaxis] - self.a)/self.b 

95 npoints = ws.shape[0] 

96 ndim = self.tree.ndim 

97 ws_pimped = [ws[:, i, :] for i in range(ndim)] 

98 for i in range(ndim): 

99 s = [npoints] + [1] * ndim 

100 s[1+i] = 2 

101 ws_pimped[i].shape = tuple(s) 

102 

103 wn = ws_pimped[0] 

104 for idim in range(1, ndim): 

105 wn = wn * ws_pimped[idim] 

106 

107 result = wn * self.f 

108 for i in range(ndim): 

109 result = num.sum(result, axis=-1) 

110 

111 return result 

112 else: 

113 result = num.empty(x.shape[0], dtype=float) 

114 result[:] = None 

115 return result 

116 

117 def slice(self, x): 

118 x = num.array(x, dtype=float) 

119 x_mask = not_(num.isfinite(x)) 

120 x_ = x.copy() 

121 x_[x_mask] = 0.0 

122 return [ 

123 cell for cell in self.children if all_(or_( 

124 x_mask, 

125 and_( 

126 cell.xbounds[:, 0] <= x_, 

127 x_ <= cell.xbounds[:, 1])))] 

128 

129 def plot_rects(self, axes, x, dims): 

130 if self.children: 

131 for cell in self.slice(x): 

132 cell.plot_rects(axes, x, dims) 

133 

134 else: 

135 points = [] 

136 for iy, ix in ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)): 

137 points.append( 

138 (self.xbounds[dims[0], iy], self.xbounds[dims[1], ix])) 

139 

140 points = num.transpose(points) 

141 axes.plot(points[1], points[0], color=(0.1, 0.1, 0.0, 0.1)) 

142 

143 def check_holes(self): 

144 ''' 

145 Check if :py:class:`Cell` or its children contain NaNs. 

146 ''' 

147 if self.children: 

148 return any([child.check_holes() for child in self.children]) 

149 else: 

150 return num.any(num.isnan(self.f)) 

151 

152 def plot_2d(self, axes, x, dims): 

153 idims = num.array(dims) 

154 self.plot_rects(axes, x, dims) 

155 coords = [ 

156 num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d)) 

157 for (xb, d) in zip(self.xbounds[idims, :], self.tree.xtols[idims])] 

158 

159 npoints = coords[0].size * coords[1].size 

160 g = num.meshgrid(*coords[::-1])[::-1] 

161 points = num.empty((npoints, self.tree.ndim), dtype=float) 

162 for idim in range(self.tree.ndim): 

163 try: 

164 idimout = dims.index(idim) 

165 points[:, idim] = g[idimout].ravel() 

166 except ValueError: 

167 points[:, idim] = x[idim] 

168 

169 fi = num.empty((coords[0].size, coords[1].size), dtype=float) 

170 fi_r = fi.ravel() 

171 fi_r[...] = self.interpolate_many(points) 

172 

173 if num.any(num.isnan(fi)): 

174 logger.warn('') 

175 if any_(num.isfinite(fi)): 

176 fi = num.ma.masked_invalid(fi) 

177 axes.imshow( 

178 fi, origin='lower', 

179 extent=[coords[1].min(), coords[1].max(), 

180 coords[0].min(), coords[0].max()], 

181 interpolation='nearest', 

182 aspect='auto', 

183 cmap='RdYlBu') 

184 

185 def plot_1d(self, axes, x, dim): 

186 xb = self.xbounds[dim] 

187 d = self.tree.xtols[dim] 

188 coords = num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d)) 

189 

190 npoints = coords.size 

191 points = num.empty((npoints, self.tree.ndim), dtype=float) 

192 for idim in range(self.tree.ndim): 

193 if idim == dim: 

194 points[:, idim] = coords 

195 else: 

196 points[:, idim] = x[idim] 

197 

198 fi = self.interpolate_many(points) 

199 if any_(num.isfinite(fi)): 

200 fi = num.ma.masked_invalid(fi) 

201 axes.plot(coords, fi) 

202 

203 def __iter__(self): 

204 yield self 

205 for c in self.children: 

206 for x in c: 

207 yield x 

208 

209 def dump(self, file): 

210 self.index.astype('<i4').tofile(file) 

211 self.f.astype('<f8').tofile(file) 

212 for c in self.children: 

213 c.dump(file) 

214 

215 

216def bread(f, fmt): 

217 s = f.read(struct.calcsize(fmt)) 

218 return struct.unpack(fmt, s) 

219 

220 

221class SPTree(object): 

222 

223 def __init__(self, f=None, ftol=None, xbounds=None, xtols=None, 

224 filename=None, addargs=()): 

225 

226 ''' 

227 Create n-dimensional space partitioning interpolator. 

228 

229 :param f: callable function f(x) where x is a vector of size n 

230 :param ftol: target accuracy |f_interp(x) - f(x)| <= ftol 

231 :param xbounds: bounds of x, shape (n, 2) 

232 :param xtols: target coarsenesses in x, vector of size n 

233 :param addargs: additional arguments to pass to f 

234 ''' 

235 

236 if filename is None: 

237 assert all(v is not None for v in (f, ftol, xbounds, xtols)) 

238 

239 self.f = f 

240 self.ftol = float(ftol) 

241 self.f_values = {} 

242 self.ncells = 0 

243 self.addargs = addargs 

244 

245 self.xbounds = num.asarray(xbounds, dtype=float) 

246 assert self.xbounds.ndim == 2 

247 assert self.xbounds.shape[1] == 2 

248 self.ndim = self.xbounds.shape[0] 

249 

250 self.xtols = num.asarray(xtols, dtype=float) 

251 assert self.xtols.ndim == 1 and self.xtols.size == self.ndim 

252 

253 self.maxdepths = num.ceil(num.log2( 

254 num.maximum( 

255 1.0, 

256 (self.xbounds[:, 1] - self.xbounds[:, 0]) / self.xtols) 

257 )).astype(int) 

258 

259 self.root = None 

260 self.ones_int = num.ones(self.ndim, dtype=int) 

261 

262 cc = num.ix_(*[num.arange(3)]*self.ndim) 

263 w = num.zeros([3]*self.ndim + [self.ndim, 2]) 

264 for i, c in enumerate(cc): 

265 w[..., i, 0] = (2-c)*0.5 

266 w[..., i, 1] = c*0.5 

267 

268 self.pointmaker = w 

269 self.pointmaker_mask = num.sum(w[..., 0] == 0.5, axis=-1) != 0 

270 self.pointmaker_masked = w[self.pointmaker_mask] 

271 

272 self.nothing_found_yet = True 

273 

274 self.root = Cell(self, self.ones_int) 

275 self.ncells += 1 

276 

277 self.fraction_bad = 0.0 

278 self.nbad = 0 

279 self.cells_to_continue = [] 

280 for clipdepth in range(0, num.max(self.maxdepths)+1): 

281 self.clipdepth = clipdepth 

282 self.tested = 0 

283 if self.clipdepth == 0: 

284 self._fill(self.root) 

285 else: 

286 self._continue_fill() 

287 

288 self.status() 

289 

290 if not self.cells_to_continue: 

291 break 

292 

293 else: 

294 self._load(filename) 

295 

296 def status(self): 

297 perc = (1.0-self.fraction_bad)*100 

298 s = '%6.1f%%' % perc 

299 

300 if self.fraction_bad != 0.0 and s == ' 100.0%': 

301 s = '~100.0%' 

302 

303 logger.info('at level %2i: %s covered, %6i cell%s' % ( 

304 self.clipdepth, s, self.ncells, ['s', ''][self.ncells == 1])) 

305 

306 def __iter__(self): 

307 return iter(self.root) 

308 

309 def __len__(self): 

310 return self.ncells 

311 

312 def dump(self, filename): 

313 with open(filename, 'wb') as file: 

314 version = 1 

315 file.write(b'SPITREE ') 

316 file.write(struct.pack( 

317 '<QQQd', version, self.ndim, self.ncells, self.ftol)) 

318 self.xbounds.astype('<f8').tofile(file) 

319 self.xtols.astype('<f8').tofile(file) 

320 self.root.dump(file) 

321 

322 def _load(self, filename): 

323 with open(filename, 'rb') as file: 

324 marker, version, self.ndim, self.ncells, self.ftol = bread( 

325 file, '<8sQQQd') 

326 assert marker == b'SPITREE ' 

327 assert version == 1 

328 self.xbounds = num.fromfile( 

329 file, dtype='<f8', count=self.ndim*2).reshape(self.ndim, 2) 

330 self.xtols = num.fromfile( 

331 file, dtype='<f8', count=self.ndim) 

332 

333 path = [] 

334 for icell in range(self.ncells): 

335 index = num.fromfile( 

336 file, dtype='<i4', count=self.ndim) 

337 f = num.fromfile( 

338 file, dtype='<f8', count=2**self.ndim).reshape( 

339 [2]*self.ndim) 

340 

341 cell = Cell(self, index, f) 

342 if not path: 

343 self.root = cell 

344 path.append(cell) 

345 

346 else: 

347 while not any_(path[-1].index == (cell.index >> 1)): 

348 path.pop() 

349 

350 path[-1].children.append(cell) 

351 path.append(cell) 

352 

353 def _f_cached(self, x): 

354 return getset( 

355 self.f_values, tuple(float(xx) for xx in x), self.f, self.addargs) 

356 

357 def interpolate(self, x): 

358 x = num.asarray(x, dtype=float) 

359 assert x.ndim == 1 and x.size == self.ndim 

360 if not all_(and_(self.xbounds[:, 0] <= x, x <= self.xbounds[:, 1])): 

361 raise OutOfBounds() 

362 

363 return self.root.interpolate(x) 

364 

365 def __call__(self, x): 

366 return self.interpolate(x) 

367 

368 def interpolate_many(self, x): 

369 return self.root.interpolate_many(x) 

370 

371 def _continue_fill(self): 

372 cells_to_continue, self.cells_to_continue = self.cells_to_continue, [] 

373 for cell in cells_to_continue: 

374 self._deepen_cell(cell) 

375 

376 def _fill(self, cell): 

377 

378 self.tested += 1 

379 xtestpoints = num.sum(cell.xbounds * self.pointmaker_masked, axis=-1) 

380 

381 fis = cell.interpolate_many(xtestpoints) 

382 fes = num.array( 

383 [self._f_cached(x) for x in xtestpoints], dtype=float) 

384 

385 iffes = num.isfinite(fes) 

386 iffis = num.isfinite(fis) 

387 works = iffes == iffis 

388 iif = num.logical_and(iffes, iffis) 

389 

390 works[iif] *= num.abs(fes[iif] - fis[iif]) < self.ftol 

391 

392 nundef = num.sum(not_(num.isfinite(fes))) + \ 

393 num.sum(not_(num.isfinite(cell.f))) 

394 

395 some_undef = 0 < nundef < (xtestpoints.shape[0] + cell.f.size) 

396 

397 if any_(works): 

398 self.nothing_found_yet = False 

399 

400 if not all_(works) or some_undef or self.nothing_found_yet: 

401 deepen = self.ones_int.copy() 

402 if not some_undef: 

403 works_full = num.ones([3]*self.ndim, dtype=num.bool) 

404 works_full[self.pointmaker_mask] = works 

405 for idim in range(self.ndim): 

406 dimcorners = [slice(None, None, 2)] * self.ndim 

407 dimcorners[idim] = 1 

408 if all_(works_full[tuple(dimcorners)]): 

409 deepen[idim] = 0 

410 

411 if not any_(deepen): 

412 deepen = self.ones_int 

413 

414 deepen = num.where( 

415 cell.depths + deepen > self.maxdepths, 0, deepen) 

416 

417 cell.deepen = deepen 

418 

419 if any_(deepen) and all_(cell.depths + deepen <= self.clipdepth): 

420 self._deepen_cell(cell) 

421 else: 

422 if any_(deepen): 

423 self.cells_to_continue.append(cell) 

424 

425 cell.bad = True 

426 self.fraction_bad += num.product(1.0/2**cell.depths) 

427 self.nbad += 1 

428 

429 def _deepen_cell(self, cell): 

430 if cell.bad: 

431 self.fraction_bad -= num.product(1.0/2**cell.depths) 

432 self.nbad -= 1 

433 cell.bad = False 

434 

435 for iadd in num.ndindex(*(cell.deepen+1)): 

436 index_child = (cell.index << cell.deepen) + iadd 

437 child = Cell(self, index_child) 

438 self.ncells += 1 

439 cell.children.append(child) 

440 self._fill(child) 

441 

442 def check_holes(self): 

443 ''' 

444 Check for NaNs in :py:class:`SPTree` 

445 ''' 

446 return self.root.check_holes() 

447 

448 def plot_2d(self, axes=None, x=None, dims=None): 

449 assert self.ndim >= 2 

450 

451 if x is None: 

452 x = num.zeros(self.ndim) 

453 x[-2:] = None 

454 

455 x = num.asarray(x, dtype=float) 

456 if dims is None: 

457 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)] 

458 

459 assert len(dims) == 2 

460 

461 plt = None 

462 if axes is None: 

463 from matplotlib import pyplot as plt 

464 axes = plt.gca() 

465 

466 self.root.plot_2d(axes, x, dims) 

467 

468 axes.set_xlabel('Dim %i' % dims[1]) 

469 axes.set_ylabel('Dim %i' % dims[0]) 

470 

471 if plt: 

472 plt.show() 

473 

474 def plot_1d(self, axes=None, x=None, dims=None): 

475 

476 if x is None: 

477 x = num.zeros(self.ndim) 

478 x[-1:] = None 

479 

480 x = num.asarray(x, dtype=float) 

481 if dims is None: 

482 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)] 

483 

484 assert len(dims) == 1 

485 

486 plt = None 

487 if axes is None: 

488 from matplotlib import pyplot as plt 

489 axes = plt.gca() 

490 

491 self.root.plot_1d(axes, x, dims[0]) 

492 

493 axes.set_xlabel('Dim %i' % dims[0]) 

494 

495 if plt: 

496 plt.show() 

497 

498 

499def getset(d, k, f, addargs): 

500 try: 

501 return d[k] 

502 except KeyError: 

503 v = d[k] = f(k, *addargs) 

504 return v 

505 

506 

507def nditer_outer(x): 

508 add = [] 

509 if x[-1] is None: 

510 x_ = x[:-1] 

511 add = [None] 

512 else: 

513 x_ = x 

514 

515 return num.nditer( 

516 x, 

517 op_axes=(num.identity(len(x_), dtype=int)-1).tolist() + add) 

518 

519 

520if __name__ == '__main__': 

521 logging.basicConfig(level=logging.INFO) 

522 

523 def f(x): 

524 x0 = num.array([0.5, 0.5, 0.5]) 

525 r = 0.5 

526 if num.sqrt(num.sum((x-x0)**2)) < r: 

527 

528 return x[2]**4 + x[1] 

529 

530 return None 

531 

532 tree = SPTree(f, 0.01, [[0., 1.], [0., 1.], [0., 1.]], [0.025, 0.05, 0.1]) 

533 

534 import tempfile 

535 import os 

536 fid, fn = tempfile.mkstemp() 

537 tree.dump(fn) 

538 tree = SPTree(filename=fn) 

539 os.unlink(fn) 

540 

541 from matplotlib import pyplot as plt 

542 

543 v = 0.5 

544 axes = plt.subplot(2, 2, 1) 

545 tree.plot_2d(axes, x=(v, None, None)) 

546 axes = plt.subplot(2, 2, 2) 

547 tree.plot_2d(axes, x=(None, v, None)) 

548 axes = plt.subplot(2, 2, 3) 

549 tree.plot_2d(axes, x=(None, None, v)) 

550 

551 axes = plt.subplot(2, 2, 4) 

552 tree.plot_1d(axes, x=(v, v, None)) 

553 

554 plt.show()