Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/spit.py: 88%

338 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-09-30 08:22 +0000

1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6''' 

7N-dimensional space partitioning multi-linear interpolator. 

8''' 

9 

10import struct 

11import logging 

12import numpy as num 

13 

14try: 

15 range = xrange 

16except NameError: 

17 pass 

18 

19logger = logging.getLogger('pyrocko.spit') 

20 

21or_ = num.logical_or 

22and_ = num.logical_and 

23not_ = num.logical_not 

24all_ = num.all 

25any_ = num.any 

26 

27 

28class OutOfBounds(Exception): 

29 pass 

30 

31 

32class Cell(object): 

33 def __init__(self, tree, index, f=None): 

34 self.tree = tree 

35 self.index = index 

36 self.depths = num.log2(index).astype(int) 

37 self.bad = False 

38 self.children = [] 

39 n = 2**self.depths 

40 i = self.index - n 

41 delta = (self.tree.xbounds[:, 1] - self.tree.xbounds[:, 0])/n 

42 xmin = self.tree.xbounds[:, 0] 

43 self.xbounds = self.tree.xbounds.copy() 

44 self.xbounds[:, 0] = xmin + i * delta 

45 self.xbounds[:, 1] = xmin + (i+1) * delta 

46 self.a = self.xbounds[:, ::-1].copy() 

47 self.b = self.a.copy() 

48 self.b[:, 1] = self.xbounds[:, 1] - self.xbounds[:, 0] 

49 self.b[:, 0] = - self.b[:, 1] 

50 

51 self.a[:, 0] += (self.b[:, 0] == 0.0)*0.5 

52 self.a[:, 1] -= (self.b[:, 1] == 0.0)*0.5 

53 self.b[:, 0] -= (self.b[:, 0] == 0.0) 

54 self.b[:, 1] += (self.b[:, 1] == 0.0) 

55 

56 if f is None: 

57 it = nditer_outer(tuple(self.xbounds) + (None,)) 

58 for vvv in it: 

59 vvv[-1][...] = self.tree._f_cached(vvv[:-1]) 

60 

61 self.f = it.operands[-1] 

62 else: 

63 self.f = f 

64 

65 def interpolate(self, x): 

66 if self.children: 

67 for cell in self.children: 

68 if all_(and_(cell.xbounds[:, 0] <= x, 

69 x <= cell.xbounds[:, 1])): 

70 return cell.interpolate(x) 

71 

72 else: 

73 if all_(num.isfinite(self.f)): 

74 ws = (x[:, num.newaxis] - self.a)/self.b 

75 wn = num.multiply.reduce( 

76 num.array(num.ix_(*ws), dtype=object)) 

77 return num.sum(self.f * wn) 

78 else: 

79 return None 

80 

81 def interpolate_many(self, x): 

82 if self.children: 

83 result = num.empty(x.shape[0], dtype=float) 

84 result[:] = None 

85 for cell in self.children: 

86 indices = num.where( 

87 self.tree.ndim == num.sum(and_( 

88 cell.xbounds[:, 0] <= x, 

89 x <= cell.xbounds[:, 1]), axis=-1))[0] 

90 

91 if indices.size != 0: 

92 result[indices] = cell.interpolate_many(x[indices]) 

93 

94 return result 

95 

96 else: 

97 if all_(num.isfinite(self.f)): 

98 ws = (x[..., num.newaxis] - self.a)/self.b 

99 npoints = ws.shape[0] 

100 ndim = self.tree.ndim 

101 ws_pimped = [ws[:, i, :] for i in range(ndim)] 

102 for i in range(ndim): 

103 s = [npoints] + [1] * ndim 

104 s[1+i] = 2 

105 ws_pimped[i].shape = tuple(s) 

106 

107 wn = ws_pimped[0] 

108 for idim in range(1, ndim): 

109 wn = wn * ws_pimped[idim] 

110 

111 result = wn * self.f 

112 for i in range(ndim): 

113 result = num.sum(result, axis=-1) 

114 

115 return result 

116 else: 

117 result = num.empty(x.shape[0], dtype=float) 

118 result[:] = None 

119 return result 

120 

121 def slice(self, x): 

122 x = num.array(x, dtype=float) 

123 x_mask = not_(num.isfinite(x)) 

124 x_ = x.copy() 

125 x_[x_mask] = 0.0 

126 return [ 

127 cell for cell in self.children if all_(or_( 

128 x_mask, 

129 and_( 

130 cell.xbounds[:, 0] <= x_, 

131 x_ <= cell.xbounds[:, 1])))] 

132 

133 def plot_rects(self, axes, x, dims): 

134 if self.children: 

135 for cell in self.slice(x): 

136 cell.plot_rects(axes, x, dims) 

137 

138 else: 

139 points = [] 

140 for iy, ix in ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)): 

141 points.append( 

142 (self.xbounds[dims[0], iy], self.xbounds[dims[1], ix])) 

143 

144 points = num.transpose(points) 

145 axes.plot(points[1], points[0], color=(0.1, 0.1, 0.0, 0.1)) 

146 

147 def check_holes(self): 

148 ''' 

149 Check if :py:class:`Cell` or its children contain NaNs. 

150 ''' 

151 if self.children: 

152 return any([child.check_holes() for child in self.children]) 

153 else: 

154 return num.any(num.isnan(self.f)) 

155 

156 def plot_2d(self, axes, x, dims): 

157 idims = num.array(dims) 

158 self.plot_rects(axes, x, dims) 

159 coords = [ 

160 num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d)) 

161 for (xb, d) in zip(self.xbounds[idims, :], self.tree.xtols[idims])] 

162 

163 npoints = coords[0].size * coords[1].size 

164 g = num.meshgrid(*coords[::-1])[::-1] 

165 points = num.empty((npoints, self.tree.ndim), dtype=float) 

166 for idim in range(self.tree.ndim): 

167 try: 

168 idimout = dims.index(idim) 

169 points[:, idim] = g[idimout].ravel() 

170 except ValueError: 

171 points[:, idim] = x[idim] 

172 

173 fi = num.empty((coords[0].size, coords[1].size), dtype=float) 

174 fi_r = fi.ravel() 

175 fi_r[...] = self.interpolate_many(points) 

176 

177 if num.any(num.isnan(fi)): 

178 logger.warning('') 

179 if any_(num.isfinite(fi)): 

180 fi = num.ma.masked_invalid(fi) 

181 axes.imshow( 

182 fi, origin='lower', 

183 extent=[coords[1].min(), coords[1].max(), 

184 coords[0].min(), coords[0].max()], 

185 interpolation='nearest', 

186 aspect='auto', 

187 cmap='RdYlBu') 

188 

189 def plot_1d(self, axes, x, dim): 

190 xb = self.xbounds[dim] 

191 d = self.tree.xtols[dim] 

192 coords = num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d)) 

193 

194 npoints = coords.size 

195 points = num.empty((npoints, self.tree.ndim), dtype=float) 

196 for idim in range(self.tree.ndim): 

197 if idim == dim: 

198 points[:, idim] = coords 

199 else: 

200 points[:, idim] = x[idim] 

201 

202 fi = self.interpolate_many(points) 

203 if any_(num.isfinite(fi)): 

204 fi = num.ma.masked_invalid(fi) 

205 axes.plot(coords, fi) 

206 

207 def __iter__(self): 

208 yield self 

209 for c in self.children: 

210 for x in c: 

211 yield x 

212 

213 def dump(self, file): 

214 self.index.astype('<i4').tofile(file) 

215 self.f.astype('<f8').tofile(file) 

216 for c in self.children: 

217 c.dump(file) 

218 

219 

220def bread(f, fmt): 

221 s = f.read(struct.calcsize(fmt)) 

222 return struct.unpack(fmt, s) 

223 

224 

225class SPTree(object): 

226 ''' 

227 N-dimensional space partitioning interpolator. 

228 

229 :param f: callable function ``f(x)`` where ``x`` is a vector of size ``n`` 

230 :param ftol: target accuracy ``|f_interp(x) - f(x)| <= ftol`` 

231 :param xbounds: bounds of ``x``, shape ``(n, 2)`` 

232 :param xtols: target coarsenesses in ``x``, vector of size ``n`` 

233 :param addargs: additional arguments to pass to ``f`` 

234 ''' 

235 

236 def __init__(self, f=None, ftol=None, xbounds=None, xtols=None, 

237 filename=None, addargs=()): 

238 

239 if filename is None: 

240 assert all(v is not None for v in (f, ftol, xbounds, xtols)) 

241 

242 self.f = f 

243 self.ftol = float(ftol) 

244 self.f_values = {} 

245 self.ncells = 0 

246 self.addargs = addargs 

247 

248 self.xbounds = num.asarray(xbounds, dtype=float) 

249 assert self.xbounds.ndim == 2 

250 assert self.xbounds.shape[1] == 2 

251 self.ndim = self.xbounds.shape[0] 

252 

253 self.xtols = num.asarray(xtols, dtype=float) 

254 assert self.xtols.ndim == 1 and self.xtols.size == self.ndim 

255 

256 self.maxdepths = num.ceil(num.log2( 

257 num.maximum( 

258 1.0, 

259 (self.xbounds[:, 1] - self.xbounds[:, 0]) / self.xtols) 

260 )).astype(int) 

261 

262 self.root = None 

263 self.ones_int = num.ones(self.ndim, dtype=int) 

264 

265 cc = num.ix_(*[num.arange(3)]*self.ndim) 

266 w = num.zeros([3]*self.ndim + [self.ndim, 2]) 

267 for i, c in enumerate(cc): 

268 w[..., i, 0] = (2-c)*0.5 

269 w[..., i, 1] = c*0.5 

270 

271 self.pointmaker = w 

272 self.pointmaker_mask = num.sum(w[..., 0] == 0.5, axis=-1) != 0 

273 self.pointmaker_masked = w[self.pointmaker_mask] 

274 

275 self.nothing_found_yet = True 

276 

277 self.root = Cell(self, self.ones_int) 

278 self.ncells += 1 

279 

280 self.fraction_bad = 0.0 

281 self.nbad = 0 

282 self.cells_to_continue = [] 

283 for clipdepth in range(0, num.max(self.maxdepths)+1): 

284 self.clipdepth = clipdepth 

285 self.tested = 0 

286 if self.clipdepth == 0: 

287 self._fill(self.root) 

288 else: 

289 self._continue_fill() 

290 

291 self.status() 

292 

293 if not self.cells_to_continue: 

294 break 

295 

296 else: 

297 self._load(filename) 

298 

299 def status(self): 

300 perc = (1.0-self.fraction_bad)*100 

301 s = '%6.1f%%' % perc 

302 

303 if self.fraction_bad != 0.0 and s == ' 100.0%': 

304 s = '~100.0%' 

305 

306 logger.info('at level %2i: %s covered, %6i cell%s' % ( 

307 self.clipdepth, s, self.ncells, ['s', ''][self.ncells == 1])) 

308 

309 def __iter__(self): 

310 return iter(self.root) 

311 

312 def __len__(self): 

313 return self.ncells 

314 

315 def dump(self, filename): 

316 with open(filename, 'wb') as file: 

317 version = 1 

318 file.write(b'SPITREE ') 

319 file.write(struct.pack( 

320 '<QQQd', version, self.ndim, self.ncells, self.ftol)) 

321 self.xbounds.astype('<f8').tofile(file) 

322 self.xtols.astype('<f8').tofile(file) 

323 self.root.dump(file) 

324 

325 def _load(self, filename): 

326 with open(filename, 'rb') as file: 

327 marker, version, self.ndim, self.ncells, self.ftol = bread( 

328 file, '<8sQQQd') 

329 assert marker == b'SPITREE ' 

330 assert version == 1 

331 self.xbounds = num.fromfile( 

332 file, dtype='<f8', count=self.ndim*2).reshape(self.ndim, 2) 

333 self.xtols = num.fromfile( 

334 file, dtype='<f8', count=self.ndim) 

335 

336 path = [] 

337 for icell in range(self.ncells): 

338 index = num.fromfile( 

339 file, dtype='<i4', count=self.ndim) 

340 f = num.fromfile( 

341 file, dtype='<f8', count=2**self.ndim).reshape( 

342 [2]*self.ndim) 

343 

344 cell = Cell(self, index, f) 

345 if not path: 

346 self.root = cell 

347 path.append(cell) 

348 

349 else: 

350 while not any_(path[-1].index == (cell.index >> 1)): 

351 path.pop() 

352 

353 path[-1].children.append(cell) 

354 path.append(cell) 

355 

356 def _f_cached(self, x): 

357 return getset( 

358 self.f_values, tuple(float(xx) for xx in x), self.f, self.addargs) 

359 

360 def interpolate(self, x): 

361 x = num.asarray(x, dtype=float) 

362 assert x.ndim == 1 and x.size == self.ndim 

363 if not all_(and_(self.xbounds[:, 0] <= x, x <= self.xbounds[:, 1])): 

364 raise OutOfBounds() 

365 

366 return self.root.interpolate(x) 

367 

368 def __call__(self, x): 

369 return self.interpolate(x) 

370 

371 def interpolate_many(self, x): 

372 return self.root.interpolate_many(x) 

373 

374 def _continue_fill(self): 

375 cells_to_continue, self.cells_to_continue = self.cells_to_continue, [] 

376 for cell in cells_to_continue: 

377 self._deepen_cell(cell) 

378 

379 def _fill(self, cell): 

380 

381 self.tested += 1 

382 xtestpoints = num.sum(cell.xbounds * self.pointmaker_masked, axis=-1) 

383 

384 fis = cell.interpolate_many(xtestpoints) 

385 fes = num.array( 

386 [self._f_cached(x) for x in xtestpoints], dtype=float) 

387 

388 iffes = num.isfinite(fes) 

389 iffis = num.isfinite(fis) 

390 works = iffes == iffis 

391 iif = num.logical_and(iffes, iffis) 

392 

393 works[iif] *= num.abs(fes[iif] - fis[iif]) < self.ftol 

394 

395 nundef = num.sum(not_(num.isfinite(fes))) + \ 

396 num.sum(not_(num.isfinite(cell.f))) 

397 

398 some_undef = 0 < nundef < (xtestpoints.shape[0] + cell.f.size) 

399 

400 if any_(works): 

401 self.nothing_found_yet = False 

402 

403 if not all_(works) or some_undef or self.nothing_found_yet: 

404 deepen = self.ones_int.copy() 

405 if not some_undef: 

406 works_full = num.ones([3]*self.ndim, dtype=bool) 

407 works_full[self.pointmaker_mask] = works 

408 for idim in range(self.ndim): 

409 dimcorners = [slice(None, None, 2)] * self.ndim 

410 dimcorners[idim] = 1 

411 if all_(works_full[tuple(dimcorners)]): 

412 deepen[idim] = 0 

413 

414 if not any_(deepen): 

415 deepen = self.ones_int 

416 

417 deepen = num.where( 

418 cell.depths + deepen > self.maxdepths, 0, deepen) 

419 

420 cell.deepen = deepen 

421 

422 if any_(deepen) and all_(cell.depths + deepen <= self.clipdepth): 

423 self._deepen_cell(cell) 

424 else: 

425 if any_(deepen): 

426 self.cells_to_continue.append(cell) 

427 

428 cell.bad = True 

429 self.fraction_bad += num.prod(1.0/2**cell.depths) 

430 self.nbad += 1 

431 

432 def _deepen_cell(self, cell): 

433 if cell.bad: 

434 self.fraction_bad -= num.prod(1.0/2**cell.depths) 

435 self.nbad -= 1 

436 cell.bad = False 

437 

438 for iadd in num.ndindex(*(cell.deepen+1)): 

439 index_child = (cell.index << cell.deepen) + iadd 

440 child = Cell(self, index_child) 

441 self.ncells += 1 

442 cell.children.append(child) 

443 self._fill(child) 

444 

445 def check_holes(self): 

446 ''' 

447 Check for NaNs in :py:class:`SPTree` 

448 ''' 

449 return self.root.check_holes() 

450 

451 def plot_2d(self, axes=None, x=None, dims=None): 

452 assert self.ndim >= 2 

453 

454 if x is None: 

455 x = num.zeros(self.ndim) 

456 x[-2:] = None 

457 

458 x = num.asarray(x, dtype=float) 

459 if dims is None: 

460 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)] 

461 

462 assert len(dims) == 2 

463 

464 plt = None 

465 if axes is None: 

466 from matplotlib import pyplot as plt 

467 axes = plt.gca() 

468 

469 self.root.plot_2d(axes, x, dims) 

470 

471 axes.set_xlabel('Dim %i' % dims[1]) 

472 axes.set_ylabel('Dim %i' % dims[0]) 

473 

474 if plt: 

475 plt.show() 

476 

477 def plot_1d(self, axes=None, x=None, dims=None): 

478 

479 if x is None: 

480 x = num.zeros(self.ndim) 

481 x[-1:] = None 

482 

483 x = num.asarray(x, dtype=float) 

484 if dims is None: 

485 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)] 

486 

487 assert len(dims) == 1 

488 

489 plt = None 

490 if axes is None: 

491 from matplotlib import pyplot as plt 

492 axes = plt.gca() 

493 

494 self.root.plot_1d(axes, x, dims[0]) 

495 

496 axes.set_xlabel('Dim %i' % dims[0]) 

497 

498 if plt: 

499 plt.show() 

500 

501 

502def getset(d, k, f, addargs): 

503 try: 

504 return d[k] 

505 except KeyError: 

506 v = d[k] = f(k, *addargs) 

507 return v 

508 

509 

510def nditer_outer(x): 

511 add = [] 

512 if x[-1] is None: 

513 x_ = x[:-1] 

514 add = [None] 

515 else: 

516 x_ = x 

517 

518 return num.nditer( 

519 x, 

520 op_axes=(num.identity(len(x_), dtype=int)-1).tolist() + add) 

521 

522 

523if __name__ == '__main__': 

524 logging.basicConfig(level=logging.INFO) 

525 

526 def f(x): 

527 x0 = num.array([0.5, 0.5, 0.5]) 

528 r = 0.5 

529 if num.sqrt(num.sum((x-x0)**2)) < r: 

530 

531 return x[2]**4 + x[1] 

532 

533 return None 

534 

535 tree = SPTree(f, 0.01, [[0., 1.], [0., 1.], [0., 1.]], [0.025, 0.05, 0.1]) 

536 

537 import tempfile 

538 import os 

539 fid, fn = tempfile.mkstemp() 

540 tree.dump(fn) 

541 tree = SPTree(filename=fn) 

542 os.unlink(fn) 

543 

544 from matplotlib import pyplot as plt 

545 

546 v = 0.5 

547 axes = plt.subplot(2, 2, 1) 

548 tree.plot_2d(axes, x=(v, None, None)) 

549 axes = plt.subplot(2, 2, 2) 

550 tree.plot_2d(axes, x=(None, v, None)) 

551 axes = plt.subplot(2, 2, 3) 

552 tree.plot_2d(axes, x=(None, None, v)) 

553 

554 axes = plt.subplot(2, 2, 4) 

555 tree.plot_1d(axes, x=(v, v, None)) 

556 

557 plt.show()