Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/spit.py: 86%

344 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-03-07 11:54 +0000

1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6''' 

7N-dimensional space partitioning multi-linear interpolator. 

8''' 

9 

10import struct 

11import logging 

12import numpy as num 

13 

14try: 

15 range = xrange 

16except NameError: 

17 pass 

18 

19logger = logging.getLogger('pyrocko.spit') 

20 

21or_ = num.logical_or 

22and_ = num.logical_and 

23not_ = num.logical_not 

24all_ = num.all 

25any_ = num.any 

26 

27 

28class OutOfBounds(Exception): 

29 pass 

30 

31 

32class Cell(object): 

33 __slots__ = ( 

34 "tree", "index", "depths", "bad", "children", "xbounds", "deepen", 

35 "a", "b", "f") 

36 

37 def __init__(self, tree, index, f=None): 

38 self.tree = tree 

39 self.index = index 

40 self.depths = num.log2(index).astype(int) 

41 self.bad = False 

42 self.children = [] 

43 n = 2**self.depths 

44 i = self.index - n 

45 delta = (self.tree.xbounds[:, 1] - self.tree.xbounds[:, 0])/n 

46 xmin = self.tree.xbounds[:, 0] 

47 self.xbounds = self.tree.xbounds.copy() 

48 self.xbounds[:, 0] = xmin + i * delta 

49 self.xbounds[:, 1] = xmin + (i+1) * delta 

50 self.a = self.xbounds[:, ::-1].copy() 

51 self.b = self.a.copy() 

52 self.b[:, 1] = self.xbounds[:, 1] - self.xbounds[:, 0] 

53 self.b[:, 0] = - self.b[:, 1] 

54 

55 self.a[:, 0] += (self.b[:, 0] == 0.0)*0.5 

56 self.a[:, 1] -= (self.b[:, 1] == 0.0)*0.5 

57 self.b[:, 0] -= (self.b[:, 0] == 0.0) 

58 self.b[:, 1] += (self.b[:, 1] == 0.0) 

59 

60 if f is None: 

61 it = nditer_outer(tuple(self.xbounds) + (None,)) 

62 for vvv in it: 

63 vvv[-1][...] = self.tree._f_cached(vvv[:-1]) 

64 

65 self.f = it.operands[-1] 

66 else: 

67 self.f = f 

68 

69 def interpolate(self, x): 

70 if self.children: 

71 for cell in self.children: 

72 if all_(and_(cell.xbounds[:, 0] <= x, 

73 x <= cell.xbounds[:, 1])): 

74 return cell.interpolate(x) 

75 

76 if all_(num.isfinite(self.f)): 

77 ws = (x[:, num.newaxis] - self.a)/self.b 

78 wn = num.multiply.reduce( 

79 num.array(num.ix_(*ws), dtype=object)) 

80 return num.sum(self.f * wn) 

81 else: 

82 if all_(num.isfinite(self.f)): 

83 ws = (x[:, num.newaxis] - self.a)/self.b 

84 wn2 = self.tree.ones_cell.copy() 

85 for ws_ in num.ix_(*ws): 

86 wn2 *= ws_ 

87 

88 # wn = num.multiply.reduce( 

89 # num.array(num.ix_(*ws), dtype=object)) 

90 return num.sum(self.f * wn2) 

91 else: 

92 return None 

93 

94 def interpolate_many(self, x): 

95 ndim = self.tree.ndim 

96 ndim_range = tuple(range(ndim)) 

97 if self.children: 

98 result = num.full(x.shape[0], fill_value=num.nan) 

99 for cell in self.children: 

100 indices = num.where( 

101 ndim == num.sum(and_( 

102 cell.xbounds[:, 0] <= x, 

103 x <= cell.xbounds[:, 1]), axis=-1))[0] 

104 

105 if indices.size != 0: 

106 result[indices] = cell.interpolate_many(x[indices]) 

107 return result 

108 

109 if all_(num.isfinite(self.f)): 

110 ws = (x[..., num.newaxis] - self.a)/self.b 

111 npoints = ws.shape[0] 

112 ws_pimped = [ws[:, i, :] for i in ndim_range] 

113 for i in ndim_range: 

114 s = [npoints] + [1] * ndim 

115 s[1+i] = 2 

116 ws_pimped[i].shape = tuple(s) 

117 

118 wn = ws_pimped[0] 

119 for idim in ndim_range[1:]: 

120 wn = wn * ws_pimped[idim] 

121 

122 result = wn * self.f 

123 for i in ndim_range: 

124 result = num.sum(result, axis=-1) 

125 

126 return result 

127 else: 

128 return num.full(x.shape[0], fill_value=num.nan) 

129 

130 def slice(self, x): 

131 x = num.array(x, dtype=float) 

132 x_mask = not_(num.isfinite(x)) 

133 x_ = x.copy() 

134 x_[x_mask] = 0.0 

135 return [ 

136 cell for cell in self.children if all_(or_( 

137 x_mask, 

138 and_( 

139 cell.xbounds[:, 0] <= x_, 

140 x_ <= cell.xbounds[:, 1])))] 

141 

142 def plot_rects(self, axes, x, dims): 

143 if self.children: 

144 for cell in self.slice(x): 

145 cell.plot_rects(axes, x, dims) 

146 

147 else: 

148 points = [] 

149 for iy, ix in ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)): 

150 points.append( 

151 (self.xbounds[dims[0], iy], self.xbounds[dims[1], ix])) 

152 

153 points = num.transpose(points) 

154 axes.plot(points[1], points[0], color=(0.1, 0.1, 0.0, 0.1)) 

155 

156 def check_holes(self): 

157 ''' 

158 Check if :py:class:`Cell` or its children contain NaNs. 

159 ''' 

160 if self.children: 

161 return any([child.check_holes() for child in self.children]) 

162 else: 

163 return num.any(num.isnan(self.f)) 

164 

165 def plot_2d(self, axes, x, dims): 

166 idims = num.array(dims) 

167 self.plot_rects(axes, x, dims) 

168 coords = [ 

169 num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d)) 

170 for (xb, d) in zip(self.xbounds[idims, :], self.tree.xtols[idims])] 

171 

172 npoints = coords[0].size * coords[1].size 

173 g = num.meshgrid(*coords[::-1])[::-1] 

174 points = num.empty((npoints, self.tree.ndim), dtype=float) 

175 for idim in range(self.tree.ndim): 

176 try: 

177 idimout = dims.index(idim) 

178 points[:, idim] = g[idimout].ravel() 

179 except ValueError: 

180 points[:, idim] = x[idim] 

181 

182 fi = num.empty((coords[0].size, coords[1].size), dtype=float) 

183 fi_r = fi.ravel() 

184 fi_r[...] = self.interpolate_many(points) 

185 

186 if num.any(num.isnan(fi)): 

187 logger.warning('') 

188 if any_(num.isfinite(fi)): 

189 fi = num.ma.masked_invalid(fi) 

190 axes.imshow( 

191 fi, origin='lower', 

192 extent=[coords[1].min(), coords[1].max(), 

193 coords[0].min(), coords[0].max()], 

194 interpolation='nearest', 

195 aspect='auto', 

196 cmap='RdYlBu') 

197 

198 def plot_1d(self, axes, x, dim): 

199 xb = self.xbounds[dim] 

200 d = self.tree.xtols[dim] 

201 coords = num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d)) 

202 

203 npoints = coords.size 

204 points = num.empty((npoints, self.tree.ndim), dtype=float) 

205 for idim in range(self.tree.ndim): 

206 if idim == dim: 

207 points[:, idim] = coords 

208 else: 

209 points[:, idim] = x[idim] 

210 

211 fi = self.interpolate_many(points) 

212 if any_(num.isfinite(fi)): 

213 fi = num.ma.masked_invalid(fi) 

214 axes.plot(coords, fi) 

215 

216 def __iter__(self): 

217 yield self 

218 for c in self.children: 

219 for x in c: 

220 yield x 

221 

222 def dump(self, file): 

223 self.index.astype('<i4').tofile(file) 

224 self.f.astype('<f8').tofile(file) 

225 for c in self.children: 

226 c.dump(file) 

227 

228 

229def bread(f, fmt): 

230 s = f.read(struct.calcsize(fmt)) 

231 return struct.unpack(fmt, s) 

232 

233 

234class SPTree(object): 

235 ''' 

236 N-dimensional space partitioning interpolator. 

237 

238 :param f: callable function ``f(x)`` where ``x`` is a vector of size ``n`` 

239 :param ftol: target accuracy ``|f_interp(x) - f(x)| <= ftol`` 

240 :param xbounds: bounds of ``x``, shape ``(n, 2)`` 

241 :param xtols: target coarsenesses in ``x``, vector of size ``n`` 

242 :param addargs: additional arguments to pass to ``f`` 

243 ''' 

244 

245 def __init__(self, f=None, ftol=None, xbounds=None, xtols=None, 

246 filename=None, addargs=()): 

247 

248 if filename is None: 

249 assert all(v is not None for v in (f, ftol, xbounds, xtols)) 

250 

251 self.f = f 

252 self.ftol = float(ftol) 

253 self.f_values = {} 

254 self.ncells = 0 

255 self.addargs = addargs 

256 

257 self.xbounds = num.asarray(xbounds, dtype=float) 

258 assert self.xbounds.ndim == 2 

259 assert self.xbounds.shape[1] == 2 

260 self.ndim = self.xbounds.shape[0] 

261 

262 self.xtols = num.asarray(xtols, dtype=float) 

263 assert self.xtols.ndim == 1 and self.xtols.size == self.ndim 

264 

265 self.maxdepths = num.ceil(num.log2( 

266 num.maximum( 

267 1.0, 

268 (self.xbounds[:, 1] - self.xbounds[:, 0]) / self.xtols) 

269 )).astype(int) 

270 

271 self.root = None 

272 self.ones_int = num.ones(self.ndim, dtype=int) 

273 

274 cc = num.ix_(*[num.arange(3)]*self.ndim) 

275 w = num.zeros([3]*self.ndim + [self.ndim, 2]) 

276 for i, c in enumerate(cc): 

277 w[..., i, 0] = (2-c)*0.5 

278 w[..., i, 1] = c*0.5 

279 

280 self.pointmaker = w 

281 self.pointmaker_mask = num.sum(w[..., 0] == 0.5, axis=-1) != 0 

282 self.pointmaker_masked = w[self.pointmaker_mask] 

283 

284 self.nothing_found_yet = True 

285 

286 self.root = Cell(self, self.ones_int) 

287 self.ncells += 1 

288 

289 self.fraction_bad = 0.0 

290 self.nbad = 0 

291 self.cells_to_continue = [] 

292 for clipdepth in range(0, num.max(self.maxdepths)+1): 

293 self.clipdepth = clipdepth 

294 self.tested = 0 

295 if self.clipdepth == 0: 

296 self._fill(self.root) 

297 else: 

298 self._continue_fill() 

299 

300 self.status() 

301 

302 if not self.cells_to_continue: 

303 break 

304 

305 else: 

306 self._load(filename) 

307 

308 self.ones_cell = num.ones((2,) * self.ndim, dtype=float) 

309 

310 def status(self): 

311 perc = (1.0-self.fraction_bad)*100 

312 s = '%6.1f%%' % perc 

313 

314 if self.fraction_bad != 0.0 and s == ' 100.0%': 

315 s = '~100.0%' 

316 

317 logger.info('at level %2i: %s covered, %6i cell%s' % ( 

318 self.clipdepth, s, self.ncells, ['s', ''][self.ncells == 1])) 

319 

320 def __iter__(self): 

321 return iter(self.root) 

322 

323 def __len__(self): 

324 return self.ncells 

325 

326 def dump(self, filename): 

327 with open(filename, 'wb') as file: 

328 version = 1 

329 file.write(b'SPITREE ') 

330 file.write(struct.pack( 

331 '<QQQd', version, self.ndim, self.ncells, self.ftol)) 

332 self.xbounds.astype('<f8').tofile(file) 

333 self.xtols.astype('<f8').tofile(file) 

334 self.root.dump(file) 

335 

336 def _load(self, filename): 

337 with open(filename, 'rb') as file: 

338 marker, version, self.ndim, self.ncells, self.ftol = bread( 

339 file, '<8sQQQd') 

340 assert marker == b'SPITREE ' 

341 assert version == 1 

342 self.xbounds = num.fromfile( 

343 file, dtype='<f8', count=self.ndim*2).reshape(self.ndim, 2) 

344 self.xtols = num.fromfile( 

345 file, dtype='<f8', count=self.ndim) 

346 

347 path = [] 

348 for icell in range(self.ncells): 

349 index = num.fromfile( 

350 file, dtype='<i4', count=self.ndim) 

351 f = num.fromfile( 

352 file, dtype='<f8', count=2**self.ndim).reshape( 

353 [2]*self.ndim) 

354 

355 cell = Cell(self, index, f) 

356 if not path: 

357 self.root = cell 

358 path.append(cell) 

359 

360 else: 

361 while not any_(path[-1].index == (cell.index >> 1)): 

362 path.pop() 

363 

364 path[-1].children.append(cell) 

365 path.append(cell) 

366 

367 def _f_cached(self, x): 

368 return getset( 

369 self.f_values, tuple(float(xx) for xx in x), self.f, self.addargs) 

370 

371 def interpolate(self, x): 

372 x = num.asarray(x, dtype=float) 

373 assert x.ndim == 1 and x.size == self.ndim 

374 if not all_(and_(self.xbounds[:, 0] <= x, x <= self.xbounds[:, 1])): 

375 raise OutOfBounds() 

376 

377 return self.root.interpolate(x) 

378 

379 def __call__(self, x): 

380 return self.interpolate(x) 

381 

382 def interpolate_many(self, x): 

383 return self.root.interpolate_many(x) 

384 

385 def _continue_fill(self): 

386 cells_to_continue, self.cells_to_continue = self.cells_to_continue, [] 

387 for cell in cells_to_continue: 

388 self._deepen_cell(cell) 

389 

390 def _fill(self, cell): 

391 

392 self.tested += 1 

393 xtestpoints = num.sum(cell.xbounds * self.pointmaker_masked, axis=-1) 

394 

395 fis = cell.interpolate_many(xtestpoints) 

396 fes = num.array( 

397 [self._f_cached(x) for x in xtestpoints], dtype=float) 

398 

399 iffes = num.isfinite(fes) 

400 iffis = num.isfinite(fis) 

401 works = iffes == iffis 

402 iif = num.logical_and(iffes, iffis) 

403 

404 works[iif] *= num.abs(fes[iif] - fis[iif]) < self.ftol 

405 

406 nundef = num.sum(not_(num.isfinite(fes))) + \ 

407 num.sum(not_(num.isfinite(cell.f))) 

408 

409 some_undef = 0 < nundef < (xtestpoints.shape[0] + cell.f.size) 

410 

411 if any_(works): 

412 self.nothing_found_yet = False 

413 

414 if not all_(works) or some_undef or self.nothing_found_yet: 

415 deepen = self.ones_int.copy() 

416 if not some_undef: 

417 works_full = num.ones([3]*self.ndim, dtype=bool) 

418 works_full[self.pointmaker_mask] = works 

419 for idim in range(self.ndim): 

420 dimcorners = [slice(None, None, 2)] * self.ndim 

421 dimcorners[idim] = 1 

422 if all_(works_full[tuple(dimcorners)]): 

423 deepen[idim] = 0 

424 

425 if not any_(deepen): 

426 deepen = self.ones_int 

427 

428 deepen = num.where( 

429 cell.depths + deepen > self.maxdepths, 0, deepen) 

430 

431 cell.deepen = deepen 

432 

433 if any_(deepen) and all_(cell.depths + deepen <= self.clipdepth): 

434 self._deepen_cell(cell) 

435 else: 

436 if any_(deepen): 

437 self.cells_to_continue.append(cell) 

438 

439 cell.bad = True 

440 self.fraction_bad += num.prod(1.0/2**cell.depths) 

441 self.nbad += 1 

442 

443 def _deepen_cell(self, cell): 

444 if cell.bad: 

445 self.fraction_bad -= num.prod(1.0/2**cell.depths) 

446 self.nbad -= 1 

447 cell.bad = False 

448 

449 for iadd in num.ndindex(*(cell.deepen+1)): 

450 index_child = (cell.index << cell.deepen) + iadd 

451 child = Cell(self, index_child) 

452 self.ncells += 1 

453 cell.children.append(child) 

454 self._fill(child) 

455 

456 def check_holes(self): 

457 ''' 

458 Check for NaNs in :py:class:`SPTree` 

459 ''' 

460 return self.root.check_holes() 

461 

462 def plot_2d(self, axes=None, x=None, dims=None): 

463 assert self.ndim >= 2 

464 

465 if x is None: 

466 x = num.zeros(self.ndim) 

467 x[-2:] = None 

468 

469 x = num.asarray(x, dtype=float) 

470 if dims is None: 

471 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)] 

472 

473 assert len(dims) == 2 

474 

475 plt = None 

476 if axes is None: 

477 from matplotlib import pyplot as plt 

478 axes = plt.gca() 

479 

480 self.root.plot_2d(axes, x, dims) 

481 

482 axes.set_xlabel('Dim %i' % dims[1]) 

483 axes.set_ylabel('Dim %i' % dims[0]) 

484 

485 if plt: 

486 plt.show() 

487 

488 def plot_1d(self, axes=None, x=None, dims=None): 

489 

490 if x is None: 

491 x = num.zeros(self.ndim) 

492 x[-1:] = None 

493 

494 x = num.asarray(x, dtype=float) 

495 if dims is None: 

496 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)] 

497 

498 assert len(dims) == 1 

499 

500 plt = None 

501 if axes is None: 

502 from matplotlib import pyplot as plt 

503 axes = plt.gca() 

504 

505 self.root.plot_1d(axes, x, dims[0]) 

506 

507 axes.set_xlabel('Dim %i' % dims[0]) 

508 

509 if plt: 

510 plt.show() 

511 

512 

513def getset(d, k, f, addargs): 

514 try: 

515 return d[k] 

516 except KeyError: 

517 v = d[k] = f(k, *addargs) 

518 return v 

519 

520 

521def nditer_outer(x): 

522 add = [] 

523 if x[-1] is None: 

524 x_ = x[:-1] 

525 add = [None] 

526 else: 

527 x_ = x 

528 

529 return num.nditer( 

530 x, 

531 op_axes=(num.identity(len(x_), dtype=int)-1).tolist() + add) 

532 

533 

534if __name__ == '__main__': 

535 import time 

536 logging.basicConfig(level=logging.INFO) 

537 

538 def f(x): 

539 x0 = num.array([0.5, 0.5, 0.5]) 

540 r = 0.5 

541 if num.sqrt(num.sum((x-x0)**2)) < r: 

542 

543 return x[2]**4 + x[1] 

544 

545 return None 

546 

547 tree = SPTree(f, 0.01, [[0., 1.], [0., 1.], [0., 1.]], [0.025, 0.05, 0.1]) 

548 

549 x = num.array([0.35, 0.25, 0.15]) 

550 n = 10000 

551 xs = num.random.random((n, 3)) 

552 t0 = time.time() 

553 vs = tree.interpolate_many(xs) 

554 t1 = time.time() 

555 for i in range(n): 

556 v = tree.interpolate(xs[i]) 

557 assert vs[i] == v or (v is None and num.isnan(vs[i])) 

558 t2 = time.time() 

559 

560 print(t1 - t0, t2 - t1) 

561 

562 import tempfile 

563 import os 

564 fid, fn = tempfile.mkstemp() 

565 tree.dump(fn) 

566 tree = SPTree(filename=fn) 

567 os.unlink(fn) 

568 

569 from matplotlib import pyplot as plt 

570 

571 v = 0.5 

572 axes = plt.subplot(2, 2, 1) 

573 tree.plot_2d(axes, x=(v, None, None)) 

574 axes = plt.subplot(2, 2, 2) 

575 tree.plot_2d(axes, x=(None, v, None)) 

576 axes = plt.subplot(2, 2, 3) 

577 tree.plot_2d(axes, x=(None, None, v)) 

578 

579 axes = plt.subplot(2, 2, 4) 

580 tree.plot_1d(axes, x=(v, v, None)) 

581 

582 plt.show()