Coverage for /usr/local/lib/python3.11/dist-packages/grond/problems/base.py: 81%
658 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-28 13:13 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-28 13:13 +0000
1'''
2Base classes for Grond's problem definition and the model history container.
4Common behaviour of all source models offered by Grond is implemented here.
5Source model specific details are implemented in the respective submodules.
6'''
8import numpy as num
9import math
10import copy
11import logging
12import os.path as op
13import os
14import time
16from pyrocko import gf, util, guts
17from pyrocko.guts import Object, String, List, Dict, Int
19from grond.meta import ADict, Parameter, GrondError, xjoin, Forbidden, \
20 StringID, has_get_plot_classes
21from ..targets import MisfitResult, MisfitTarget, TargetGroup, \
22 WaveformMisfitTarget, SatelliteMisfitTarget, GNSSCampaignMisfitTarget
24from grond import stats
26from grond.version import __version__
28guts_prefix = 'grond'
29logger = logging.getLogger('grond.problems.base')
30km = 1e3
31as_km = dict(scale_factor=km, scale_unit='km')
33g_rstate = num.random.RandomState()
36def nextpow2(i):
37 return 2**int(math.ceil(math.log(i)/math.log(2.)))
40def correlated_weights(values, weight_matrix):
41 '''
42 Applies correlated weights to values
44 The resulting weighed values have to be squared! Check out
45 :meth:`Problem.combine_misfits` for more information.
47 :param values: Misfits or norms as :class:`numpy.Array`
48 :param weight: Weight matrix, commonly the inverse of covariance matrix
50 :returns: :class:`numpy.Array` weighted values
51 '''
52 return num.matmul(values, weight_matrix)
55class ProblemConfig(Object):
56 '''
57 Base class for config section defining the objective function setup.
59 Factory for :py:class:`Problem` objects.
60 '''
61 name_template = String.T()
62 norm_exponent = Int.T(default=2)
63 nthreads = Int.T(
64 default=1,
65 optional=True,
66 help='Deprecated: use command line argument or global config to set '
67 'number of allowed threads.')
69 def check_deprecations(self):
70 if self.nthreads != 1:
71 logger.warn(
72 'The `nthreads` parameter in `ProblemConfig` has been '
73 'deprecated and is ignored now. Please use the `--nthreads` '
74 'command line option to set it.')
76 def get_problem(self, event, target_groups, targets):
77 '''
78 Instantiate the problem with a given event and targets.
80 :returns: :py:class:`Problem` object
81 '''
82 self.check_deprecations()
83 raise NotImplementedError
86@has_get_plot_classes
87class Problem(Object):
88 '''
89 Base class for objective function setup.
91 Defines the *problem* to be solved by the optimiser.
92 '''
93 name = String.T()
94 ranges = Dict.T(String.T(), gf.Range.T())
95 dependants = List.T(Parameter.T())
96 norm_exponent = Int.T(default=2)
97 base_source = gf.Source.T(optional=True)
98 targets = List.T(MisfitTarget.T())
99 target_groups = List.T(TargetGroup.T())
100 grond_version = String.T(optional=True)
101 nthreads = Int.T(
102 default=1,
103 optional=True,
104 help='Deprecated: use command line argument or global config to set '
105 'number of allowed threads.')
107 def __init__(self, **kwargs):
108 Object.__init__(self, **kwargs)
110 if self.grond_version is None:
111 self.grond_version = __version__
113 self._target_weights = None
114 self._engine = None
115 self._family_mask = None
117 if hasattr(self, 'problem_waveform_parameters') and self.has_waveforms:
118 self.problem_parameters =\
119 self.problem_parameters + self.problem_waveform_parameters
121 unused_parameters = []
122 for p in self.problem_parameters:
123 if p.optional and p._name not in self.ranges.keys():
124 unused_parameters.append(p)
126 for p in unused_parameters:
127 self.problem_parameters.remove(p)
129 self.check()
131 @classmethod
132 def get_plot_classes(cls):
133 from . import plot
134 return plot.get_plot_classes()
136 def check(self):
137 paths = set()
138 for grp in self.target_groups:
139 if grp.path == 'all':
140 continue
141 if grp.path in paths:
142 raise ValueError('Path %s defined more than once! In %s'
143 % (grp.path, grp.__class__.__name__))
144 paths.add(grp.path)
145 logger.debug('TargetGroup check OK.')
147 def copy(self):
148 o = copy.copy(self)
149 o._target_weights = None
150 return o
152 def set_target_parameter_values(self, x):
153 nprob = len(self.problem_parameters)
154 for target in self.targets:
155 target.set_parameter_values(x[nprob:nprob+target.nparameters])
156 nprob += target.nparameters
158 def get_parameter_dict(self, model, group=None):
159 params = []
160 for ip, p in enumerate(self.parameters):
161 if group in p.groups or group is None:
162 params.append((p.name, model[ip]))
163 return ADict(params)
165 def get_parameter_array(self, d):
166 arr = num.zeros(self.nparameters, dtype=float)
167 for ip, p in enumerate(self.parameters):
168 if p.name in d.keys():
169 arr[ip] = d[p.name]
170 return arr
172 def dump_problem_info(self, dirname):
173 fn = op.join(dirname, 'problem.yaml')
174 util.ensuredirs(fn)
175 guts.dump(self, filename=fn)
177 def dump_problem_data(
178 self, dirname, x, misfits, chains=None,
179 sampler_context=None):
181 fn = op.join(dirname, 'models')
182 if not isinstance(x, num.ndarray):
183 x = num.array(x)
184 with open(fn, 'ab') as f:
185 x.astype('<f8').tofile(f)
187 fn = op.join(dirname, 'misfits')
188 with open(fn, 'ab') as f:
189 misfits.astype('<f8').tofile(f)
191 if chains is not None:
192 fn = op.join(dirname, 'chains')
193 with open(fn, 'ab') as f:
194 chains.astype('<f8').tofile(f)
196 if sampler_context is not None:
197 fn = op.join(dirname, 'choices')
198 with open(fn, 'ab') as f:
199 num.array(sampler_context, dtype='<i8').tofile(f)
201 def name_to_index(self, name):
202 pnames = [p.name for p in self.combined]
203 return pnames.index(name)
205 @property
206 def parameters(self):
207 target_parameters = []
208 for target in self.targets:
209 target_parameters.extend(target.target_parameters)
210 return self.problem_parameters + target_parameters
212 @property
213 def parameter_names(self):
214 return [p.name for p in self.combined]
216 @property
217 def dependant_names(self):
218 return [p.name for p in self.dependants]
220 @property
221 def nparameters(self):
222 return len(self.parameters)
224 @property
225 def ntargets(self):
226 return len(self.targets)
228 @property
229 def nwaveform_targets(self):
230 return len(self.waveform_targets)
232 @property
233 def nsatellite_targets(self):
234 return len(self.satellite_targets)
236 @property
237 def ngnss_targets(self):
238 return len(self.gnss_targets)
240 @property
241 def nmisfits(self):
242 nmisfits = 0
243 for target in self.targets:
244 nmisfits += target.nmisfits
245 return nmisfits
247 @property
248 def ndependants(self):
249 return len(self.dependants)
251 @property
252 def ncombined(self):
253 return len(self.parameters) + len(self.dependants)
255 @property
256 def combined(self):
257 return self.parameters + self.dependants
259 @property
260 def satellite_targets(self):
261 return [t for t in self.targets
262 if isinstance(t, SatelliteMisfitTarget)]
264 @property
265 def gnss_targets(self):
266 return [t for t in self.targets
267 if isinstance(t, GNSSCampaignMisfitTarget)]
269 @property
270 def waveform_targets(self):
271 return [t for t in self.targets
272 if isinstance(t, WaveformMisfitTarget)]
274 @property
275 def has_satellite(self):
276 if self.satellite_targets:
277 return True
278 return False
280 @property
281 def has_waveforms(self):
282 if self.waveform_targets:
283 return True
284 return False
286 def set_engine(self, engine):
287 self._engine = engine
289 def get_engine(self):
290 return self._engine
292 def get_gf_store(self, target):
293 if self.get_engine() is None:
294 raise GrondError('Cannot get GF Store, modelling is not set up.')
295 return self.get_engine().get_store(target.store_id)
297 def random_uniform(self, xbounds, rstate, fixed_magnitude=None):
298 if fixed_magnitude is not None:
299 raise GrondError(
300 'Setting fixed magnitude in random model generation not '
301 'supported for this type of problem.')
303 x = rstate.uniform(0., 1., self.nparameters)
304 x *= (xbounds[:, 1] - xbounds[:, 0])
305 x += xbounds[:, 0]
306 return x
308 def preconstrain(self, x):
309 return x
311 def extract(self, xs, i):
312 if xs.ndim == 1:
313 return self.extract(xs[num.newaxis, :], i)[0]
315 if i < self.nparameters:
316 return xs[:, i]
317 else:
318 return self.make_dependant(
319 xs, self.dependants[i-self.nparameters].name)
321 def get_target_weights(self):
322 if self._target_weights is None:
323 self._target_weights = num.concatenate(
324 [target.get_combined_weight() for target in self.targets])
326 return self._target_weights
328 def get_target_residuals(self):
329 pass
331 def inter_family_weights(self, ns):
332 exp, root = self.get_norm_functions()
334 family, nfamilies = self.get_family_mask()
336 ws = num.zeros(self.nmisfits)
337 for ifamily in range(nfamilies):
338 mask = family == ifamily
339 ws[mask] = 1.0 / root(num.nansum(exp(ns[mask])))
341 return ws
343 def inter_family_weights2(self, ns):
344 '''
345 :param ns: 2D array with normalization factors ``ns[imodel, itarget]``
346 :returns: 2D array ``weights[imodel, itarget]``
347 '''
349 exp, root = self.get_norm_functions()
350 family, nfamilies = self.get_family_mask()
352 ws = num.zeros(ns.shape)
353 for ifamily in range(nfamilies):
354 mask = family == ifamily
355 ws[:, mask] = (1.0 / root(
356 num.nansum(exp(ns[:, mask]), axis=1)))[:, num.newaxis]
358 return ws
360 def get_reference_model(self):
361 model = num.zeros(self.nparameters)
362 model_source_params = self.pack(self.base_source)
363 model[:model_source_params.size] = model_source_params
364 return model
366 def get_parameter_bounds(self):
367 out = []
368 for p in self.problem_parameters:
369 r = self.ranges[p.name]
370 out.append((r.start, r.stop))
372 for target in self.targets:
373 for p in target.target_parameters:
374 r = target.target_ranges[p.name_nogroups]
375 out.append((r.start, r.stop))
377 return num.array(out, dtype=float)
379 def get_dependant_bounds(self):
380 return num.zeros((0, 2))
382 def get_combined_bounds(self):
383 return num.vstack((
384 self.get_parameter_bounds(),
385 self.get_dependant_bounds()))
387 def raise_invalid_norm_exponent(self):
388 raise GrondError('Invalid norm exponent: %f' % self.norm_exponent)
390 def get_norm_functions(self):
391 if self.norm_exponent == 2:
392 def sqr(x):
393 return x**2
395 return sqr, num.sqrt
397 elif self.norm_exponent == 1:
398 def noop(x):
399 return x
401 return noop, num.abs
403 else:
404 self.raise_invalid_norm_exponent()
406 def combine_misfits(
407 self, misfits,
408 extra_weights=None,
409 extra_residuals=None,
410 extra_correlated_weights=dict(),
411 get_contributions=False):
413 '''
414 Combine misfit contributions (residuals) to global or bootstrap misfits
416 :param misfits: 3D array ``misfits[imodel, iresidual, 0]`` are the
417 misfit contributions (residuals) ``misfits[imodel, iresidual, 1]``
418 are the normalisation contributions. It is also possible to give
419 the misfit and normalisation contributions for a single model as
420 ``misfits[iresidual, 0]`` and misfits[iresidual, 1]`` in which
421 case, the first dimension (imodel) of the result will be stipped
422 off.
424 :param extra_weights: if given, 2D array of extra weights to be applied
425 to the contributions, indexed as
426 ``extra_weights[ibootstrap, iresidual]``.
428 :param extra_residuals: if given, 2D array of perturbations to be added
429 to the residuals, indexed as
430 ``extra_residuals[ibootstrap, iresidual]``.
432 :param extra_correlated_weights: if a dictionary of
433 ``imisfit: correlated weight matrix`` is passed a correlated
434 weight matrix is applied to the misfit and normalisation values.
435 `imisfit` is the starting index in the misfits vector the
436 correlated weight matrix applies to.
438 :param get_contributions: get the weighted and perturbed contributions
439 (don't do the sum).
441 :returns: if no *extra_weights* or *extra_residuals* are given, a 1D
442 array indexed as ``misfits[imodel]`` containing the global misfit
443 for each model is returned, otherwise a 2D array
444 ``misfits[imodel, ibootstrap]`` with the misfit for every model and
445 weighting/residual set is returned.
446 '''
447 if misfits.ndim == 2:
448 misfits = misfits[num.newaxis, :, :]
449 return self.combine_misfits(
450 misfits, extra_weights, extra_residuals,
451 extra_correlated_weights, get_contributions)[0, ...]
453 if extra_weights is None and extra_residuals is None:
454 return self.combine_misfits(
455 misfits, False, False,
456 extra_correlated_weights, get_contributions)[:, 0]
458 assert misfits.ndim == 3
459 assert not num.any(extra_weights) or extra_weights.ndim == 2
460 assert not num.any(extra_residuals) or extra_residuals.ndim == 2
462 if self.norm_exponent != 2 and extra_correlated_weights:
463 raise GrondError('Correlated weights can only be used '
464 ' with norm_exponent=2')
466 exp, root = self.get_norm_functions()
468 nmodels = misfits.shape[0]
469 nmisfits = misfits.shape[1] # noqa
471 mf = misfits[:, num.newaxis, :, :].copy()
473 if num.any(extra_residuals):
474 mf = mf + extra_residuals[num.newaxis, :, :, num.newaxis]
476 res = mf[..., 0]
477 norms = mf[..., 1]
479 for imisfit, corr_weight_mat in extra_correlated_weights.items():
481 jmisfit = imisfit + corr_weight_mat.shape[0]
483 for imodel in range(nmodels):
484 corr_res = res[imodel, :, imisfit:jmisfit]
485 corr_norms = norms[imodel, :, imisfit:jmisfit]
487 res[imodel, :, imisfit:jmisfit] = \
488 correlated_weights(corr_res, corr_weight_mat)
490 norms[imodel, :, imisfit:jmisfit] = \
491 correlated_weights(corr_norms, corr_weight_mat)
493 # Apply normalization family weights (these weights depend on
494 # on just calculated correlated norms!)
495 weights_fam = \
496 self.inter_family_weights2(norms[:, 0, :])[:, num.newaxis, :]
498 weights_fam = exp(weights_fam)
500 res = exp(res)
501 norms = exp(norms)
503 res *= weights_fam
504 norms *= weights_fam
506 weights_tar = self.get_target_weights()[num.newaxis, num.newaxis, :]
507 if num.any(extra_weights):
508 weights_tar = weights_tar * extra_weights[num.newaxis, :, :]
510 weights_tar = exp(weights_tar)
512 res = res * weights_tar
513 norms = norms * weights_tar
515 if get_contributions:
516 return res / num.nansum(norms, axis=2)[:, :, num.newaxis]
518 result = root(
519 num.nansum(res, axis=2) /
520 num.nansum(norms, axis=2))
522 assert result[result < 0].size == 0
523 return result
525 def make_family_mask(self):
526 family_names = set()
527 families = num.zeros(self.nmisfits, dtype=int)
529 idx = 0
530 for itarget, target in enumerate(self.targets):
531 family_names.add(target.normalisation_family)
532 families[idx:idx + target.nmisfits] = len(family_names) - 1
533 idx += target.nmisfits
535 return families, len(family_names)
537 def get_family_mask(self):
538 if self._family_mask is None:
539 self._family_mask = self.make_family_mask()
541 return self._family_mask
543 def evaluate(self, x, mask=None, result_mode='full', targets=None):
544 source = self.get_source(x)
545 engine = self.get_engine()
547 self.set_target_parameter_values(x)
549 if mask is not None and targets is not None:
550 raise ValueError('Mask cannot be defined with targets set.')
551 targets = targets if targets is not None else self.targets
553 for target in targets:
554 target.set_result_mode(result_mode)
556 modelling_targets = []
557 t2m_map = {}
558 for itarget, target in enumerate(targets):
559 t2m_map[target] = target.prepare_modelling(engine, source, targets)
560 if mask is None or mask[itarget]:
561 modelling_targets.extend(t2m_map[target])
563 u2m_map = {}
564 for imtarget, mtarget in enumerate(modelling_targets):
565 if mtarget not in u2m_map:
566 u2m_map[mtarget] = []
568 u2m_map[mtarget].append(imtarget)
570 modelling_targets_unique = list(u2m_map.keys())
572 resp = engine.process(source, modelling_targets_unique)
573 modelling_results_unique = list(resp.results_list[0])
575 modelling_results = [None] * len(modelling_targets)
577 for mtarget, mresult in zip(
578 modelling_targets_unique, modelling_results_unique):
580 for itarget in u2m_map[mtarget]:
581 modelling_results[itarget] = mresult
583 imt = 0
584 results = []
585 for itarget, target in enumerate(targets):
586 nmt_this = len(t2m_map[target])
587 if mask is None or mask[itarget]:
588 result = target.finalize_modelling(
589 engine, source,
590 t2m_map[target],
591 modelling_results[imt:imt+nmt_this])
593 imt += nmt_this
594 else:
595 result = gf.SeismosizerError(
596 'target was excluded from modelling')
598 results.append(result)
600 return results
602 def misfits(self, x, mask=None):
603 results = self.evaluate(x, mask=mask, result_mode='sparse')
604 misfits = num.full((self.nmisfits, 2), num.nan)
606 imisfit = 0
607 for target, result in zip(self.targets, results):
608 if isinstance(result, MisfitResult):
609 misfits[imisfit:imisfit+target.nmisfits, :] = result.misfits
611 imisfit += target.nmisfits
613 return misfits
615 def forward(self, x):
616 source = self.get_source(x)
617 engine = self.get_engine()
619 plain_targets = []
620 for target in self.targets:
621 plain_targets.extend(target.get_plain_targets(engine, source))
623 resp = engine.process(source, plain_targets)
625 results = []
626 for target, result in zip(plain_targets, resp.results_list[0]):
627 if isinstance(result, gf.SeismosizerError):
628 logger.debug(
629 '%s.%s.%s.%s: %s' % (target.codes + (str(result),)))
630 else:
631 results.append(result)
633 return results
635 def get_random_model(self, ntries_limit=100):
636 xbounds = self.get_parameter_bounds()
638 for _ in range(ntries_limit):
639 x = self.random_uniform(xbounds, rstate=g_rstate)
640 try:
641 return self.preconstrain(x)
643 except Forbidden:
644 pass
646 raise GrondError(
647 'Could not find any suitable candidate sample within %i tries' % (
648 ntries_limit))
651class ProblemInfoNotAvailable(GrondError):
652 pass
655class ProblemDataNotAvailable(GrondError):
656 pass
659class NoSuchAttribute(GrondError):
660 pass
663class InvalidAttributeName(GrondError):
664 pass
667class ModelHistory(object):
668 '''
669 Write, read and follow sequences of models produced in an optimisation run.
671 :param problem: :class:`grond.Problem` instance
672 :param path: path to rundir, defaults to None
673 :type path: str, optional
674 :param mode: open mode, 'r': read, 'w': write
675 :type mode: str, optional
676 '''
678 nmodels_capacity_min = 1024
680 def __init__(self, problem, nchains=None, path=None, mode='r'):
681 self.mode = mode
683 self.problem = problem
684 self.path = path
685 self.nchains = nchains
687 self._models_buffer = None
688 self._misfits_buffer = None
689 self._bootstraps_buffer = None
690 self._sample_contexts_buffer = None
692 self._sorted_misfit_idx = {}
694 self.models = None
695 self.misfits = None
696 self.bootstrap_misfits = None
698 self.sampler_contexts = None
700 self.nmodels_capacity = self.nmodels_capacity_min
701 self.listeners = []
703 self._attributes = {}
705 if mode == 'r':
706 self.load()
708 @staticmethod
709 def verify_rundir(rundir):
710 _rundir_files = ('misfits', 'models')
712 if not op.exists(rundir):
713 raise ProblemDataNotAvailable(
714 'Directory does not exist: %s' % rundir)
715 for f in _rundir_files:
716 if not op.exists(op.join(rundir, f)):
717 raise ProblemDataNotAvailable('File not found: %s' % f)
719 @classmethod
720 def follow(cls, path, nchains=None, wait=20.):
721 '''
722 Start following a rundir (constructor).
724 :param path: the path to follow, a grond rundir
725 :type path: str, optional
726 :param wait: wait time until the folder become alive
727 :type wait: number in seconds, optional
728 :returns: A :py:class:`ModelHistory` instance
729 '''
730 start_watch = time.time()
731 while (time.time() - start_watch) < wait:
732 try:
733 cls.verify_rundir(path)
734 problem = load_problem_info(path)
735 return cls(problem, nchains=nchains, path=path, mode='r')
736 except (ProblemDataNotAvailable, OSError):
737 time.sleep(.25)
739 @property
740 def nmodels(self):
741 if self.models is None:
742 return 0
743 else:
744 return self.models.shape[0]
746 @nmodels.setter
747 def nmodels(self, nmodels_new):
748 assert 0 <= nmodels_new <= self.nmodels
749 self.models = self._models_buffer[:nmodels_new, :]
750 self.misfits = self._misfits_buffer[:nmodels_new, :, :]
751 if self.nchains is not None:
752 self.bootstrap_misfits = self._bootstraps_buffer[:nmodels_new, :, :] # noqa
753 if self._sample_contexts_buffer is not None:
754 self.sampler_contexts = self._sample_contexts_buffer[:nmodels_new, :] # noqa
756 @property
757 def nmodels_capacity(self):
758 if self._models_buffer is None:
759 return 0
760 else:
761 return self._models_buffer.shape[0]
763 @nmodels_capacity.setter
764 def nmodels_capacity(self, nmodels_capacity_new):
765 if self.nmodels_capacity != nmodels_capacity_new:
767 models_buffer = num.zeros(
768 (nmodels_capacity_new, self.problem.nparameters),
769 dtype=float)
770 misfits_buffer = num.zeros(
771 (nmodels_capacity_new, self.problem.nmisfits, 2),
772 dtype=float)
773 sample_contexts_buffer = num.zeros(
774 (nmodels_capacity_new, 4),
775 dtype=int)
776 sample_contexts_buffer.fill(-1)
778 if self.nchains is not None:
779 bootstraps_buffer = num.zeros(
780 (nmodels_capacity_new, self.nchains),
781 dtype=float)
783 ncopy = min(self.nmodels, nmodels_capacity_new)
785 if self._models_buffer is not None:
786 models_buffer[:ncopy, :] = \
787 self._models_buffer[:ncopy, :]
788 misfits_buffer[:ncopy, :, :] = \
789 self._misfits_buffer[:ncopy, :, :]
790 sample_contexts_buffer[:ncopy, :] = \
791 self._sample_contexts_buffer[:ncopy, :]
793 self._models_buffer = models_buffer
794 self._misfits_buffer = misfits_buffer
795 self._sample_contexts_buffer = sample_contexts_buffer
797 if self.nchains is not None:
798 if self._bootstraps_buffer is not None:
799 bootstraps_buffer[:ncopy, :] = \
800 self._bootstraps_buffer[:ncopy, :]
801 self._bootstraps_buffer = bootstraps_buffer
803 def clear(self):
804 assert self.mode != 'r', 'History is read-only, cannot clear.'
805 self.nmodels = 0
806 self.nmodels_capacity = self.nmodels_capacity_min
808 def extend(
809 self, models, misfits,
810 bootstrap_misfits=None,
811 sampler_contexts=None):
813 nmodels = self.nmodels
814 n = models.shape[0]
816 nmodels_capacity_want = max(
817 self.nmodels_capacity_min, nextpow2(nmodels + n))
819 if nmodels_capacity_want != self.nmodels_capacity:
820 self.nmodels_capacity = nmodels_capacity_want
822 self._models_buffer[nmodels:nmodels+n, :] = models
823 self._misfits_buffer[nmodels:nmodels+n, :, :] = misfits
825 self.models = self._models_buffer[:nmodels+n, :]
826 self.misfits = self._misfits_buffer[:nmodels+n, :, :]
828 if bootstrap_misfits is not None:
829 self._bootstraps_buffer[nmodels:nmodels+n, :] = bootstrap_misfits
830 self.bootstrap_misfits = self._bootstraps_buffer[:nmodels+n, :]
832 if sampler_contexts is not None:
833 self._sample_contexts_buffer[nmodels:nmodels+n, :] \
834 = sampler_contexts
835 self.sampler_contexts = self._sample_contexts_buffer[:nmodels+n, :]
837 if self.path and self.mode == 'w':
838 for i in range(n):
839 self.problem.dump_problem_data(
840 self.path, models[i, :], misfits[i, :, :],
841 bootstrap_misfits[i, :]
842 if bootstrap_misfits is not None else None,
843 sampler_contexts[i, :]
844 if sampler_contexts is not None else None)
846 self._sorted_misfit_idx.clear()
848 self.emit('extend', nmodels, n, models, misfits, sampler_contexts)
850 def append(
851 self, model, misfits,
852 bootstrap_misfits=None,
853 sampler_context=None):
855 if bootstrap_misfits is not None:
856 bootstrap_misfits = bootstrap_misfits[num.newaxis, :]
858 if sampler_context is not None:
859 sampler_context = sampler_context[num.newaxis, :]
861 return self.extend(
862 model[num.newaxis, :], misfits[num.newaxis, :, :],
863 bootstrap_misfits, sampler_context)
865 def load(self):
866 self.mode = 'r'
867 self.verify_rundir(self.path)
868 models, misfits, bootstraps, sampler_contexts = load_problem_data(
869 self.path, self.problem, nchains=self.nchains)
870 self.extend(models, misfits, bootstraps, sampler_contexts)
872 def update(self):
873 ''' Update history from path '''
874 nmodels_available = get_nmodels(self.path, self.problem)
875 if self.nmodels == nmodels_available:
876 return
878 try:
879 new_models, new_misfits, new_bootstraps, new_sampler_contexts = \
880 load_problem_data(
881 self.path,
882 self.problem,
883 nmodels_skip=self.nmodels,
884 nchains=self.nchains)
886 except ValueError:
887 return
889 self.extend(
890 new_models,
891 new_misfits,
892 new_bootstraps,
893 new_sampler_contexts)
895 def add_listener(self, listener):
896 ''' Add a listener to the history
898 The listening class can implement the following methods:
899 * ``extend``
900 '''
901 self.listeners.append(listener)
903 def emit(self, event_name, *args, **kwargs):
904 for listener in self.listeners:
905 slot = getattr(listener, event_name, None)
906 if callable(slot):
907 slot(*args, **kwargs)
909 @property
910 def attribute_names(self):
911 apath = op.join(self.path, 'attributes')
912 if not os.path.exists(apath):
913 return []
915 return [fn for fn in os.listdir(apath)
916 if StringID.regex.match(fn)]
918 def get_attribute(self, name):
919 if name not in self._attributes:
920 if name not in self.attribute_names:
921 raise NoSuchAttribute(name)
923 path = op.join(self.path, 'attributes', name)
925 with open(path, 'rb') as f:
926 self._attributes[name] = num.fromfile(
927 f, dtype='<i4',
928 count=self.nmodels).astype(int)
930 assert self._attributes[name].shape == (self.nmodels,)
932 return self._attributes[name]
934 def set_attribute(self, name, attribute):
935 if not StringID.regex.match(name):
936 raise InvalidAttributeName(name)
938 attribute = attribute.astype(int)
939 assert attribute.shape == (self.nmodels,)
941 apath = op.join(self.path, 'attributes')
943 if not os.path.exists(apath):
944 os.mkdir(apath)
946 path = op.join(apath, name)
948 with open(path, 'wb') as f:
949 attribute.astype('<i4').tofile(f)
951 self._attributes[name] = attribute
953 def ensure_bootstrap_misfits(self, optimiser):
954 if self.bootstrap_misfits is None:
955 problem = self.problem
956 self.bootstrap_misfits = problem.combine_misfits(
957 self.misfits,
958 extra_weights=optimiser.get_bootstrap_weights(problem),
959 extra_residuals=optimiser.get_bootstrap_residuals(problem))
961 def imodels_by_cluster(self, cluster_attribute):
962 if cluster_attribute is None:
963 return [(-1, 100.0, num.arange(self.nmodels))]
965 by_cluster = []
966 try:
967 iclusters = self.get_attribute(cluster_attribute)
968 iclusters_avail = num.unique(iclusters)
970 for icluster in iclusters_avail:
971 imodels = num.where(iclusters == icluster)[0]
972 by_cluster.append(
973 (icluster,
974 (100.0 * imodels.size) / self.nmodels,
975 imodels))
977 if by_cluster and by_cluster[0][0] == -1:
978 by_cluster.append(by_cluster.pop(0))
980 except NoSuchAttribute:
981 logger.warning(
982 'Attribute %s not set in run %s.\n'
983 ' Skipping model retrieval by clusters.' % (
984 cluster_attribute, self.problem.name))
986 return by_cluster
988 def models_by_cluster(self, cluster_attribute):
989 if cluster_attribute is None:
990 return [(-1, 100.0, self.models)]
992 return [
993 (icluster, percentage, self.models[imodels])
994 for (icluster, percentage, imodels)
995 in self.imodels_by_cluster(cluster_attribute)]
997 def mean_sources_by_cluster(self, cluster_attribute):
998 return [
999 (icluster, percentage, stats.get_mean_source(self.problem, models))
1000 for (icluster, percentage, models)
1001 in self.models_by_cluster(cluster_attribute)]
1003 def get_sorted_misfits_idx(self, chain=0):
1004 if chain not in self._sorted_misfit_idx.keys():
1005 self._sorted_misfit_idx[chain] = num.argsort(
1006 self.bootstrap_misfits[:, chain])
1008 return self._sorted_misfit_idx[chain]
1010 def get_sorted_misfits(self, chain=0):
1011 isort = self.get_sorted_misfits_idx(chain)
1012 return self.bootstrap_misfits[:, chain][isort]
1014 def get_sorted_models(self, chain=0):
1015 isort = self.get_sorted_misfits_idx(chain=0)
1016 return self.models[isort, :]
1018 def get_sorted_primary_misfits(self):
1019 return self.get_sorted_misfits(chain=0)
1021 def get_sorted_primary_models(self):
1022 return self.get_sorted_models(chain=0)
1024 def get_best_model(self, chain=0):
1025 return self.get_sorted_models(chain)[0, ...]
1027 def get_best_misfit(self, chain=0):
1028 return self.get_sorted_misfits(chain)[0]
1030 def get_mean_model(self):
1031 return num.mean(self.models, axis=0)
1033 def get_mean_misfit(self, chain=0):
1034 return num.mean(self.bootstrap_misfits[:, chain])
1036 def get_best_source(self, chain=0):
1037 return self.problem.get_source(self.get_best_model(chain))
1039 def get_mean_source(self, chain=0):
1040 return self.problem.get_source(self.get_mean_model())
1042 def get_chain_misfits(self, chain=0):
1043 return self.bootstrap_misfits[:, chain]
1045 def get_primary_chain_misfits(self):
1046 return self.get_chain_misfits(chain=0)
1049def get_nmodels(dirname, problem):
1050 fn = op.join(dirname, 'models')
1051 with open(fn, 'r') as f:
1052 nmodels1 = os.fstat(f.fileno()).st_size // (problem.nparameters * 8)
1054 fn = op.join(dirname, 'misfits')
1055 with open(fn, 'r') as f:
1056 nmodels2 = os.fstat(f.fileno()).st_size // (problem.nmisfits * 2 * 8)
1058 return min(nmodels1, nmodels2)
1061def load_problem_info_and_data(dirname, subset=None, nchains=None):
1062 problem = load_problem_info(dirname)
1063 models, misfits, bootstraps, sampler_contexts = load_problem_data(
1064 xjoin(dirname, subset), problem, nchains=nchains)
1065 return problem, models, misfits, bootstraps, sampler_contexts
1068def load_optimiser_info(dirname):
1069 fn = op.join(dirname, 'optimiser.yaml')
1070 return guts.load(filename=fn)
1073def load_problem_info(dirname):
1074 try:
1075 fn = op.join(dirname, 'problem.yaml')
1076 return guts.load(filename=fn)
1077 except OSError as e:
1078 logger.debug(e)
1079 raise ProblemInfoNotAvailable(
1080 'No problem info available (%s).' % dirname)
1083def load_problem_data(dirname, problem, nmodels_skip=0, nchains=None):
1085 def get_chains_fn():
1086 for fn in (op.join(dirname, 'bootstraps'),
1087 op.join(dirname, 'chains')):
1088 if op.exists(fn):
1089 return fn
1090 return False
1092 try:
1093 nmodels = get_nmodels(dirname, problem) - nmodels_skip
1095 fn = op.join(dirname, 'models')
1096 with open(fn, 'r') as f:
1097 f.seek(nmodels_skip * problem.nparameters * 8)
1098 models = num.fromfile(
1099 f, dtype='<f8',
1100 count=nmodels * problem.nparameters)\
1101 .astype(float)
1103 models = models.reshape((nmodels, problem.nparameters))
1105 fn = op.join(dirname, 'misfits')
1106 with open(fn, 'r') as f:
1107 f.seek(nmodels_skip * problem.nmisfits * 2 * 8)
1108 misfits = num.fromfile(
1109 f, dtype='<f8',
1110 count=nmodels*problem.nmisfits*2)\
1111 .astype(float)
1112 misfits = misfits.reshape((nmodels, problem.nmisfits, 2))
1114 chains = None
1115 fn = get_chains_fn()
1116 if fn and nchains is not None:
1117 with open(fn, 'r') as f:
1118 f.seek(nmodels_skip * nchains * 8)
1119 chains = num.fromfile(
1120 f, dtype='<f8',
1121 count=nmodels*nchains)\
1122 .astype(float)
1124 chains = chains.reshape((nmodels, nchains))
1126 sampler_contexts = None
1127 fn = op.join(dirname, 'choices')
1128 if op.exists(fn):
1129 with open(fn, 'r') as f:
1130 f.seek(nmodels_skip * 4 * 8)
1131 sampler_contexts = num.fromfile(
1132 f, dtype='<i8',
1133 count=nmodels*4).astype(int)
1135 sampler_contexts = sampler_contexts.reshape((nmodels, 4))
1137 except OSError as e:
1138 logger.debug(str(e))
1139 raise ProblemDataNotAvailable(
1140 'No problem data available (%s).' % dirname)
1142 return models, misfits, chains, sampler_contexts
1145__all__ = '''
1146 ProblemConfig
1147 Problem
1148 ModelHistory
1149 ProblemInfoNotAvailable
1150 ProblemDataNotAvailable
1151 load_problem_info
1152 load_problem_info_and_data
1153 InvalidAttributeName
1154 NoSuchAttribute
1155'''.split()