Coverage for /usr/local/lib/python3.11/dist-packages/grond/problems/base.py: 79%

757 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-25 10:12 +0000

1''' 

2Base classes for Grond's problem definition and the model history container. 

3 

4Common behaviour of all source models offered by Grond is implemented here. 

5Source model specific details are implemented in the respective submodules. 

6''' 

7 

8import numpy as num 

9import math 

10import copy 

11import logging 

12import os.path as op 

13import os 

14import time 

15import struct 

16import threading 

17 

18from pyrocko import gf, util, guts, orthodrome as pod 

19from pyrocko.guts import Object, String, List, Dict, Int 

20 

21from grond.meta import ADict, Parameter, GrondError, xjoin, Forbidden, \ 

22 StringID, has_get_plot_classes 

23from ..targets import MisfitResult, MisfitTarget, TargetGroup, \ 

24 WaveformMisfitTarget, SatelliteMisfitTarget, GNSSCampaignMisfitTarget 

25 

26from grond import stats 

27 

28from grond.version import __version__ 

29 

30guts_prefix = 'grond' 

31logger = logging.getLogger('grond.problems.base') 

32km = 1e3 

33as_km = dict(scale_factor=km, scale_unit='km') 

34 

35g_rstate = num.random.RandomState() 

36 

37 

38def nextpow2(i): 

39 return 2**int(math.ceil(math.log(i) / math.log(2.))) 

40 

41 

42def nextcapacity(i): 

43 return int(math.ceil(i / 1024) * 1024) 

44 

45 

46def correlated_weights(values, weight_matrix): 

47 ''' 

48 Applies correlated weights to values 

49 

50 The resulting weighed values have to be squared! Check out 

51 :meth:`Problem.combine_misfits` for more information. 

52 

53 :param values: Misfits or norms as :class:`numpy.Array` 

54 :param weight: Weight matrix, commonly the inverse of covariance matrix 

55 

56 :returns: :class:`numpy.Array` weighted values 

57 ''' 

58 return num.matmul(values, weight_matrix) 

59 

60 

61class ProblemConfig(Object): 

62 ''' 

63 Base class for config section defining the objective function setup. 

64 

65 Factory for :py:class:`Problem` objects. 

66 ''' 

67 name_template = String.T() 

68 norm_exponent = Int.T(default=2) 

69 nthreads = Int.T(default=1) 

70 

71 def get_problem(self, event, target_groups, targets): 

72 ''' 

73 Instantiate the problem with a given event and targets. 

74 

75 :returns: :py:class:`Problem` object 

76 ''' 

77 raise NotImplementedError 

78 

79 

80@has_get_plot_classes 

81class Problem(Object): 

82 ''' 

83 Base class for objective function setup. 

84 

85 Defines the *problem* to be solved by the optimiser. 

86 ''' 

87 name = String.T() 

88 ranges = Dict.T(String.T(), gf.Range.T()) 

89 dependants = List.T(Parameter.T()) 

90 norm_exponent = Int.T(default=2) 

91 base_source = gf.Source.T(optional=True) 

92 targets = List.T(MisfitTarget.T()) 

93 target_groups = List.T(TargetGroup.T()) 

94 grond_version = String.T(optional=True) 

95 nthreads = Int.T(default=1) 

96 

97 def __init__(self, **kwargs): 

98 Object.__init__(self, **kwargs) 

99 

100 if self.grond_version is None: 

101 self.grond_version = __version__ 

102 

103 self._target_weights = None 

104 self._engine = None 

105 self._family_mask = None 

106 self._rstate_manager = None 

107 

108 if hasattr(self, 'problem_waveform_parameters') and self.has_waveforms: 

109 self.problem_parameters =\ 

110 self.problem_parameters + self.problem_waveform_parameters 

111 

112 unused_parameters = [] 

113 for p in self.problem_parameters: 

114 if p.optional and p._name not in self.ranges.keys(): 

115 unused_parameters.append(p) 

116 

117 for p in unused_parameters: 

118 self.problem_parameters.remove(p) 

119 

120 self.check() 

121 

122 @classmethod 

123 def get_plot_classes(cls): 

124 from . import plot 

125 return plot.get_plot_classes() 

126 

127 def check(self): 

128 paths = set() 

129 for grp in self.target_groups: 

130 if grp.path == 'all': 

131 continue 

132 if grp.path in paths: 

133 raise ValueError('Path %s defined more than once! In %s' 

134 % (grp.path, grp.__class__.__name__)) 

135 paths.add(grp.path) 

136 logger.debug('TargetGroup check OK.') 

137 

138 def copy(self): 

139 o = copy.copy(self) 

140 o._target_weights = None 

141 return o 

142 

143 def set_target_parameter_values(self, x): 

144 nprob = len(self.problem_parameters) 

145 for target in self.targets: 

146 target.set_parameter_values(x[nprob:nprob + target.nparameters]) 

147 nprob += target.nparameters 

148 

149 def get_parameter_dict(self, model, group=None): 

150 params = [(p.name, model[ip]) 

151 for ip, p in enumerate(self.parameters) 

152 if group in p.groups or group is None] 

153 return ADict(params) 

154 

155 def get_parameter_array(self, d): 

156 arr = num.zeros(self.nparameters, dtype=float) 

157 for ip, p in enumerate(self.parameters): 

158 if p.name in d.keys(): 

159 arr[ip] = d[p.name] 

160 return arr 

161 

162 def get_parameter_index(self, param_name): 

163 return {k.name: ik for ik, k in enumerate(self.parameters)}[param_name] 

164 

165 def get_rstate_manager(self): 

166 if self._rstate_manager is None: 

167 self._rstate_manager = RandomStateManager() 

168 return self._rstate_manager 

169 

170 def dump_problem_info(self, dirname): 

171 fn = op.join(dirname, 'problem.yaml') 

172 util.ensuredirs(fn) 

173 guts.dump(self, filename=fn) 

174 

175 def dump_problem_data( 

176 self, dirname, x, misfits, chains=None, 

177 sampler_context=None): 

178 

179 fn = op.join(dirname, 'models') 

180 if not isinstance(x, num.ndarray): 

181 x = num.array(x) 

182 with open(fn, 'ab') as f: 

183 x.astype('<f8').tofile(f) 

184 

185 fn = op.join(dirname, 'misfits') 

186 with open(fn, 'ab') as f: 

187 misfits.astype('<f8').tofile(f) 

188 

189 if chains is not None: 

190 fn = op.join(dirname, 'chains') 

191 with open(fn, 'ab') as f: 

192 chains.astype('<f8').tofile(f) 

193 

194 if sampler_context is not None: 

195 fn = op.join(dirname, 'choices') 

196 with open(fn, 'ab') as f: 

197 num.array(sampler_context, dtype='<i8').tofile(f) 

198 

199 fn = op.join(dirname, 'rstate') 

200 self.get_rstate_manager().save_state(fn) 

201 

202 def name_to_index(self, name): 

203 pnames = [p.name for p in self.combined] 

204 return pnames.index(name) 

205 

206 @property 

207 def parameters(self): 

208 target_parameters = [] 

209 for target in self.targets: 

210 target_parameters.extend(target.target_parameters) 

211 return self.problem_parameters + target_parameters 

212 

213 @property 

214 def parameter_names(self): 

215 return [p.name for p in self.combined] 

216 

217 @property 

218 def dependant_names(self): 

219 return [p.name for p in self.dependants] 

220 

221 @property 

222 def nparameters(self): 

223 return len(self.parameters) 

224 

225 @property 

226 def ntargets(self): 

227 return len(self.targets) 

228 

229 @property 

230 def nwaveform_targets(self): 

231 return len(self.waveform_targets) 

232 

233 @property 

234 def nsatellite_targets(self): 

235 return len(self.satellite_targets) 

236 

237 @property 

238 def ngnss_targets(self): 

239 return len(self.gnss_targets) 

240 

241 @property 

242 def nmisfits(self): 

243 nmisfits = 0 

244 for target in self.targets: 

245 nmisfits += target.nmisfits 

246 return nmisfits 

247 

248 @property 

249 def ndependants(self): 

250 return len(self.dependants) 

251 

252 @property 

253 def ncombined(self): 

254 return len(self.parameters) + len(self.dependants) 

255 

256 @property 

257 def combined(self): 

258 return self.parameters + self.dependants 

259 

260 @property 

261 def satellite_targets(self): 

262 return [t for t in self.targets 

263 if isinstance(t, SatelliteMisfitTarget)] 

264 

265 @property 

266 def gnss_targets(self): 

267 return [t for t in self.targets 

268 if isinstance(t, GNSSCampaignMisfitTarget)] 

269 

270 @property 

271 def waveform_targets(self): 

272 return [t for t in self.targets 

273 if isinstance(t, WaveformMisfitTarget)] 

274 

275 @property 

276 def has_satellite(self): 

277 if self.satellite_targets: 

278 return True 

279 return False 

280 

281 @property 

282 def has_waveforms(self): 

283 if self.waveform_targets: 

284 return True 

285 return False 

286 

287 def set_engine(self, engine): 

288 self._engine = engine 

289 

290 def get_engine(self): 

291 return self._engine 

292 

293 def get_source(self, x): 

294 raise NotImplementedError 

295 

296 def pack(self, source): 

297 raise NotImplementedError 

298 

299 def source_to_x(self, source): 

300 bs = self.base_source 

301 

302 n, e = pod.latlon_to_ne( 

303 bs.lat, bs.lon, 

304 source.effective_lat, source.effective_lon) 

305 

306 source.lat, source.lon = bs.lat, bs.lon 

307 source.north_shift = n 

308 source.east_shift = e 

309 

310 tmin, tmax = self.ranges['time'].start, self.ranges['time'].stop 

311 

312 if (source.time - bs.time < tmin) or (source.time - bs.time > tmax): 

313 rstatem = self.get_rstate_manager() 

314 rstate = rstatem.get_rstate(name='source2x') 

315 

316 source.time = bs.time 

317 source.time += rstate.uniform(low=tmin, high=tmax, size=1) 

318 

319 p = {} 

320 for k in self.base_source.keys(): 

321 val = source[k] 

322 if k == 'time': 

323 p[k] = float(val) 

324 

325 elif k in self.ranges: 

326 val = source[k] 

327 

328 if self.ranges[k].relative == 'add': 

329 val = self.ranges[k].make_relative( 

330 -self.base_source[k], source[k]) 

331 elif self.ranges[k].relative == 'mult': 

332 val = self.ranges[k].make_relative( 

333 1./self.base_source[k], source[k]) 

334 

335 p[k] = float(val) 

336 

337 source = source.clone(**p) 

338 

339 return self.pack(source) 

340 

341 def get_gf_store_ids(self): 

342 return tuple(set([t.store_id for t in self.targets])) 

343 

344 def get_gf_store(self, target): 

345 if not isinstance(target, str): 

346 target = target.store_id 

347 if self.get_engine() is None: 

348 raise GrondError('Cannot get GF Store, modelling is not set up.') 

349 return self.get_engine().get_store(target) 

350 

351 def random_uniform(self, xbounds, rstate, fixed_magnitude=None): 

352 if fixed_magnitude is not None: 

353 raise GrondError( 

354 'Setting fixed magnitude in random model generation not ' 

355 'supported for this type of problem.') 

356 

357 x = rstate.uniform(0., 1., self.nparameters) 

358 x *= (xbounds[:, 1] - xbounds[:, 0]) 

359 x += xbounds[:, 0] 

360 return x 

361 

362 def minimum(self, xbounds, fixed_magnitude=None): 

363 if fixed_magnitude is not None: 

364 raise GrondError( 

365 'Setting fixed magnitude in random model generation not ' 

366 'supported for this type of problem.') 

367 

368 x = xbounds[:, 0] 

369 return x 

370 

371 def maximum(self, xbounds, fixed_magnitude=None): 

372 if fixed_magnitude is not None: 

373 raise GrondError( 

374 'Setting fixed magnitude in random model generation not ' 

375 'supported for this type of problem.') 

376 

377 x = xbounds[:, 1] 

378 return x 

379 

380 def preconstrain(self, x, optimizer=None): 

381 return x 

382 

383 def extract(self, xs, i): 

384 if xs.ndim == 1: 

385 return self.extract(xs[num.newaxis, :], i)[0] 

386 

387 if i < self.nparameters: 

388 return xs[:, i] 

389 else: 

390 return self.make_dependant( 

391 xs, self.dependants[i - self.nparameters].name) 

392 

393 def get_target_weights(self): 

394 if self._target_weights is None: 

395 self._target_weights = num.concatenate( 

396 [target.get_combined_weight() for target in self.targets]) 

397 

398 return self._target_weights 

399 

400 def get_target_residuals(self): 

401 pass 

402 

403 def inter_family_weights(self, ns): 

404 exp, root = self.get_norm_functions() 

405 

406 family, nfamilies = self.get_family_mask() 

407 

408 ws = num.zeros(self.nmisfits) 

409 for ifamily in range(nfamilies): 

410 mask = family == ifamily 

411 ws[mask] = 1.0 / root(num.nansum(exp(ns[mask]))) 

412 

413 return ws 

414 

415 def inter_family_weights2(self, ns): 

416 ''' 

417 :param ns: 2D array with normalization factors ``ns[imodel, itarget]`` 

418 :returns: 2D array ``weights[imodel, itarget]`` 

419 ''' 

420 

421 exp, root = self.get_norm_functions() 

422 family, nfamilies = self.get_family_mask() 

423 

424 ws = num.zeros(ns.shape) 

425 for ifamily in range(nfamilies): 

426 mask = family == ifamily 

427 ws[:, mask] = (1.0 / root( 

428 num.nansum(exp(ns[:, mask]), axis=1)))[:, num.newaxis] 

429 

430 return ws 

431 

432 def get_reference_model(self): 

433 model = num.zeros(self.nparameters) 

434 model_source_params = self.pack(self.base_source) 

435 model[:model_source_params.size] = model_source_params 

436 return model 

437 

438 def get_parameter_bounds(self): 

439 out = [] 

440 for p in self.problem_parameters: 

441 r = self.ranges[p.name] 

442 out.append((r.start, r.stop)) 

443 

444 for target in self.targets: 

445 for p in target.target_parameters: 

446 r = target.target_ranges[p.name_nogroups] 

447 out.append((r.start, r.stop)) 

448 

449 return num.array(out, dtype=float) 

450 

451 def get_dependant_bounds(self): 

452 return num.zeros((0, 2)) 

453 

454 def get_combined_bounds(self): 

455 return num.vstack(( 

456 self.get_parameter_bounds(), 

457 self.get_dependant_bounds())) 

458 

459 def raise_invalid_norm_exponent(self): 

460 raise GrondError('Invalid norm exponent: %f' % self.norm_exponent) 

461 

462 def get_norm_functions(self): 

463 if self.norm_exponent == 2: 

464 def sqr(x): 

465 return x**2 

466 

467 return sqr, num.sqrt 

468 

469 elif self.norm_exponent == 1: 

470 def noop(x): 

471 return x 

472 

473 return noop, num.abs 

474 

475 else: 

476 self.raise_invalid_norm_exponent() 

477 

478 def combine_misfits( 

479 self, misfits, 

480 extra_weights=None, 

481 extra_residuals=None, 

482 extra_correlated_weights=dict(), 

483 get_contributions=False): 

484 ''' 

485 Combine misfit contributions (residuals) to global or bootstrap misfits 

486 

487 :param misfits: 3D array ``misfits[imodel, iresidual, 0]`` are the 

488 misfit contributions (residuals) ``misfits[imodel, iresidual, 1]`` 

489 are the normalisation contributions. It is also possible to give 

490 the misfit and normalisation contributions for a single model as 

491 ``misfits[iresidual, 0]`` and misfits[iresidual, 1]`` in which 

492 case, the first dimension (imodel) of the result will be stipped 

493 off. 

494 

495 :param extra_weights: if given, 2D array of extra weights to be applied 

496 to the contributions, indexed as 

497 ``extra_weights[ibootstrap, iresidual]``. 

498 

499 :param extra_residuals: if given, 2D array of perturbations to be added 

500 to the residuals, indexed as 

501 ``extra_residuals[ibootstrap, iresidual]``. 

502 

503 :param extra_correlated_weights: if a dictionary of 

504 ``imisfit: correlated weight matrix`` is passed a correlated 

505 weight matrix is applied to the misfit and normalisation values. 

506 `imisfit` is the starting index in the misfits vector the 

507 correlated weight matrix applies to. 

508 

509 :param get_contributions: get the weighted and perturbed contributions 

510 (don't do the sum). 

511 

512 :returns: if no *extra_weights* or *extra_residuals* are given, a 1D 

513 array indexed as ``misfits[imodel]`` containing the global misfit 

514 for each model is returned, otherwise a 2D array 

515 ``misfits[imodel, ibootstrap]`` with the misfit for every model and 

516 weighting/residual set is returned. 

517 ''' 

518 if misfits.ndim == 2: 

519 misfits = misfits[num.newaxis, :, :] 

520 return self.combine_misfits( 

521 misfits, extra_weights, extra_residuals, 

522 extra_correlated_weights, get_contributions)[0, ...] 

523 

524 if extra_weights is None and extra_residuals is None: 

525 return self.combine_misfits( 

526 misfits, False, False, 

527 extra_correlated_weights, get_contributions)[:, 0] 

528 

529 assert misfits.ndim == 3 

530 assert not num.any(extra_weights) or extra_weights.ndim == 2 

531 assert not num.any(extra_residuals) or extra_residuals.ndim == 2 

532 

533 if self.norm_exponent != 2 and extra_correlated_weights: 

534 raise GrondError('Correlated weights can only be used ' 

535 ' with norm_exponent=2') 

536 

537 exp, root = self.get_norm_functions() 

538 

539 nmodels = misfits.shape[0] 

540 nmisfits = misfits.shape[1] # noqa 

541 

542 mf = misfits[:, num.newaxis, :, :].copy() 

543 

544 if num.any(extra_residuals): 

545 mf = mf + extra_residuals[num.newaxis, :, :, num.newaxis] 

546 

547 res = mf[..., 0] 

548 norms = mf[..., 1] 

549 

550 for imisfit, corr_weight_mat in extra_correlated_weights.items(): 

551 

552 jmisfit = imisfit + corr_weight_mat.shape[0] 

553 

554 for imodel in range(nmodels): 

555 corr_res = res[imodel, :, imisfit:jmisfit] 

556 corr_norms = norms[imodel, :, imisfit:jmisfit] 

557 

558 res[imodel, :, imisfit:jmisfit] = \ 

559 correlated_weights(corr_res, corr_weight_mat) 

560 

561 norms[imodel, :, imisfit:jmisfit] = \ 

562 correlated_weights(corr_norms, corr_weight_mat) 

563 

564 # Apply normalization family weights (these weights depend on 

565 # on just calculated correlated norms!) 

566 weights_fam = \ 

567 self.inter_family_weights2(norms[:, 0, :])[:, num.newaxis, :] 

568 

569 weights_fam = exp(weights_fam) 

570 

571 res = exp(res) 

572 norms = exp(norms) 

573 

574 res *= weights_fam 

575 norms *= weights_fam 

576 

577 weights_tar = self.get_target_weights()[num.newaxis, num.newaxis, :] 

578 if num.any(extra_weights): 

579 weights_tar = weights_tar * extra_weights[num.newaxis, :, :] 

580 

581 weights_tar = exp(weights_tar) 

582 

583 res = res * weights_tar 

584 norms = norms * weights_tar 

585 

586 if get_contributions: 

587 return res / num.nansum(norms, axis=2)[:, :, num.newaxis] 

588 

589 result = root( 

590 num.nansum(res, axis=2) / 

591 num.nansum(norms, axis=2)) 

592 

593 assert result[result < 0].size == 0 

594 return result 

595 

596 def make_family_mask(self): 

597 family_names = set() 

598 families = num.zeros(self.nmisfits, dtype=int) 

599 

600 idx = 0 

601 for itarget, target in enumerate(self.targets): 

602 family_names.add(target.normalisation_family) 

603 families[idx:idx + target.nmisfits] = len(family_names) - 1 

604 idx += target.nmisfits 

605 

606 return families, len(family_names) 

607 

608 def get_family_mask(self): 

609 if self._family_mask is None: 

610 self._family_mask = self.make_family_mask() 

611 

612 return self._family_mask 

613 

614 def evaluate(self, x, mask=None, result_mode='full', targets=None, 

615 nthreads=1): 

616 source = self.get_source(x) 

617 engine = self.get_engine() 

618 

619 self.set_target_parameter_values(x) 

620 

621 if mask is not None and targets is not None: 

622 raise ValueError('Mask cannot be defined with targets set.') 

623 targets = targets if targets is not None else self.targets 

624 

625 for target in targets: 

626 target.set_result_mode(result_mode) 

627 

628 modelling_targets = [] 

629 t2m_map = {} 

630 for itarget, target in enumerate(targets): 

631 t2m_map[target] = target.prepare_modelling(engine, source, targets) 

632 if mask is None or mask[itarget]: 

633 modelling_targets.extend(t2m_map[target]) 

634 

635 u2m_map = {} 

636 for imtarget, mtarget in enumerate(modelling_targets): 

637 if mtarget not in u2m_map: 

638 u2m_map[mtarget] = [] 

639 

640 u2m_map[mtarget].append(imtarget) 

641 

642 modelling_targets_unique = list(u2m_map.keys()) 

643 resp = engine.process(source, modelling_targets_unique, 

644 nthreads=nthreads) 

645 modelling_results_unique = list(resp.results_list[0]) 

646 modelling_results = [None] * len(modelling_targets) 

647 

648 for mtarget, mresult in zip( 

649 modelling_targets_unique, modelling_results_unique): 

650 

651 for itarget in u2m_map[mtarget]: 

652 modelling_results[itarget] = mresult 

653 

654 imt = 0 

655 results = [] 

656 for itarget, target in enumerate(targets): 

657 nmt_this = len(t2m_map[target]) 

658 if mask is None or mask[itarget]: 

659 result = target.finalize_modelling( 

660 engine, source, 

661 t2m_map[target], 

662 modelling_results[imt:imt + nmt_this]) 

663 

664 imt += nmt_this 

665 else: 

666 result = gf.SeismosizerError( 

667 'target was excluded from modelling') 

668 

669 results.append(result) 

670 

671 return results 

672 

673 def misfits(self, x, mask=None, nthreads=1, raise_bad=False): 

674 results = self.evaluate( 

675 x, mask=mask, result_mode='sparse', nthreads=nthreads) 

676 misfits = num.full((self.nmisfits, 2), num.nan) 

677 

678 imisfit = 0 

679 for target, result in zip(self.targets, results): 

680 if isinstance(result, MisfitResult): 

681 misfits[imisfit:imisfit + target.nmisfits, :] = result.misfits 

682 elif raise_bad: 

683 logger.error(result) 

684 

685 imisfit += target.nmisfits 

686 

687 return misfits 

688 

689 def forward(self, x): 

690 source = self.get_source(x) 

691 engine = self.get_engine() 

692 

693 plain_targets = [] 

694 for target in self.targets: 

695 plain_targets.extend(target.get_plain_targets(engine, source)) 

696 

697 resp = engine.process(source, plain_targets) 

698 

699 results = [] 

700 for target, result in zip(plain_targets, resp.results_list[0]): 

701 if isinstance(result, gf.SeismosizerError): 

702 logger.debug( 

703 '%s.%s.%s.%s: %s' % (target.codes + (str(result),))) 

704 else: 

705 results.append(result) 

706 

707 return results 

708 

709 def get_random_model(self, ntries_limit=100): 

710 xbounds = self.get_parameter_bounds() 

711 

712 for _ in range(ntries_limit): 

713 x = self.random_uniform(xbounds, rstate=g_rstate) 

714 try: 

715 return self.preconstrain(x) 

716 

717 except Forbidden: 

718 pass 

719 

720 raise GrondError( 

721 'Could not find any suitable candidate sample within %i tries' % ( 

722 ntries_limit)) 

723 

724 def get_min_model(self, ntries_limit=100): 

725 xbounds = self.get_parameter_bounds() 

726 

727 for _ in range(ntries_limit): 

728 x = self.minimum(xbounds) 

729 try: 

730 return self.preconstrain(x) 

731 

732 except Forbidden: 

733 pass 

734 

735 raise GrondError( 

736 'Could not find any suitable candidate sample within %i tries' % ( 

737 ntries_limit)) 

738 

739 def get_max_model(self, ntries_limit=100): 

740 xbounds = self.get_parameter_bounds() 

741 

742 for _ in range(ntries_limit): 

743 x = self.maximum(xbounds) 

744 try: 

745 return self.preconstrain(x) 

746 

747 except Forbidden: 

748 pass 

749 

750 raise GrondError( 

751 'Could not find any suitable candidate sample within %i tries' % ( 

752 ntries_limit)) 

753 

754 

755class ProblemInfoNotAvailable(GrondError): 

756 pass 

757 

758 

759class ProblemDataNotAvailable(GrondError): 

760 pass 

761 

762 

763class NoSuchAttribute(GrondError): 

764 pass 

765 

766 

767class InvalidAttributeName(GrondError): 

768 pass 

769 

770 

771class ModelHistory(object): 

772 ''' 

773 Write, read and follow sequences of models produced in an optimisation run. 

774 

775 :param problem: :class:`grond.Problem` instance 

776 :param path: path to rundir, defaults to None 

777 :type path: str, optional 

778 :param mode: open mode, 'r': read, 'w': write 

779 :type mode: str, optional 

780 ''' 

781 

782 nmodels_capacity_min = 1024 

783 

784 def __init__(self, problem, nchains=None, path=None, mode='r'): 

785 self.mode = mode 

786 

787 self.problem = problem 

788 self.path = path 

789 self.nchains = nchains 

790 

791 self._models_buffer = None 

792 self._misfits_buffer = None 

793 self._bootstraps_buffer = None 

794 self._sample_contexts_buffer = None 

795 

796 self._sorted_misfit_idx = {} 

797 

798 self.models = None 

799 self.misfits = None 

800 self.bootstrap_misfits = None 

801 

802 self.sampler_contexts = None 

803 

804 self.nmodels_capacity = self.nmodels_capacity_min 

805 self.listeners = [] 

806 

807 self._attributes = {} 

808 

809 if mode == 'r': 

810 self.load() 

811 

812 @staticmethod 

813 def verify_rundir(rundir): 

814 _rundir_files = ('misfits', 'models') 

815 

816 if not op.exists(rundir): 

817 raise ProblemDataNotAvailable( 

818 'Directory does not exist: %s' % rundir) 

819 for f in _rundir_files: 

820 if not op.exists(op.join(rundir, f)): 

821 raise ProblemDataNotAvailable('File not found: %s' % f) 

822 

823 @classmethod 

824 def follow(cls, path, nchains=None, wait=20.): 

825 ''' 

826 Start following a rundir (constructor). 

827 

828 :param path: the path to follow, a grond rundir 

829 :type path: str, optional 

830 :param wait: wait time until the folder become alive 

831 :type wait: number in seconds, optional 

832 :returns: A :py:class:`ModelHistory` instance 

833 ''' 

834 start_watch = time.time() 

835 while (time.time() - start_watch) < wait: 

836 try: 

837 cls.verify_rundir(path) 

838 problem = load_problem_info(path) 

839 return cls(problem, nchains=nchains, path=path, mode='r') 

840 except (ProblemDataNotAvailable, OSError): 

841 time.sleep(.25) 

842 

843 @property 

844 def nmodels(self): 

845 if self.models is None: 

846 return 0 

847 else: 

848 return self.models.shape[0] 

849 

850 @nmodels.setter 

851 def nmodels(self, nmodels_new): 

852 assert 0 <= nmodels_new <= self.nmodels 

853 self.models = self._models_buffer[:nmodels_new, :] 

854 self.misfits = self._misfits_buffer[:nmodels_new, :, :] 

855 if self.nchains is not None: 

856 self.bootstrap_misfits = self._bootstraps_buffer[:nmodels_new, :, :] # noqa 

857 if self._sample_contexts_buffer is not None: 

858 self.sampler_contexts = self._sample_contexts_buffer[:nmodels_new, :] # noqa 

859 

860 @property 

861 def nmodels_capacity(self): 

862 if self._models_buffer is None: 

863 return 0 

864 else: 

865 return self._models_buffer.shape[0] 

866 

867 @nmodels_capacity.setter 

868 def nmodels_capacity(self, nmodels_capacity_new): 

869 if self.nmodels_capacity != nmodels_capacity_new: 

870 

871 models_buffer = num.zeros( 

872 (nmodels_capacity_new, self.problem.nparameters), 

873 dtype=float) 

874 misfits_buffer = num.zeros( 

875 (nmodels_capacity_new, self.problem.nmisfits, 2), 

876 dtype=float) 

877 sample_contexts_buffer = num.zeros( 

878 (nmodels_capacity_new, 4), 

879 dtype=int) 

880 sample_contexts_buffer.fill(-1) 

881 

882 if self.nchains is not None: 

883 bootstraps_buffer = num.zeros( 

884 (nmodels_capacity_new, self.nchains), 

885 dtype=float) 

886 

887 ncopy = min(self.nmodels, nmodels_capacity_new) 

888 

889 if self._models_buffer is not None: 

890 models_buffer[:ncopy, :] = \ 

891 self._models_buffer[:ncopy, :] 

892 misfits_buffer[:ncopy, :, :] = \ 

893 self._misfits_buffer[:ncopy, :, :] 

894 sample_contexts_buffer[:ncopy, :] = \ 

895 self._sample_contexts_buffer[:ncopy, :] 

896 

897 self._models_buffer = models_buffer 

898 self._misfits_buffer = misfits_buffer 

899 self._sample_contexts_buffer = sample_contexts_buffer 

900 

901 if self.nchains is not None: 

902 if self._bootstraps_buffer is not None: 

903 bootstraps_buffer[:ncopy, :] = \ 

904 self._bootstraps_buffer[:ncopy, :] 

905 self._bootstraps_buffer = bootstraps_buffer 

906 

907 def clear(self): 

908 assert self.mode != 'r', 'History is read-only, cannot clear.' 

909 self.nmodels = 0 

910 self.nmodels_capacity = self.nmodels_capacity_min 

911 

912 def extend( 

913 self, models, misfits, 

914 bootstrap_misfits=None, 

915 sampler_contexts=None): 

916 

917 nmodels = self.nmodels 

918 n = models.shape[0] 

919 

920 nmodels_capacity_want = max( 

921 self.nmodels_capacity_min, nextpow2(nmodels + n)) 

922 

923 if nmodels_capacity_want != self.nmodels_capacity: 

924 self.nmodels_capacity = nmodels_capacity_want 

925 

926 self._models_buffer[nmodels:nmodels + n, :] = models 

927 self._misfits_buffer[nmodels:nmodels + n, :, :] = misfits 

928 

929 self.models = self._models_buffer[:nmodels + n, :] 

930 self.misfits = self._misfits_buffer[:nmodels + n, :, :] 

931 

932 if bootstrap_misfits is not None: 

933 self._bootstraps_buffer[nmodels:nmodels + n, :] = bootstrap_misfits 

934 self.bootstrap_misfits = self._bootstraps_buffer[:nmodels + n, :] 

935 

936 if sampler_contexts is not None: 

937 self._sample_contexts_buffer[nmodels:nmodels + n, :] \ 

938 = sampler_contexts 

939 self.sampler_contexts = self._sample_contexts_buffer[ 

940 :nmodels + n, :] 

941 

942 if self.path and self.mode == 'w': 

943 for i in range(n): 

944 self.problem.dump_problem_data( 

945 self.path, models[i, :], misfits[i, :, :], 

946 bootstrap_misfits[i, :] 

947 if bootstrap_misfits is not None else None, 

948 sampler_contexts[i, :] 

949 if sampler_contexts is not None else None) 

950 

951 self._sorted_misfit_idx.clear() 

952 self.emit('extend', nmodels, n, models, misfits, sampler_contexts) 

953 

954 def append( 

955 self, model, misfits, 

956 bootstrap_misfits=None, 

957 sampler_context=None): 

958 

959 if bootstrap_misfits is not None: 

960 bootstrap_misfits = bootstrap_misfits[num.newaxis, :] 

961 

962 if sampler_context is not None: 

963 sampler_context = sampler_context[num.newaxis, :] 

964 

965 return self.extend( 

966 model[num.newaxis, :], misfits[num.newaxis, :, :], 

967 bootstrap_misfits, sampler_context) 

968 

969 def load(self): 

970 self.mode = 'r' 

971 self.verify_rundir(self.path) 

972 models, misfits, bootstraps, sampler_contexts = load_problem_data( 

973 self.path, self.problem, nchains=self.nchains) 

974 self.extend(models, misfits, bootstraps, sampler_contexts) 

975 

976 def update(self): 

977 ''' Update history from path ''' 

978 nmodels_available = get_nmodels(self.path, self.problem) 

979 if self.nmodels == nmodels_available: 

980 return 

981 

982 try: 

983 new_models, new_misfits, new_bootstraps, new_sampler_contexts = \ 

984 load_problem_data( 

985 self.path, 

986 self.problem, 

987 nmodels_skip=self.nmodels, 

988 nchains=self.nchains) 

989 

990 except ValueError: 

991 return 

992 

993 self.extend( 

994 new_models, 

995 new_misfits, 

996 new_bootstraps, 

997 new_sampler_contexts) 

998 

999 def add_listener(self, listener): 

1000 ''' Add a listener to the history 

1001 

1002 The listening class can implement the following methods: 

1003 * ``extend`` 

1004 ''' 

1005 self.listeners.append(listener) 

1006 

1007 def emit(self, event_name, *args, **kwargs): 

1008 for listener in self.listeners: 

1009 slot = getattr(listener, event_name, None) 

1010 if callable(slot): 

1011 slot(*args, **kwargs) 

1012 

1013 @property 

1014 def attribute_names(self): 

1015 apath = op.join(self.path, 'attributes') 

1016 if not os.path.exists(apath): 

1017 return [] 

1018 

1019 return [fn for fn in os.listdir(apath) 

1020 if StringID.regex.match(fn)] 

1021 

1022 def get_attribute(self, name): 

1023 if name not in self._attributes: 

1024 if name not in self.attribute_names: 

1025 raise NoSuchAttribute(name) 

1026 

1027 path = op.join(self.path, 'attributes', name) 

1028 

1029 with open(path, 'rb') as f: 

1030 self._attributes[name] = num.fromfile( 

1031 f, dtype='<i4', 

1032 count=self.nmodels).astype(int) 

1033 

1034 assert self._attributes[name].shape == (self.nmodels,) 

1035 

1036 return self._attributes[name] 

1037 

1038 def set_attribute(self, name, attribute): 

1039 if not StringID.regex.match(name): 

1040 raise InvalidAttributeName(name) 

1041 

1042 attribute = attribute.astype(int) 

1043 assert attribute.shape == (self.nmodels,) 

1044 

1045 apath = op.join(self.path, 'attributes') 

1046 

1047 if not os.path.exists(apath): 

1048 os.mkdir(apath) 

1049 

1050 path = op.join(apath, name) 

1051 

1052 with open(path, 'wb') as f: 

1053 attribute.astype('<i4').tofile(f) 

1054 

1055 self._attributes[name] = attribute 

1056 

1057 def ensure_bootstrap_misfits(self, optimiser): 

1058 if self.bootstrap_misfits is None: 

1059 problem = self.problem 

1060 self.bootstrap_misfits = problem.combine_misfits( 

1061 self.misfits, 

1062 extra_weights=optimiser.get_bootstrap_weights(problem), 

1063 extra_residuals=optimiser.get_bootstrap_residuals(problem)) 

1064 

1065 def imodels_by_cluster(self, cluster_attribute): 

1066 if cluster_attribute is None: 

1067 return [(-1, 100.0, num.arange(self.nmodels))] 

1068 

1069 by_cluster = [] 

1070 try: 

1071 iclusters = self.get_attribute(cluster_attribute) 

1072 iclusters_avail = num.unique(iclusters) 

1073 

1074 for icluster in iclusters_avail: 

1075 imodels = num.where(iclusters == icluster)[0] 

1076 by_cluster.append( 

1077 (icluster, 

1078 (100.0 * imodels.size) / self.nmodels, 

1079 imodels)) 

1080 

1081 if by_cluster and by_cluster[0][0] == -1: 

1082 by_cluster.append(by_cluster.pop(0)) 

1083 

1084 except NoSuchAttribute: 

1085 logger.warn( 

1086 'Attribute %s not set in run %s.\n' 

1087 ' Skipping model retrieval by clusters.' % ( 

1088 cluster_attribute, self.problem.name)) 

1089 

1090 return by_cluster 

1091 

1092 def models_by_cluster(self, cluster_attribute): 

1093 if cluster_attribute is None: 

1094 return [(-1, 100.0, self.models)] 

1095 

1096 return [ 

1097 (icluster, percentage, self.models[imodels]) 

1098 for (icluster, percentage, imodels) 

1099 in self.imodels_by_cluster(cluster_attribute)] 

1100 

1101 def mean_sources_by_cluster(self, cluster_attribute): 

1102 return [ 

1103 (icluster, percentage, stats.get_mean_source(self.problem, models)) 

1104 for (icluster, percentage, models) 

1105 in self.models_by_cluster(cluster_attribute)] 

1106 

1107 def get_sorted_misfits_idx(self, chain=0): 

1108 if chain not in self._sorted_misfit_idx.keys(): 

1109 self._sorted_misfit_idx[chain] = num.argsort( 

1110 self.bootstrap_misfits[:, chain]) 

1111 

1112 return self._sorted_misfit_idx[chain] 

1113 

1114 def get_sorted_misfits(self, chain=0): 

1115 isort = self.get_sorted_misfits_idx(chain) 

1116 return self.bootstrap_misfits[:, chain][isort] 

1117 

1118 def get_sorted_models(self, chain=0): 

1119 isort = self.get_sorted_misfits_idx(chain=0) 

1120 return self.models[isort, :] 

1121 

1122 def get_sorted_primary_misfits(self): 

1123 return self.get_sorted_misfits(chain=0) 

1124 

1125 def get_sorted_primary_models(self): 

1126 return self.get_sorted_models(chain=0) 

1127 

1128 def get_best_model(self, chain=0): 

1129 return self.get_sorted_models(chain)[0, ...] 

1130 

1131 def get_best_misfit(self, chain=0): 

1132 return self.get_sorted_misfits(chain)[0] 

1133 

1134 def get_mean_model(self): 

1135 return num.mean(self.models, axis=0) 

1136 

1137 def get_mean_misfit(self, chain=0): 

1138 return num.mean(self.bootstrap_misfits[:, chain]) 

1139 

1140 def get_best_source(self, chain=0): 

1141 return self.problem.get_source(self.get_best_model(chain)) 

1142 

1143 def get_mean_source(self, chain=0): 

1144 return self.problem.get_source(self.get_mean_model()) 

1145 

1146 def get_chain_misfits(self, chain=0): 

1147 return self.bootstrap_misfits[:, chain] 

1148 

1149 def get_primary_chain_misfits(self): 

1150 return self.get_chain_misfits(chain=0) 

1151 

1152 

1153class RandomStateManager(object): 

1154 

1155 MAX_LEN = 64 

1156 save_struct = struct.Struct('%ds7s2496sqqd' % MAX_LEN) 

1157 

1158 def __init__(self): 

1159 self.rstates = {} 

1160 self.lock = threading.Lock() 

1161 

1162 def get_rstate(self, name, seed=None): 

1163 assert len(name) <= self.MAX_LEN 

1164 

1165 if name not in self.rstates: 

1166 self.rstates[name] = num.random.RandomState(seed) 

1167 return self.rstates[name] 

1168 

1169 @property 

1170 def nstates(self): 

1171 return len(self.rstates) 

1172 

1173 def save_state(self, fname): 

1174 with self.lock: 

1175 with open(fname, 'wb') as f: 

1176 

1177 for name, rstate in self.rstates.items(): 

1178 s, arr, pos, has_gauss, chached_gauss = rstate.get_state() 

1179 f.write( 

1180 self.save_struct.pack( 

1181 name.encode(), s.encode(), arr.tobytes(), 

1182 pos, has_gauss, chached_gauss)) 

1183 

1184 def load_state(self, fname): 

1185 with self.lock: 

1186 with open(fname, 'rb') as f: 

1187 while True: 

1188 buff = f.read(self.save_struct.size) 

1189 if not buff: 

1190 break 

1191 

1192 name, s, arr, pos, has_gauss, chached_gauss = \ 

1193 self.save_struct.unpack(buff) 

1194 

1195 name = name.replace(b'\x00', b'').decode() 

1196 s = s.replace(b'\x00', b'').decode() 

1197 arr = num.frombuffer(arr, dtype=num.uint32) 

1198 

1199 rstate = num.random.RandomState() 

1200 rstate.set_state((s, arr, pos, has_gauss, chached_gauss)) 

1201 self.rstates[name] = rstate 

1202 

1203 

1204def get_nmodels(dirname, problem): 

1205 fn = op.join(dirname, 'models') 

1206 with open(fn, 'r') as f: 

1207 nmodels1 = os.fstat(f.fileno()).st_size // (problem.nparameters * 8) 

1208 

1209 fn = op.join(dirname, 'misfits') 

1210 with open(fn, 'r') as f: 

1211 nmodels2 = os.fstat(f.fileno()).st_size // (problem.nmisfits * 2 * 8) 

1212 

1213 return min(nmodels1, nmodels2) 

1214 

1215 

1216def load_problem_info_and_data(dirname, subset=None, nchains=None): 

1217 problem = load_problem_info(dirname) 

1218 models, misfits, bootstraps, sampler_contexts = load_problem_data( 

1219 xjoin(dirname, subset), problem, nchains=nchains) 

1220 return problem, models, misfits, bootstraps, sampler_contexts 

1221 

1222 

1223def load_optimiser_info(dirname): 

1224 fn = op.join(dirname, 'optimiser.yaml') 

1225 return guts.load(filename=fn) 

1226 

1227 

1228def load_problem_info(dirname): 

1229 try: 

1230 fn = op.join(dirname, 'problem.yaml') 

1231 return guts.load(filename=fn) 

1232 except OSError as e: 

1233 logger.debug(e) 

1234 raise ProblemInfoNotAvailable( 

1235 'No problem info available (%s).' % dirname) 

1236 

1237 

1238def load_problem_data(dirname, problem, nmodels_skip=0, nchains=None): 

1239 

1240 def get_chains_fn(): 

1241 for fn in (op.join(dirname, 'bootstraps'), 

1242 op.join(dirname, 'chains')): 

1243 if op.exists(fn): 

1244 return fn 

1245 return False 

1246 

1247 try: 

1248 nmodels = get_nmodels(dirname, problem) - nmodels_skip 

1249 

1250 fn = op.join(dirname, 'models') 

1251 offset = nmodels_skip * problem.nparameters * 8 

1252 models = num.memmap( 

1253 fn, dtype='<f8', 

1254 offset=offset, 

1255 shape=(nmodels, problem.nparameters))\ 

1256 .astype(float, copy=False) 

1257 

1258 fn = op.join(dirname, 'misfits') 

1259 offset = nmodels_skip * problem.nmisfits * 2 * 8 

1260 misfits = num.memmap( 

1261 fn, dtype='<f8', 

1262 offset=offset, 

1263 shape=(nmodels, problem.nmisfits, 2))\ 

1264 .astype(float, copy=False) 

1265 

1266 chains = None 

1267 fn = get_chains_fn() 

1268 if fn and nchains is not None: 

1269 offset = nmodels_skip * nchains * 8 

1270 chains = num.memmap( 

1271 fn, dtype='<f8', 

1272 offset=offset, 

1273 shape=(nmodels, nchains))\ 

1274 .astype(float, copy=False) 

1275 

1276 sampler_contexts = None 

1277 fn = op.join(dirname, 'choices') 

1278 if op.exists(fn): 

1279 offset = nmodels_skip * 4 * 8 

1280 sampler_contexts = num.memmap( 

1281 fn, dtype='<i8', 

1282 offset=offset, 

1283 shape=(nmodels, 4))\ 

1284 .astype(int, copy=False) 

1285 

1286 fn = op.join(dirname, 'rstate') 

1287 if op.exists(fn): 

1288 problem.get_rstate_manager().load_state(fn) 

1289 

1290 except OSError as e: 

1291 logger.debug(str(e)) 

1292 raise ProblemDataNotAvailable( 

1293 'No problem data available (%s).' % dirname) 

1294 

1295 return models, misfits, chains, sampler_contexts 

1296 

1297 

1298__all__ = ''' 

1299 ProblemConfig 

1300 Problem 

1301 ModelHistory 

1302 ProblemInfoNotAvailable 

1303 ProblemDataNotAvailable 

1304 load_problem_info 

1305 load_problem_info_and_data 

1306 InvalidAttributeName 

1307 NoSuchAttribute 

1308'''.split()