Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/spit.py: 88%

337 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-23 12:34 +0000

1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6''' 

7N-dimensional space partitioning multi-linear interpolator. 

8''' 

9 

10import struct 

11import logging 

12import numpy as num 

13 

14try: 

15 range = xrange 

16except NameError: 

17 pass 

18 

19logger = logging.getLogger('pyrocko.spit') 

20 

21or_ = num.logical_or 

22and_ = num.logical_and 

23not_ = num.logical_not 

24all_ = num.all 

25any_ = num.any 

26 

27 

28class OutOfBounds(Exception): 

29 pass 

30 

31 

32class Cell(object): 

33 __slots__ = ( 

34 "tree", "index", "depths", "bad", "children", "xbounds", "deepen", 

35 "a", "b", "f") 

36 

37 def __init__(self, tree, index, f=None): 

38 self.tree = tree 

39 self.index = index 

40 self.depths = num.log2(index).astype(int) 

41 self.bad = False 

42 self.children = [] 

43 n = 2**self.depths 

44 i = self.index - n 

45 delta = (self.tree.xbounds[:, 1] - self.tree.xbounds[:, 0])/n 

46 xmin = self.tree.xbounds[:, 0] 

47 self.xbounds = self.tree.xbounds.copy() 

48 self.xbounds[:, 0] = xmin + i * delta 

49 self.xbounds[:, 1] = xmin + (i+1) * delta 

50 self.a = self.xbounds[:, ::-1].copy() 

51 self.b = self.a.copy() 

52 self.b[:, 1] = self.xbounds[:, 1] - self.xbounds[:, 0] 

53 self.b[:, 0] = - self.b[:, 1] 

54 

55 self.a[:, 0] += (self.b[:, 0] == 0.0)*0.5 

56 self.a[:, 1] -= (self.b[:, 1] == 0.0)*0.5 

57 self.b[:, 0] -= (self.b[:, 0] == 0.0) 

58 self.b[:, 1] += (self.b[:, 1] == 0.0) 

59 

60 if f is None: 

61 it = nditer_outer(tuple(self.xbounds) + (None,)) 

62 for vvv in it: 

63 vvv[-1][...] = self.tree._f_cached(vvv[:-1]) 

64 

65 self.f = it.operands[-1] 

66 else: 

67 self.f = f 

68 

69 def interpolate(self, x): 

70 if self.children: 

71 for cell in self.children: 

72 if all_(and_(cell.xbounds[:, 0] <= x, 

73 x <= cell.xbounds[:, 1])): 

74 return cell.interpolate(x) 

75 

76 if all_(num.isfinite(self.f)): 

77 ws = (x[:, num.newaxis] - self.a)/self.b 

78 wn = num.multiply.reduce( 

79 num.array(num.ix_(*ws), dtype=object)) 

80 return num.sum(self.f * wn) 

81 else: 

82 return None 

83 

84 def interpolate_many(self, x): 

85 ndim = self.tree.ndim 

86 ndim_range = tuple(range(ndim)) 

87 if self.children: 

88 result = num.full(x.shape[0], fill_value=num.nan) 

89 for cell in self.children: 

90 indices = num.where( 

91 ndim == num.sum(and_( 

92 cell.xbounds[:, 0] <= x, 

93 x <= cell.xbounds[:, 1]), axis=-1))[0] 

94 

95 if indices.size != 0: 

96 result[indices] = cell.interpolate_many(x[indices]) 

97 return result 

98 

99 if all_(num.isfinite(self.f)): 

100 ws = (x[..., num.newaxis] - self.a)/self.b 

101 npoints = ws.shape[0] 

102 ws_pimped = [ws[:, i, :] for i in ndim_range] 

103 for i in ndim_range: 

104 s = [npoints] + [1] * ndim 

105 s[1+i] = 2 

106 ws_pimped[i].shape = tuple(s) 

107 

108 wn = ws_pimped[0] 

109 for idim in ndim_range[1:]: 

110 wn = wn * ws_pimped[idim] 

111 

112 result = wn * self.f 

113 for i in ndim_range: 

114 result = num.sum(result, axis=-1) 

115 

116 return result 

117 else: 

118 return num.full(x.shape[0], fill_value=num.nan) 

119 

120 def slice(self, x): 

121 x = num.array(x, dtype=float) 

122 x_mask = not_(num.isfinite(x)) 

123 x_ = x.copy() 

124 x_[x_mask] = 0.0 

125 return [ 

126 cell for cell in self.children if all_(or_( 

127 x_mask, 

128 and_( 

129 cell.xbounds[:, 0] <= x_, 

130 x_ <= cell.xbounds[:, 1])))] 

131 

132 def plot_rects(self, axes, x, dims): 

133 if self.children: 

134 for cell in self.slice(x): 

135 cell.plot_rects(axes, x, dims) 

136 

137 else: 

138 points = [] 

139 for iy, ix in ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)): 

140 points.append( 

141 (self.xbounds[dims[0], iy], self.xbounds[dims[1], ix])) 

142 

143 points = num.transpose(points) 

144 axes.plot(points[1], points[0], color=(0.1, 0.1, 0.0, 0.1)) 

145 

146 def check_holes(self): 

147 ''' 

148 Check if :py:class:`Cell` or its children contain NaNs. 

149 ''' 

150 if self.children: 

151 return any([child.check_holes() for child in self.children]) 

152 else: 

153 return num.any(num.isnan(self.f)) 

154 

155 def plot_2d(self, axes, x, dims): 

156 idims = num.array(dims) 

157 self.plot_rects(axes, x, dims) 

158 coords = [ 

159 num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d)) 

160 for (xb, d) in zip(self.xbounds[idims, :], self.tree.xtols[idims])] 

161 

162 npoints = coords[0].size * coords[1].size 

163 g = num.meshgrid(*coords[::-1])[::-1] 

164 points = num.empty((npoints, self.tree.ndim), dtype=float) 

165 for idim in range(self.tree.ndim): 

166 try: 

167 idimout = dims.index(idim) 

168 points[:, idim] = g[idimout].ravel() 

169 except ValueError: 

170 points[:, idim] = x[idim] 

171 

172 fi = num.empty((coords[0].size, coords[1].size), dtype=float) 

173 fi_r = fi.ravel() 

174 fi_r[...] = self.interpolate_many(points) 

175 

176 if num.any(num.isnan(fi)): 

177 logger.warning('') 

178 if any_(num.isfinite(fi)): 

179 fi = num.ma.masked_invalid(fi) 

180 axes.imshow( 

181 fi, origin='lower', 

182 extent=[coords[1].min(), coords[1].max(), 

183 coords[0].min(), coords[0].max()], 

184 interpolation='nearest', 

185 aspect='auto', 

186 cmap='RdYlBu') 

187 

188 def plot_1d(self, axes, x, dim): 

189 xb = self.xbounds[dim] 

190 d = self.tree.xtols[dim] 

191 coords = num.linspace(xb[0], xb[1], 1+int((xb[1]-xb[0])/d)) 

192 

193 npoints = coords.size 

194 points = num.empty((npoints, self.tree.ndim), dtype=float) 

195 for idim in range(self.tree.ndim): 

196 if idim == dim: 

197 points[:, idim] = coords 

198 else: 

199 points[:, idim] = x[idim] 

200 

201 fi = self.interpolate_many(points) 

202 if any_(num.isfinite(fi)): 

203 fi = num.ma.masked_invalid(fi) 

204 axes.plot(coords, fi) 

205 

206 def __iter__(self): 

207 yield self 

208 for c in self.children: 

209 for x in c: 

210 yield x 

211 

212 def dump(self, file): 

213 self.index.astype('<i4').tofile(file) 

214 self.f.astype('<f8').tofile(file) 

215 for c in self.children: 

216 c.dump(file) 

217 

218 

219def bread(f, fmt): 

220 s = f.read(struct.calcsize(fmt)) 

221 return struct.unpack(fmt, s) 

222 

223 

224class SPTree(object): 

225 ''' 

226 N-dimensional space partitioning interpolator. 

227 

228 :param f: callable function ``f(x)`` where ``x`` is a vector of size ``n`` 

229 :param ftol: target accuracy ``|f_interp(x) - f(x)| <= ftol`` 

230 :param xbounds: bounds of ``x``, shape ``(n, 2)`` 

231 :param xtols: target coarsenesses in ``x``, vector of size ``n`` 

232 :param addargs: additional arguments to pass to ``f`` 

233 ''' 

234 

235 def __init__(self, f=None, ftol=None, xbounds=None, xtols=None, 

236 filename=None, addargs=()): 

237 

238 if filename is None: 

239 assert all(v is not None for v in (f, ftol, xbounds, xtols)) 

240 

241 self.f = f 

242 self.ftol = float(ftol) 

243 self.f_values = {} 

244 self.ncells = 0 

245 self.addargs = addargs 

246 

247 self.xbounds = num.asarray(xbounds, dtype=float) 

248 assert self.xbounds.ndim == 2 

249 assert self.xbounds.shape[1] == 2 

250 self.ndim = self.xbounds.shape[0] 

251 

252 self.xtols = num.asarray(xtols, dtype=float) 

253 assert self.xtols.ndim == 1 and self.xtols.size == self.ndim 

254 

255 self.maxdepths = num.ceil(num.log2( 

256 num.maximum( 

257 1.0, 

258 (self.xbounds[:, 1] - self.xbounds[:, 0]) / self.xtols) 

259 )).astype(int) 

260 

261 self.root = None 

262 self.ones_int = num.ones(self.ndim, dtype=int) 

263 

264 cc = num.ix_(*[num.arange(3)]*self.ndim) 

265 w = num.zeros([3]*self.ndim + [self.ndim, 2]) 

266 for i, c in enumerate(cc): 

267 w[..., i, 0] = (2-c)*0.5 

268 w[..., i, 1] = c*0.5 

269 

270 self.pointmaker = w 

271 self.pointmaker_mask = num.sum(w[..., 0] == 0.5, axis=-1) != 0 

272 self.pointmaker_masked = w[self.pointmaker_mask] 

273 

274 self.nothing_found_yet = True 

275 

276 self.root = Cell(self, self.ones_int) 

277 self.ncells += 1 

278 

279 self.fraction_bad = 0.0 

280 self.nbad = 0 

281 self.cells_to_continue = [] 

282 for clipdepth in range(0, num.max(self.maxdepths)+1): 

283 self.clipdepth = clipdepth 

284 self.tested = 0 

285 if self.clipdepth == 0: 

286 self._fill(self.root) 

287 else: 

288 self._continue_fill() 

289 

290 self.status() 

291 

292 if not self.cells_to_continue: 

293 break 

294 

295 else: 

296 self._load(filename) 

297 

298 def status(self): 

299 perc = (1.0-self.fraction_bad)*100 

300 s = '%6.1f%%' % perc 

301 

302 if self.fraction_bad != 0.0 and s == ' 100.0%': 

303 s = '~100.0%' 

304 

305 logger.info('at level %2i: %s covered, %6i cell%s' % ( 

306 self.clipdepth, s, self.ncells, ['s', ''][self.ncells == 1])) 

307 

308 def __iter__(self): 

309 return iter(self.root) 

310 

311 def __len__(self): 

312 return self.ncells 

313 

314 def dump(self, filename): 

315 with open(filename, 'wb') as file: 

316 version = 1 

317 file.write(b'SPITREE ') 

318 file.write(struct.pack( 

319 '<QQQd', version, self.ndim, self.ncells, self.ftol)) 

320 self.xbounds.astype('<f8').tofile(file) 

321 self.xtols.astype('<f8').tofile(file) 

322 self.root.dump(file) 

323 

324 def _load(self, filename): 

325 with open(filename, 'rb') as file: 

326 marker, version, self.ndim, self.ncells, self.ftol = bread( 

327 file, '<8sQQQd') 

328 assert marker == b'SPITREE ' 

329 assert version == 1 

330 self.xbounds = num.fromfile( 

331 file, dtype='<f8', count=self.ndim*2).reshape(self.ndim, 2) 

332 self.xtols = num.fromfile( 

333 file, dtype='<f8', count=self.ndim) 

334 

335 path = [] 

336 for icell in range(self.ncells): 

337 index = num.fromfile( 

338 file, dtype='<i4', count=self.ndim) 

339 f = num.fromfile( 

340 file, dtype='<f8', count=2**self.ndim).reshape( 

341 [2]*self.ndim) 

342 

343 cell = Cell(self, index, f) 

344 if not path: 

345 self.root = cell 

346 path.append(cell) 

347 

348 else: 

349 while not any_(path[-1].index == (cell.index >> 1)): 

350 path.pop() 

351 

352 path[-1].children.append(cell) 

353 path.append(cell) 

354 

355 def _f_cached(self, x): 

356 return getset( 

357 self.f_values, tuple(float(xx) for xx in x), self.f, self.addargs) 

358 

359 def interpolate(self, x): 

360 x = num.asarray(x, dtype=float) 

361 assert x.ndim == 1 and x.size == self.ndim 

362 if not all_(and_(self.xbounds[:, 0] <= x, x <= self.xbounds[:, 1])): 

363 raise OutOfBounds() 

364 

365 return self.root.interpolate(x) 

366 

367 def __call__(self, x): 

368 return self.interpolate(x) 

369 

370 def interpolate_many(self, x): 

371 return self.root.interpolate_many(x) 

372 

373 def _continue_fill(self): 

374 cells_to_continue, self.cells_to_continue = self.cells_to_continue, [] 

375 for cell in cells_to_continue: 

376 self._deepen_cell(cell) 

377 

378 def _fill(self, cell): 

379 

380 self.tested += 1 

381 xtestpoints = num.sum(cell.xbounds * self.pointmaker_masked, axis=-1) 

382 

383 fis = cell.interpolate_many(xtestpoints) 

384 fes = num.array( 

385 [self._f_cached(x) for x in xtestpoints], dtype=float) 

386 

387 iffes = num.isfinite(fes) 

388 iffis = num.isfinite(fis) 

389 works = iffes == iffis 

390 iif = num.logical_and(iffes, iffis) 

391 

392 works[iif] *= num.abs(fes[iif] - fis[iif]) < self.ftol 

393 

394 nundef = num.sum(not_(num.isfinite(fes))) + \ 

395 num.sum(not_(num.isfinite(cell.f))) 

396 

397 some_undef = 0 < nundef < (xtestpoints.shape[0] + cell.f.size) 

398 

399 if any_(works): 

400 self.nothing_found_yet = False 

401 

402 if not all_(works) or some_undef or self.nothing_found_yet: 

403 deepen = self.ones_int.copy() 

404 if not some_undef: 

405 works_full = num.ones([3]*self.ndim, dtype=bool) 

406 works_full[self.pointmaker_mask] = works 

407 for idim in range(self.ndim): 

408 dimcorners = [slice(None, None, 2)] * self.ndim 

409 dimcorners[idim] = 1 

410 if all_(works_full[tuple(dimcorners)]): 

411 deepen[idim] = 0 

412 

413 if not any_(deepen): 

414 deepen = self.ones_int 

415 

416 deepen = num.where( 

417 cell.depths + deepen > self.maxdepths, 0, deepen) 

418 

419 cell.deepen = deepen 

420 

421 if any_(deepen) and all_(cell.depths + deepen <= self.clipdepth): 

422 self._deepen_cell(cell) 

423 else: 

424 if any_(deepen): 

425 self.cells_to_continue.append(cell) 

426 

427 cell.bad = True 

428 self.fraction_bad += num.prod(1.0/2**cell.depths) 

429 self.nbad += 1 

430 

431 def _deepen_cell(self, cell): 

432 if cell.bad: 

433 self.fraction_bad -= num.prod(1.0/2**cell.depths) 

434 self.nbad -= 1 

435 cell.bad = False 

436 

437 for iadd in num.ndindex(*(cell.deepen+1)): 

438 index_child = (cell.index << cell.deepen) + iadd 

439 child = Cell(self, index_child) 

440 self.ncells += 1 

441 cell.children.append(child) 

442 self._fill(child) 

443 

444 def check_holes(self): 

445 ''' 

446 Check for NaNs in :py:class:`SPTree` 

447 ''' 

448 return self.root.check_holes() 

449 

450 def plot_2d(self, axes=None, x=None, dims=None): 

451 assert self.ndim >= 2 

452 

453 if x is None: 

454 x = num.zeros(self.ndim) 

455 x[-2:] = None 

456 

457 x = num.asarray(x, dtype=float) 

458 if dims is None: 

459 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)] 

460 

461 assert len(dims) == 2 

462 

463 plt = None 

464 if axes is None: 

465 from matplotlib import pyplot as plt 

466 axes = plt.gca() 

467 

468 self.root.plot_2d(axes, x, dims) 

469 

470 axes.set_xlabel('Dim %i' % dims[1]) 

471 axes.set_ylabel('Dim %i' % dims[0]) 

472 

473 if plt: 

474 plt.show() 

475 

476 def plot_1d(self, axes=None, x=None, dims=None): 

477 

478 if x is None: 

479 x = num.zeros(self.ndim) 

480 x[-1:] = None 

481 

482 x = num.asarray(x, dtype=float) 

483 if dims is None: 

484 dims = [i for (i, v) in enumerate(x) if not num.isfinite(v)] 

485 

486 assert len(dims) == 1 

487 

488 plt = None 

489 if axes is None: 

490 from matplotlib import pyplot as plt 

491 axes = plt.gca() 

492 

493 self.root.plot_1d(axes, x, dims[0]) 

494 

495 axes.set_xlabel('Dim %i' % dims[0]) 

496 

497 if plt: 

498 plt.show() 

499 

500 

501def getset(d, k, f, addargs): 

502 try: 

503 return d[k] 

504 except KeyError: 

505 v = d[k] = f(k, *addargs) 

506 return v 

507 

508 

509def nditer_outer(x): 

510 add = [] 

511 if x[-1] is None: 

512 x_ = x[:-1] 

513 add = [None] 

514 else: 

515 x_ = x 

516 

517 return num.nditer( 

518 x, 

519 op_axes=(num.identity(len(x_), dtype=int)-1).tolist() + add) 

520 

521 

522if __name__ == '__main__': 

523 logging.basicConfig(level=logging.INFO) 

524 

525 def f(x): 

526 x0 = num.array([0.5, 0.5, 0.5]) 

527 r = 0.5 

528 if num.sqrt(num.sum((x-x0)**2)) < r: 

529 

530 return x[2]**4 + x[1] 

531 

532 return None 

533 

534 tree = SPTree(f, 0.01, [[0., 1.], [0., 1.], [0., 1.]], [0.025, 0.05, 0.1]) 

535 

536 import tempfile 

537 import os 

538 fid, fn = tempfile.mkstemp() 

539 tree.dump(fn) 

540 tree = SPTree(filename=fn) 

541 os.unlink(fn) 

542 

543 from matplotlib import pyplot as plt 

544 

545 v = 0.5 

546 axes = plt.subplot(2, 2, 1) 

547 tree.plot_2d(axes, x=(v, None, None)) 

548 axes = plt.subplot(2, 2, 2) 

549 tree.plot_2d(axes, x=(None, v, None)) 

550 axes = plt.subplot(2, 2, 3) 

551 tree.plot_2d(axes, x=(None, None, v)) 

552 

553 axes = plt.subplot(2, 2, 4) 

554 tree.plot_1d(axes, x=(v, v, None)) 

555 

556 plt.show()