Coverage for /usr/local/lib/python3.11/dist-packages/grond/problems/base.py: 79%
757 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-25 10:12 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-25 10:12 +0000
1'''
2Base classes for Grond's problem definition and the model history container.
4Common behaviour of all source models offered by Grond is implemented here.
5Source model specific details are implemented in the respective submodules.
6'''
8import numpy as num
9import math
10import copy
11import logging
12import os.path as op
13import os
14import time
15import struct
16import threading
18from pyrocko import gf, util, guts, orthodrome as pod
19from pyrocko.guts import Object, String, List, Dict, Int
21from grond.meta import ADict, Parameter, GrondError, xjoin, Forbidden, \
22 StringID, has_get_plot_classes
23from ..targets import MisfitResult, MisfitTarget, TargetGroup, \
24 WaveformMisfitTarget, SatelliteMisfitTarget, GNSSCampaignMisfitTarget
26from grond import stats
28from grond.version import __version__
30guts_prefix = 'grond'
31logger = logging.getLogger('grond.problems.base')
32km = 1e3
33as_km = dict(scale_factor=km, scale_unit='km')
35g_rstate = num.random.RandomState()
38def nextpow2(i):
39 return 2**int(math.ceil(math.log(i) / math.log(2.)))
42def nextcapacity(i):
43 return int(math.ceil(i / 1024) * 1024)
46def correlated_weights(values, weight_matrix):
47 '''
48 Applies correlated weights to values
50 The resulting weighed values have to be squared! Check out
51 :meth:`Problem.combine_misfits` for more information.
53 :param values: Misfits or norms as :class:`numpy.Array`
54 :param weight: Weight matrix, commonly the inverse of covariance matrix
56 :returns: :class:`numpy.Array` weighted values
57 '''
58 return num.matmul(values, weight_matrix)
61class ProblemConfig(Object):
62 '''
63 Base class for config section defining the objective function setup.
65 Factory for :py:class:`Problem` objects.
66 '''
67 name_template = String.T()
68 norm_exponent = Int.T(default=2)
69 nthreads = Int.T(default=1)
71 def get_problem(self, event, target_groups, targets):
72 '''
73 Instantiate the problem with a given event and targets.
75 :returns: :py:class:`Problem` object
76 '''
77 raise NotImplementedError
80@has_get_plot_classes
81class Problem(Object):
82 '''
83 Base class for objective function setup.
85 Defines the *problem* to be solved by the optimiser.
86 '''
87 name = String.T()
88 ranges = Dict.T(String.T(), gf.Range.T())
89 dependants = List.T(Parameter.T())
90 norm_exponent = Int.T(default=2)
91 base_source = gf.Source.T(optional=True)
92 targets = List.T(MisfitTarget.T())
93 target_groups = List.T(TargetGroup.T())
94 grond_version = String.T(optional=True)
95 nthreads = Int.T(default=1)
97 def __init__(self, **kwargs):
98 Object.__init__(self, **kwargs)
100 if self.grond_version is None:
101 self.grond_version = __version__
103 self._target_weights = None
104 self._engine = None
105 self._family_mask = None
106 self._rstate_manager = None
108 if hasattr(self, 'problem_waveform_parameters') and self.has_waveforms:
109 self.problem_parameters =\
110 self.problem_parameters + self.problem_waveform_parameters
112 unused_parameters = []
113 for p in self.problem_parameters:
114 if p.optional and p._name not in self.ranges.keys():
115 unused_parameters.append(p)
117 for p in unused_parameters:
118 self.problem_parameters.remove(p)
120 self.check()
122 @classmethod
123 def get_plot_classes(cls):
124 from . import plot
125 return plot.get_plot_classes()
127 def check(self):
128 paths = set()
129 for grp in self.target_groups:
130 if grp.path == 'all':
131 continue
132 if grp.path in paths:
133 raise ValueError('Path %s defined more than once! In %s'
134 % (grp.path, grp.__class__.__name__))
135 paths.add(grp.path)
136 logger.debug('TargetGroup check OK.')
138 def copy(self):
139 o = copy.copy(self)
140 o._target_weights = None
141 return o
143 def set_target_parameter_values(self, x):
144 nprob = len(self.problem_parameters)
145 for target in self.targets:
146 target.set_parameter_values(x[nprob:nprob + target.nparameters])
147 nprob += target.nparameters
149 def get_parameter_dict(self, model, group=None):
150 params = [(p.name, model[ip])
151 for ip, p in enumerate(self.parameters)
152 if group in p.groups or group is None]
153 return ADict(params)
155 def get_parameter_array(self, d):
156 arr = num.zeros(self.nparameters, dtype=float)
157 for ip, p in enumerate(self.parameters):
158 if p.name in d.keys():
159 arr[ip] = d[p.name]
160 return arr
162 def get_parameter_index(self, param_name):
163 return {k.name: ik for ik, k in enumerate(self.parameters)}[param_name]
165 def get_rstate_manager(self):
166 if self._rstate_manager is None:
167 self._rstate_manager = RandomStateManager()
168 return self._rstate_manager
170 def dump_problem_info(self, dirname):
171 fn = op.join(dirname, 'problem.yaml')
172 util.ensuredirs(fn)
173 guts.dump(self, filename=fn)
175 def dump_problem_data(
176 self, dirname, x, misfits, chains=None,
177 sampler_context=None):
179 fn = op.join(dirname, 'models')
180 if not isinstance(x, num.ndarray):
181 x = num.array(x)
182 with open(fn, 'ab') as f:
183 x.astype('<f8').tofile(f)
185 fn = op.join(dirname, 'misfits')
186 with open(fn, 'ab') as f:
187 misfits.astype('<f8').tofile(f)
189 if chains is not None:
190 fn = op.join(dirname, 'chains')
191 with open(fn, 'ab') as f:
192 chains.astype('<f8').tofile(f)
194 if sampler_context is not None:
195 fn = op.join(dirname, 'choices')
196 with open(fn, 'ab') as f:
197 num.array(sampler_context, dtype='<i8').tofile(f)
199 fn = op.join(dirname, 'rstate')
200 self.get_rstate_manager().save_state(fn)
202 def name_to_index(self, name):
203 pnames = [p.name for p in self.combined]
204 return pnames.index(name)
206 @property
207 def parameters(self):
208 target_parameters = []
209 for target in self.targets:
210 target_parameters.extend(target.target_parameters)
211 return self.problem_parameters + target_parameters
213 @property
214 def parameter_names(self):
215 return [p.name for p in self.combined]
217 @property
218 def dependant_names(self):
219 return [p.name for p in self.dependants]
221 @property
222 def nparameters(self):
223 return len(self.parameters)
225 @property
226 def ntargets(self):
227 return len(self.targets)
229 @property
230 def nwaveform_targets(self):
231 return len(self.waveform_targets)
233 @property
234 def nsatellite_targets(self):
235 return len(self.satellite_targets)
237 @property
238 def ngnss_targets(self):
239 return len(self.gnss_targets)
241 @property
242 def nmisfits(self):
243 nmisfits = 0
244 for target in self.targets:
245 nmisfits += target.nmisfits
246 return nmisfits
248 @property
249 def ndependants(self):
250 return len(self.dependants)
252 @property
253 def ncombined(self):
254 return len(self.parameters) + len(self.dependants)
256 @property
257 def combined(self):
258 return self.parameters + self.dependants
260 @property
261 def satellite_targets(self):
262 return [t for t in self.targets
263 if isinstance(t, SatelliteMisfitTarget)]
265 @property
266 def gnss_targets(self):
267 return [t for t in self.targets
268 if isinstance(t, GNSSCampaignMisfitTarget)]
270 @property
271 def waveform_targets(self):
272 return [t for t in self.targets
273 if isinstance(t, WaveformMisfitTarget)]
275 @property
276 def has_satellite(self):
277 if self.satellite_targets:
278 return True
279 return False
281 @property
282 def has_waveforms(self):
283 if self.waveform_targets:
284 return True
285 return False
287 def set_engine(self, engine):
288 self._engine = engine
290 def get_engine(self):
291 return self._engine
293 def get_source(self, x):
294 raise NotImplementedError
296 def pack(self, source):
297 raise NotImplementedError
299 def source_to_x(self, source):
300 bs = self.base_source
302 n, e = pod.latlon_to_ne(
303 bs.lat, bs.lon,
304 source.effective_lat, source.effective_lon)
306 source.lat, source.lon = bs.lat, bs.lon
307 source.north_shift = n
308 source.east_shift = e
310 tmin, tmax = self.ranges['time'].start, self.ranges['time'].stop
312 if (source.time - bs.time < tmin) or (source.time - bs.time > tmax):
313 rstatem = self.get_rstate_manager()
314 rstate = rstatem.get_rstate(name='source2x')
316 source.time = bs.time
317 source.time += rstate.uniform(low=tmin, high=tmax, size=1)
319 p = {}
320 for k in self.base_source.keys():
321 val = source[k]
322 if k == 'time':
323 p[k] = float(val)
325 elif k in self.ranges:
326 val = source[k]
328 if self.ranges[k].relative == 'add':
329 val = self.ranges[k].make_relative(
330 -self.base_source[k], source[k])
331 elif self.ranges[k].relative == 'mult':
332 val = self.ranges[k].make_relative(
333 1./self.base_source[k], source[k])
335 p[k] = float(val)
337 source = source.clone(**p)
339 return self.pack(source)
341 def get_gf_store_ids(self):
342 return tuple(set([t.store_id for t in self.targets]))
344 def get_gf_store(self, target):
345 if not isinstance(target, str):
346 target = target.store_id
347 if self.get_engine() is None:
348 raise GrondError('Cannot get GF Store, modelling is not set up.')
349 return self.get_engine().get_store(target)
351 def random_uniform(self, xbounds, rstate, fixed_magnitude=None):
352 if fixed_magnitude is not None:
353 raise GrondError(
354 'Setting fixed magnitude in random model generation not '
355 'supported for this type of problem.')
357 x = rstate.uniform(0., 1., self.nparameters)
358 x *= (xbounds[:, 1] - xbounds[:, 0])
359 x += xbounds[:, 0]
360 return x
362 def minimum(self, xbounds, fixed_magnitude=None):
363 if fixed_magnitude is not None:
364 raise GrondError(
365 'Setting fixed magnitude in random model generation not '
366 'supported for this type of problem.')
368 x = xbounds[:, 0]
369 return x
371 def maximum(self, xbounds, fixed_magnitude=None):
372 if fixed_magnitude is not None:
373 raise GrondError(
374 'Setting fixed magnitude in random model generation not '
375 'supported for this type of problem.')
377 x = xbounds[:, 1]
378 return x
380 def preconstrain(self, x, optimizer=None):
381 return x
383 def extract(self, xs, i):
384 if xs.ndim == 1:
385 return self.extract(xs[num.newaxis, :], i)[0]
387 if i < self.nparameters:
388 return xs[:, i]
389 else:
390 return self.make_dependant(
391 xs, self.dependants[i - self.nparameters].name)
393 def get_target_weights(self):
394 if self._target_weights is None:
395 self._target_weights = num.concatenate(
396 [target.get_combined_weight() for target in self.targets])
398 return self._target_weights
400 def get_target_residuals(self):
401 pass
403 def inter_family_weights(self, ns):
404 exp, root = self.get_norm_functions()
406 family, nfamilies = self.get_family_mask()
408 ws = num.zeros(self.nmisfits)
409 for ifamily in range(nfamilies):
410 mask = family == ifamily
411 ws[mask] = 1.0 / root(num.nansum(exp(ns[mask])))
413 return ws
415 def inter_family_weights2(self, ns):
416 '''
417 :param ns: 2D array with normalization factors ``ns[imodel, itarget]``
418 :returns: 2D array ``weights[imodel, itarget]``
419 '''
421 exp, root = self.get_norm_functions()
422 family, nfamilies = self.get_family_mask()
424 ws = num.zeros(ns.shape)
425 for ifamily in range(nfamilies):
426 mask = family == ifamily
427 ws[:, mask] = (1.0 / root(
428 num.nansum(exp(ns[:, mask]), axis=1)))[:, num.newaxis]
430 return ws
432 def get_reference_model(self):
433 model = num.zeros(self.nparameters)
434 model_source_params = self.pack(self.base_source)
435 model[:model_source_params.size] = model_source_params
436 return model
438 def get_parameter_bounds(self):
439 out = []
440 for p in self.problem_parameters:
441 r = self.ranges[p.name]
442 out.append((r.start, r.stop))
444 for target in self.targets:
445 for p in target.target_parameters:
446 r = target.target_ranges[p.name_nogroups]
447 out.append((r.start, r.stop))
449 return num.array(out, dtype=float)
451 def get_dependant_bounds(self):
452 return num.zeros((0, 2))
454 def get_combined_bounds(self):
455 return num.vstack((
456 self.get_parameter_bounds(),
457 self.get_dependant_bounds()))
459 def raise_invalid_norm_exponent(self):
460 raise GrondError('Invalid norm exponent: %f' % self.norm_exponent)
462 def get_norm_functions(self):
463 if self.norm_exponent == 2:
464 def sqr(x):
465 return x**2
467 return sqr, num.sqrt
469 elif self.norm_exponent == 1:
470 def noop(x):
471 return x
473 return noop, num.abs
475 else:
476 self.raise_invalid_norm_exponent()
478 def combine_misfits(
479 self, misfits,
480 extra_weights=None,
481 extra_residuals=None,
482 extra_correlated_weights=dict(),
483 get_contributions=False):
484 '''
485 Combine misfit contributions (residuals) to global or bootstrap misfits
487 :param misfits: 3D array ``misfits[imodel, iresidual, 0]`` are the
488 misfit contributions (residuals) ``misfits[imodel, iresidual, 1]``
489 are the normalisation contributions. It is also possible to give
490 the misfit and normalisation contributions for a single model as
491 ``misfits[iresidual, 0]`` and misfits[iresidual, 1]`` in which
492 case, the first dimension (imodel) of the result will be stipped
493 off.
495 :param extra_weights: if given, 2D array of extra weights to be applied
496 to the contributions, indexed as
497 ``extra_weights[ibootstrap, iresidual]``.
499 :param extra_residuals: if given, 2D array of perturbations to be added
500 to the residuals, indexed as
501 ``extra_residuals[ibootstrap, iresidual]``.
503 :param extra_correlated_weights: if a dictionary of
504 ``imisfit: correlated weight matrix`` is passed a correlated
505 weight matrix is applied to the misfit and normalisation values.
506 `imisfit` is the starting index in the misfits vector the
507 correlated weight matrix applies to.
509 :param get_contributions: get the weighted and perturbed contributions
510 (don't do the sum).
512 :returns: if no *extra_weights* or *extra_residuals* are given, a 1D
513 array indexed as ``misfits[imodel]`` containing the global misfit
514 for each model is returned, otherwise a 2D array
515 ``misfits[imodel, ibootstrap]`` with the misfit for every model and
516 weighting/residual set is returned.
517 '''
518 if misfits.ndim == 2:
519 misfits = misfits[num.newaxis, :, :]
520 return self.combine_misfits(
521 misfits, extra_weights, extra_residuals,
522 extra_correlated_weights, get_contributions)[0, ...]
524 if extra_weights is None and extra_residuals is None:
525 return self.combine_misfits(
526 misfits, False, False,
527 extra_correlated_weights, get_contributions)[:, 0]
529 assert misfits.ndim == 3
530 assert not num.any(extra_weights) or extra_weights.ndim == 2
531 assert not num.any(extra_residuals) or extra_residuals.ndim == 2
533 if self.norm_exponent != 2 and extra_correlated_weights:
534 raise GrondError('Correlated weights can only be used '
535 ' with norm_exponent=2')
537 exp, root = self.get_norm_functions()
539 nmodels = misfits.shape[0]
540 nmisfits = misfits.shape[1] # noqa
542 mf = misfits[:, num.newaxis, :, :].copy()
544 if num.any(extra_residuals):
545 mf = mf + extra_residuals[num.newaxis, :, :, num.newaxis]
547 res = mf[..., 0]
548 norms = mf[..., 1]
550 for imisfit, corr_weight_mat in extra_correlated_weights.items():
552 jmisfit = imisfit + corr_weight_mat.shape[0]
554 for imodel in range(nmodels):
555 corr_res = res[imodel, :, imisfit:jmisfit]
556 corr_norms = norms[imodel, :, imisfit:jmisfit]
558 res[imodel, :, imisfit:jmisfit] = \
559 correlated_weights(corr_res, corr_weight_mat)
561 norms[imodel, :, imisfit:jmisfit] = \
562 correlated_weights(corr_norms, corr_weight_mat)
564 # Apply normalization family weights (these weights depend on
565 # on just calculated correlated norms!)
566 weights_fam = \
567 self.inter_family_weights2(norms[:, 0, :])[:, num.newaxis, :]
569 weights_fam = exp(weights_fam)
571 res = exp(res)
572 norms = exp(norms)
574 res *= weights_fam
575 norms *= weights_fam
577 weights_tar = self.get_target_weights()[num.newaxis, num.newaxis, :]
578 if num.any(extra_weights):
579 weights_tar = weights_tar * extra_weights[num.newaxis, :, :]
581 weights_tar = exp(weights_tar)
583 res = res * weights_tar
584 norms = norms * weights_tar
586 if get_contributions:
587 return res / num.nansum(norms, axis=2)[:, :, num.newaxis]
589 result = root(
590 num.nansum(res, axis=2) /
591 num.nansum(norms, axis=2))
593 assert result[result < 0].size == 0
594 return result
596 def make_family_mask(self):
597 family_names = set()
598 families = num.zeros(self.nmisfits, dtype=int)
600 idx = 0
601 for itarget, target in enumerate(self.targets):
602 family_names.add(target.normalisation_family)
603 families[idx:idx + target.nmisfits] = len(family_names) - 1
604 idx += target.nmisfits
606 return families, len(family_names)
608 def get_family_mask(self):
609 if self._family_mask is None:
610 self._family_mask = self.make_family_mask()
612 return self._family_mask
614 def evaluate(self, x, mask=None, result_mode='full', targets=None,
615 nthreads=1):
616 source = self.get_source(x)
617 engine = self.get_engine()
619 self.set_target_parameter_values(x)
621 if mask is not None and targets is not None:
622 raise ValueError('Mask cannot be defined with targets set.')
623 targets = targets if targets is not None else self.targets
625 for target in targets:
626 target.set_result_mode(result_mode)
628 modelling_targets = []
629 t2m_map = {}
630 for itarget, target in enumerate(targets):
631 t2m_map[target] = target.prepare_modelling(engine, source, targets)
632 if mask is None or mask[itarget]:
633 modelling_targets.extend(t2m_map[target])
635 u2m_map = {}
636 for imtarget, mtarget in enumerate(modelling_targets):
637 if mtarget not in u2m_map:
638 u2m_map[mtarget] = []
640 u2m_map[mtarget].append(imtarget)
642 modelling_targets_unique = list(u2m_map.keys())
643 resp = engine.process(source, modelling_targets_unique,
644 nthreads=nthreads)
645 modelling_results_unique = list(resp.results_list[0])
646 modelling_results = [None] * len(modelling_targets)
648 for mtarget, mresult in zip(
649 modelling_targets_unique, modelling_results_unique):
651 for itarget in u2m_map[mtarget]:
652 modelling_results[itarget] = mresult
654 imt = 0
655 results = []
656 for itarget, target in enumerate(targets):
657 nmt_this = len(t2m_map[target])
658 if mask is None or mask[itarget]:
659 result = target.finalize_modelling(
660 engine, source,
661 t2m_map[target],
662 modelling_results[imt:imt + nmt_this])
664 imt += nmt_this
665 else:
666 result = gf.SeismosizerError(
667 'target was excluded from modelling')
669 results.append(result)
671 return results
673 def misfits(self, x, mask=None, nthreads=1, raise_bad=False):
674 results = self.evaluate(
675 x, mask=mask, result_mode='sparse', nthreads=nthreads)
676 misfits = num.full((self.nmisfits, 2), num.nan)
678 imisfit = 0
679 for target, result in zip(self.targets, results):
680 if isinstance(result, MisfitResult):
681 misfits[imisfit:imisfit + target.nmisfits, :] = result.misfits
682 elif raise_bad:
683 logger.error(result)
685 imisfit += target.nmisfits
687 return misfits
689 def forward(self, x):
690 source = self.get_source(x)
691 engine = self.get_engine()
693 plain_targets = []
694 for target in self.targets:
695 plain_targets.extend(target.get_plain_targets(engine, source))
697 resp = engine.process(source, plain_targets)
699 results = []
700 for target, result in zip(plain_targets, resp.results_list[0]):
701 if isinstance(result, gf.SeismosizerError):
702 logger.debug(
703 '%s.%s.%s.%s: %s' % (target.codes + (str(result),)))
704 else:
705 results.append(result)
707 return results
709 def get_random_model(self, ntries_limit=100):
710 xbounds = self.get_parameter_bounds()
712 for _ in range(ntries_limit):
713 x = self.random_uniform(xbounds, rstate=g_rstate)
714 try:
715 return self.preconstrain(x)
717 except Forbidden:
718 pass
720 raise GrondError(
721 'Could not find any suitable candidate sample within %i tries' % (
722 ntries_limit))
724 def get_min_model(self, ntries_limit=100):
725 xbounds = self.get_parameter_bounds()
727 for _ in range(ntries_limit):
728 x = self.minimum(xbounds)
729 try:
730 return self.preconstrain(x)
732 except Forbidden:
733 pass
735 raise GrondError(
736 'Could not find any suitable candidate sample within %i tries' % (
737 ntries_limit))
739 def get_max_model(self, ntries_limit=100):
740 xbounds = self.get_parameter_bounds()
742 for _ in range(ntries_limit):
743 x = self.maximum(xbounds)
744 try:
745 return self.preconstrain(x)
747 except Forbidden:
748 pass
750 raise GrondError(
751 'Could not find any suitable candidate sample within %i tries' % (
752 ntries_limit))
755class ProblemInfoNotAvailable(GrondError):
756 pass
759class ProblemDataNotAvailable(GrondError):
760 pass
763class NoSuchAttribute(GrondError):
764 pass
767class InvalidAttributeName(GrondError):
768 pass
771class ModelHistory(object):
772 '''
773 Write, read and follow sequences of models produced in an optimisation run.
775 :param problem: :class:`grond.Problem` instance
776 :param path: path to rundir, defaults to None
777 :type path: str, optional
778 :param mode: open mode, 'r': read, 'w': write
779 :type mode: str, optional
780 '''
782 nmodels_capacity_min = 1024
784 def __init__(self, problem, nchains=None, path=None, mode='r'):
785 self.mode = mode
787 self.problem = problem
788 self.path = path
789 self.nchains = nchains
791 self._models_buffer = None
792 self._misfits_buffer = None
793 self._bootstraps_buffer = None
794 self._sample_contexts_buffer = None
796 self._sorted_misfit_idx = {}
798 self.models = None
799 self.misfits = None
800 self.bootstrap_misfits = None
802 self.sampler_contexts = None
804 self.nmodels_capacity = self.nmodels_capacity_min
805 self.listeners = []
807 self._attributes = {}
809 if mode == 'r':
810 self.load()
812 @staticmethod
813 def verify_rundir(rundir):
814 _rundir_files = ('misfits', 'models')
816 if not op.exists(rundir):
817 raise ProblemDataNotAvailable(
818 'Directory does not exist: %s' % rundir)
819 for f in _rundir_files:
820 if not op.exists(op.join(rundir, f)):
821 raise ProblemDataNotAvailable('File not found: %s' % f)
823 @classmethod
824 def follow(cls, path, nchains=None, wait=20.):
825 '''
826 Start following a rundir (constructor).
828 :param path: the path to follow, a grond rundir
829 :type path: str, optional
830 :param wait: wait time until the folder become alive
831 :type wait: number in seconds, optional
832 :returns: A :py:class:`ModelHistory` instance
833 '''
834 start_watch = time.time()
835 while (time.time() - start_watch) < wait:
836 try:
837 cls.verify_rundir(path)
838 problem = load_problem_info(path)
839 return cls(problem, nchains=nchains, path=path, mode='r')
840 except (ProblemDataNotAvailable, OSError):
841 time.sleep(.25)
843 @property
844 def nmodels(self):
845 if self.models is None:
846 return 0
847 else:
848 return self.models.shape[0]
850 @nmodels.setter
851 def nmodels(self, nmodels_new):
852 assert 0 <= nmodels_new <= self.nmodels
853 self.models = self._models_buffer[:nmodels_new, :]
854 self.misfits = self._misfits_buffer[:nmodels_new, :, :]
855 if self.nchains is not None:
856 self.bootstrap_misfits = self._bootstraps_buffer[:nmodels_new, :, :] # noqa
857 if self._sample_contexts_buffer is not None:
858 self.sampler_contexts = self._sample_contexts_buffer[:nmodels_new, :] # noqa
860 @property
861 def nmodels_capacity(self):
862 if self._models_buffer is None:
863 return 0
864 else:
865 return self._models_buffer.shape[0]
867 @nmodels_capacity.setter
868 def nmodels_capacity(self, nmodels_capacity_new):
869 if self.nmodels_capacity != nmodels_capacity_new:
871 models_buffer = num.zeros(
872 (nmodels_capacity_new, self.problem.nparameters),
873 dtype=float)
874 misfits_buffer = num.zeros(
875 (nmodels_capacity_new, self.problem.nmisfits, 2),
876 dtype=float)
877 sample_contexts_buffer = num.zeros(
878 (nmodels_capacity_new, 4),
879 dtype=int)
880 sample_contexts_buffer.fill(-1)
882 if self.nchains is not None:
883 bootstraps_buffer = num.zeros(
884 (nmodels_capacity_new, self.nchains),
885 dtype=float)
887 ncopy = min(self.nmodels, nmodels_capacity_new)
889 if self._models_buffer is not None:
890 models_buffer[:ncopy, :] = \
891 self._models_buffer[:ncopy, :]
892 misfits_buffer[:ncopy, :, :] = \
893 self._misfits_buffer[:ncopy, :, :]
894 sample_contexts_buffer[:ncopy, :] = \
895 self._sample_contexts_buffer[:ncopy, :]
897 self._models_buffer = models_buffer
898 self._misfits_buffer = misfits_buffer
899 self._sample_contexts_buffer = sample_contexts_buffer
901 if self.nchains is not None:
902 if self._bootstraps_buffer is not None:
903 bootstraps_buffer[:ncopy, :] = \
904 self._bootstraps_buffer[:ncopy, :]
905 self._bootstraps_buffer = bootstraps_buffer
907 def clear(self):
908 assert self.mode != 'r', 'History is read-only, cannot clear.'
909 self.nmodels = 0
910 self.nmodels_capacity = self.nmodels_capacity_min
912 def extend(
913 self, models, misfits,
914 bootstrap_misfits=None,
915 sampler_contexts=None):
917 nmodels = self.nmodels
918 n = models.shape[0]
920 nmodels_capacity_want = max(
921 self.nmodels_capacity_min, nextpow2(nmodels + n))
923 if nmodels_capacity_want != self.nmodels_capacity:
924 self.nmodels_capacity = nmodels_capacity_want
926 self._models_buffer[nmodels:nmodels + n, :] = models
927 self._misfits_buffer[nmodels:nmodels + n, :, :] = misfits
929 self.models = self._models_buffer[:nmodels + n, :]
930 self.misfits = self._misfits_buffer[:nmodels + n, :, :]
932 if bootstrap_misfits is not None:
933 self._bootstraps_buffer[nmodels:nmodels + n, :] = bootstrap_misfits
934 self.bootstrap_misfits = self._bootstraps_buffer[:nmodels + n, :]
936 if sampler_contexts is not None:
937 self._sample_contexts_buffer[nmodels:nmodels + n, :] \
938 = sampler_contexts
939 self.sampler_contexts = self._sample_contexts_buffer[
940 :nmodels + n, :]
942 if self.path and self.mode == 'w':
943 for i in range(n):
944 self.problem.dump_problem_data(
945 self.path, models[i, :], misfits[i, :, :],
946 bootstrap_misfits[i, :]
947 if bootstrap_misfits is not None else None,
948 sampler_contexts[i, :]
949 if sampler_contexts is not None else None)
951 self._sorted_misfit_idx.clear()
952 self.emit('extend', nmodels, n, models, misfits, sampler_contexts)
954 def append(
955 self, model, misfits,
956 bootstrap_misfits=None,
957 sampler_context=None):
959 if bootstrap_misfits is not None:
960 bootstrap_misfits = bootstrap_misfits[num.newaxis, :]
962 if sampler_context is not None:
963 sampler_context = sampler_context[num.newaxis, :]
965 return self.extend(
966 model[num.newaxis, :], misfits[num.newaxis, :, :],
967 bootstrap_misfits, sampler_context)
969 def load(self):
970 self.mode = 'r'
971 self.verify_rundir(self.path)
972 models, misfits, bootstraps, sampler_contexts = load_problem_data(
973 self.path, self.problem, nchains=self.nchains)
974 self.extend(models, misfits, bootstraps, sampler_contexts)
976 def update(self):
977 ''' Update history from path '''
978 nmodels_available = get_nmodels(self.path, self.problem)
979 if self.nmodels == nmodels_available:
980 return
982 try:
983 new_models, new_misfits, new_bootstraps, new_sampler_contexts = \
984 load_problem_data(
985 self.path,
986 self.problem,
987 nmodels_skip=self.nmodels,
988 nchains=self.nchains)
990 except ValueError:
991 return
993 self.extend(
994 new_models,
995 new_misfits,
996 new_bootstraps,
997 new_sampler_contexts)
999 def add_listener(self, listener):
1000 ''' Add a listener to the history
1002 The listening class can implement the following methods:
1003 * ``extend``
1004 '''
1005 self.listeners.append(listener)
1007 def emit(self, event_name, *args, **kwargs):
1008 for listener in self.listeners:
1009 slot = getattr(listener, event_name, None)
1010 if callable(slot):
1011 slot(*args, **kwargs)
1013 @property
1014 def attribute_names(self):
1015 apath = op.join(self.path, 'attributes')
1016 if not os.path.exists(apath):
1017 return []
1019 return [fn for fn in os.listdir(apath)
1020 if StringID.regex.match(fn)]
1022 def get_attribute(self, name):
1023 if name not in self._attributes:
1024 if name not in self.attribute_names:
1025 raise NoSuchAttribute(name)
1027 path = op.join(self.path, 'attributes', name)
1029 with open(path, 'rb') as f:
1030 self._attributes[name] = num.fromfile(
1031 f, dtype='<i4',
1032 count=self.nmodels).astype(int)
1034 assert self._attributes[name].shape == (self.nmodels,)
1036 return self._attributes[name]
1038 def set_attribute(self, name, attribute):
1039 if not StringID.regex.match(name):
1040 raise InvalidAttributeName(name)
1042 attribute = attribute.astype(int)
1043 assert attribute.shape == (self.nmodels,)
1045 apath = op.join(self.path, 'attributes')
1047 if not os.path.exists(apath):
1048 os.mkdir(apath)
1050 path = op.join(apath, name)
1052 with open(path, 'wb') as f:
1053 attribute.astype('<i4').tofile(f)
1055 self._attributes[name] = attribute
1057 def ensure_bootstrap_misfits(self, optimiser):
1058 if self.bootstrap_misfits is None:
1059 problem = self.problem
1060 self.bootstrap_misfits = problem.combine_misfits(
1061 self.misfits,
1062 extra_weights=optimiser.get_bootstrap_weights(problem),
1063 extra_residuals=optimiser.get_bootstrap_residuals(problem))
1065 def imodels_by_cluster(self, cluster_attribute):
1066 if cluster_attribute is None:
1067 return [(-1, 100.0, num.arange(self.nmodels))]
1069 by_cluster = []
1070 try:
1071 iclusters = self.get_attribute(cluster_attribute)
1072 iclusters_avail = num.unique(iclusters)
1074 for icluster in iclusters_avail:
1075 imodels = num.where(iclusters == icluster)[0]
1076 by_cluster.append(
1077 (icluster,
1078 (100.0 * imodels.size) / self.nmodels,
1079 imodels))
1081 if by_cluster and by_cluster[0][0] == -1:
1082 by_cluster.append(by_cluster.pop(0))
1084 except NoSuchAttribute:
1085 logger.warn(
1086 'Attribute %s not set in run %s.\n'
1087 ' Skipping model retrieval by clusters.' % (
1088 cluster_attribute, self.problem.name))
1090 return by_cluster
1092 def models_by_cluster(self, cluster_attribute):
1093 if cluster_attribute is None:
1094 return [(-1, 100.0, self.models)]
1096 return [
1097 (icluster, percentage, self.models[imodels])
1098 for (icluster, percentage, imodels)
1099 in self.imodels_by_cluster(cluster_attribute)]
1101 def mean_sources_by_cluster(self, cluster_attribute):
1102 return [
1103 (icluster, percentage, stats.get_mean_source(self.problem, models))
1104 for (icluster, percentage, models)
1105 in self.models_by_cluster(cluster_attribute)]
1107 def get_sorted_misfits_idx(self, chain=0):
1108 if chain not in self._sorted_misfit_idx.keys():
1109 self._sorted_misfit_idx[chain] = num.argsort(
1110 self.bootstrap_misfits[:, chain])
1112 return self._sorted_misfit_idx[chain]
1114 def get_sorted_misfits(self, chain=0):
1115 isort = self.get_sorted_misfits_idx(chain)
1116 return self.bootstrap_misfits[:, chain][isort]
1118 def get_sorted_models(self, chain=0):
1119 isort = self.get_sorted_misfits_idx(chain=0)
1120 return self.models[isort, :]
1122 def get_sorted_primary_misfits(self):
1123 return self.get_sorted_misfits(chain=0)
1125 def get_sorted_primary_models(self):
1126 return self.get_sorted_models(chain=0)
1128 def get_best_model(self, chain=0):
1129 return self.get_sorted_models(chain)[0, ...]
1131 def get_best_misfit(self, chain=0):
1132 return self.get_sorted_misfits(chain)[0]
1134 def get_mean_model(self):
1135 return num.mean(self.models, axis=0)
1137 def get_mean_misfit(self, chain=0):
1138 return num.mean(self.bootstrap_misfits[:, chain])
1140 def get_best_source(self, chain=0):
1141 return self.problem.get_source(self.get_best_model(chain))
1143 def get_mean_source(self, chain=0):
1144 return self.problem.get_source(self.get_mean_model())
1146 def get_chain_misfits(self, chain=0):
1147 return self.bootstrap_misfits[:, chain]
1149 def get_primary_chain_misfits(self):
1150 return self.get_chain_misfits(chain=0)
1153class RandomStateManager(object):
1155 MAX_LEN = 64
1156 save_struct = struct.Struct('%ds7s2496sqqd' % MAX_LEN)
1158 def __init__(self):
1159 self.rstates = {}
1160 self.lock = threading.Lock()
1162 def get_rstate(self, name, seed=None):
1163 assert len(name) <= self.MAX_LEN
1165 if name not in self.rstates:
1166 self.rstates[name] = num.random.RandomState(seed)
1167 return self.rstates[name]
1169 @property
1170 def nstates(self):
1171 return len(self.rstates)
1173 def save_state(self, fname):
1174 with self.lock:
1175 with open(fname, 'wb') as f:
1177 for name, rstate in self.rstates.items():
1178 s, arr, pos, has_gauss, chached_gauss = rstate.get_state()
1179 f.write(
1180 self.save_struct.pack(
1181 name.encode(), s.encode(), arr.tobytes(),
1182 pos, has_gauss, chached_gauss))
1184 def load_state(self, fname):
1185 with self.lock:
1186 with open(fname, 'rb') as f:
1187 while True:
1188 buff = f.read(self.save_struct.size)
1189 if not buff:
1190 break
1192 name, s, arr, pos, has_gauss, chached_gauss = \
1193 self.save_struct.unpack(buff)
1195 name = name.replace(b'\x00', b'').decode()
1196 s = s.replace(b'\x00', b'').decode()
1197 arr = num.frombuffer(arr, dtype=num.uint32)
1199 rstate = num.random.RandomState()
1200 rstate.set_state((s, arr, pos, has_gauss, chached_gauss))
1201 self.rstates[name] = rstate
1204def get_nmodels(dirname, problem):
1205 fn = op.join(dirname, 'models')
1206 with open(fn, 'r') as f:
1207 nmodels1 = os.fstat(f.fileno()).st_size // (problem.nparameters * 8)
1209 fn = op.join(dirname, 'misfits')
1210 with open(fn, 'r') as f:
1211 nmodels2 = os.fstat(f.fileno()).st_size // (problem.nmisfits * 2 * 8)
1213 return min(nmodels1, nmodels2)
1216def load_problem_info_and_data(dirname, subset=None, nchains=None):
1217 problem = load_problem_info(dirname)
1218 models, misfits, bootstraps, sampler_contexts = load_problem_data(
1219 xjoin(dirname, subset), problem, nchains=nchains)
1220 return problem, models, misfits, bootstraps, sampler_contexts
1223def load_optimiser_info(dirname):
1224 fn = op.join(dirname, 'optimiser.yaml')
1225 return guts.load(filename=fn)
1228def load_problem_info(dirname):
1229 try:
1230 fn = op.join(dirname, 'problem.yaml')
1231 return guts.load(filename=fn)
1232 except OSError as e:
1233 logger.debug(e)
1234 raise ProblemInfoNotAvailable(
1235 'No problem info available (%s).' % dirname)
1238def load_problem_data(dirname, problem, nmodels_skip=0, nchains=None):
1240 def get_chains_fn():
1241 for fn in (op.join(dirname, 'bootstraps'),
1242 op.join(dirname, 'chains')):
1243 if op.exists(fn):
1244 return fn
1245 return False
1247 try:
1248 nmodels = get_nmodels(dirname, problem) - nmodels_skip
1250 fn = op.join(dirname, 'models')
1251 offset = nmodels_skip * problem.nparameters * 8
1252 models = num.memmap(
1253 fn, dtype='<f8',
1254 offset=offset,
1255 shape=(nmodels, problem.nparameters))\
1256 .astype(float, copy=False)
1258 fn = op.join(dirname, 'misfits')
1259 offset = nmodels_skip * problem.nmisfits * 2 * 8
1260 misfits = num.memmap(
1261 fn, dtype='<f8',
1262 offset=offset,
1263 shape=(nmodels, problem.nmisfits, 2))\
1264 .astype(float, copy=False)
1266 chains = None
1267 fn = get_chains_fn()
1268 if fn and nchains is not None:
1269 offset = nmodels_skip * nchains * 8
1270 chains = num.memmap(
1271 fn, dtype='<f8',
1272 offset=offset,
1273 shape=(nmodels, nchains))\
1274 .astype(float, copy=False)
1276 sampler_contexts = None
1277 fn = op.join(dirname, 'choices')
1278 if op.exists(fn):
1279 offset = nmodels_skip * 4 * 8
1280 sampler_contexts = num.memmap(
1281 fn, dtype='<i8',
1282 offset=offset,
1283 shape=(nmodels, 4))\
1284 .astype(int, copy=False)
1286 fn = op.join(dirname, 'rstate')
1287 if op.exists(fn):
1288 problem.get_rstate_manager().load_state(fn)
1290 except OSError as e:
1291 logger.debug(str(e))
1292 raise ProblemDataNotAvailable(
1293 'No problem data available (%s).' % dirname)
1295 return models, misfits, chains, sampler_contexts
1298__all__ = '''
1299 ProblemConfig
1300 Problem
1301 ModelHistory
1302 ProblemInfoNotAvailable
1303 ProblemDataNotAvailable
1304 load_problem_info
1305 load_problem_info_and_data
1306 InvalidAttributeName
1307 NoSuchAttribute
1308'''.split()