Coverage for /usr/local/lib/python3.11/dist-packages/grond/problems/base.py: 77%
759 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
1'''
2Base classes for Grond's problem definition and the model history container.
4Common behaviour of all source models offered by Grond is implemented here.
5Source model specific details are implemented in the respective submodules.
6'''
8import numpy as num
9import math
10import copy
11import logging
12import os.path as op
13import os
14import time
15import struct
16import threading
18from pyrocko import gf, util, guts, orthodrome as pod
19from pyrocko.guts import Object, String, List, Dict, Int
21from grond.meta import ADict, Parameter, GrondError, xjoin, Forbidden, \
22 StringID, has_get_plot_classes
23from ..targets import MisfitResult, MisfitTarget, TargetGroup, \
24 WaveformMisfitTarget, SatelliteMisfitTarget, GNSSCampaignMisfitTarget
26from grond import stats
28from grond.version import __version__
30guts_prefix = 'grond'
31logger = logging.getLogger('grond.problems.base')
32km = 1e3
33as_km = dict(scale_factor=km, scale_unit='km')
35g_rstate = num.random.RandomState()
38def nextpow2(i):
39 return 2**int(math.ceil(math.log(i) / math.log(2.)))
42def nextcapacity(i):
43 return int(math.ceil(i / 1024) * 1024)
46def correlated_weights(values, weight_matrix):
47 '''
48 Applies correlated weights to values
50 The resulting weighed values have to be squared! Check out
51 :meth:`Problem.combine_misfits` for more information.
53 :param values: Misfits or norms as :class:`numpy.Array`
54 :param weight: Weight matrix, commonly the inverse of covariance matrix
56 :returns: :class:`numpy.Array` weighted values
57 '''
58 return num.matmul(values, weight_matrix)
61class ProblemConfig(Object):
62 '''
63 Base class for config section defining the objective function setup.
65 Factory for :py:class:`Problem` objects.
66 '''
67 name_template = String.T()
68 norm_exponent = Int.T(default=2)
69 nthreads = Int.T(
70 default=1,
71 optional=True,
72 help='Deprecated: use command line argument or global config to set '
73 'number of allowed threads.')
75 def check_deprecations(self):
76 if self.nthreads != 1:
77 logger.warning(
78 'The `nthreads` parameter in `ProblemConfig` has been '
79 'deprecated and is ignored now. Please use the `--nthreads` '
80 'command line option to set it.')
82 def get_problem(self, event, target_groups, targets):
83 '''
84 Instantiate the problem with a given event and targets.
86 :returns: :py:class:`Problem` object
87 '''
88 self.check_deprecations()
89 raise NotImplementedError
92@has_get_plot_classes
93class Problem(Object):
94 '''
95 Base class for objective function setup.
97 Defines the *problem* to be solved by the optimiser.
98 '''
99 name = String.T()
100 ranges = Dict.T(String.T(), gf.Range.T())
101 dependants = List.T(Parameter.T())
102 norm_exponent = Int.T(default=2)
103 base_source = gf.Source.T(optional=True)
104 targets = List.T(MisfitTarget.T())
105 target_groups = List.T(TargetGroup.T())
106 grond_version = String.T(optional=True)
107 nthreads = Int.T(
108 default=1,
109 optional=True,
110 help='Deprecated: use command line argument or global config to set '
111 'number of allowed threads.')
113 def __init__(self, **kwargs):
114 Object.__init__(self, **kwargs)
116 if self.grond_version is None:
117 self.grond_version = __version__
119 self._target_weights = None
120 self._engine = None
121 self._family_mask = None
122 self._rstate_manager = None
124 if hasattr(self, 'problem_waveform_parameters') and self.has_waveforms:
125 self.problem_parameters =\
126 self.problem_parameters + self.problem_waveform_parameters
128 unused_parameters = []
129 for p in self.problem_parameters:
130 if p.optional and p._name not in self.ranges.keys():
131 unused_parameters.append(p)
133 for p in unused_parameters:
134 self.problem_parameters.remove(p)
136 self.check()
138 @classmethod
139 def get_plot_classes(cls):
140 from . import plot
141 return plot.get_plot_classes()
143 def check(self):
144 paths = set()
145 for grp in self.target_groups:
146 if grp.path == 'all':
147 continue
148 if grp.path in paths:
149 raise ValueError('Path %s defined more than once! In %s'
150 % (grp.path, grp.__class__.__name__))
151 paths.add(grp.path)
152 logger.debug('TargetGroup check OK.')
154 def copy(self):
155 o = copy.copy(self)
156 o._target_weights = None
157 return o
159 def set_target_parameter_values(self, x):
160 nprob = len(self.problem_parameters)
161 for target in self.targets:
162 target.set_parameter_values(x[nprob:nprob + target.nparameters])
163 nprob += target.nparameters
165 def get_parameter_dict(self, model, group=None):
166 params = [(p.name, model[ip])
167 for ip, p in enumerate(self.parameters)
168 if group in p.groups or group is None]
169 return ADict(params)
171 def get_parameter_array(self, d):
172 arr = num.zeros(self.nparameters, dtype=float)
173 for ip, p in enumerate(self.parameters):
174 if p.name in d.keys():
175 arr[ip] = d[p.name]
176 return arr
178 def get_parameter_index(self, param_name):
179 return {k.name: ik for ik, k in enumerate(self.parameters)}[param_name]
181 def get_rstate_manager(self):
182 if self._rstate_manager is None:
183 self._rstate_manager = RandomStateManager()
184 return self._rstate_manager
186 def dump_problem_info(self, dirname):
187 fn = op.join(dirname, 'problem.yaml')
188 util.ensuredirs(fn)
189 guts.dump(self, filename=fn)
191 def dump_problem_data(
192 self, dirname, x, misfits, chains=None,
193 sampler_context=None):
195 fn = op.join(dirname, 'models')
196 if not isinstance(x, num.ndarray):
197 x = num.array(x)
198 with open(fn, 'ab') as f:
199 x.astype('<f8').tofile(f)
201 fn = op.join(dirname, 'misfits')
202 with open(fn, 'ab') as f:
203 misfits.astype('<f8').tofile(f)
205 if chains is not None:
206 fn = op.join(dirname, 'chains')
207 with open(fn, 'ab') as f:
208 chains.astype('<f8').tofile(f)
210 if sampler_context is not None:
211 fn = op.join(dirname, 'choices')
212 with open(fn, 'ab') as f:
213 num.array(sampler_context, dtype='<i8').tofile(f)
215 fn = op.join(dirname, 'rstate')
216 self.get_rstate_manager().save_state(fn)
218 def name_to_index(self, name):
219 pnames = [p.name for p in self.combined]
220 return pnames.index(name)
222 @property
223 def parameters(self):
224 target_parameters = []
225 for target in self.targets:
226 target_parameters.extend(target.target_parameters)
227 return self.problem_parameters + target_parameters
229 @property
230 def parameter_names(self):
231 return [p.name for p in self.combined]
233 @property
234 def dependant_names(self):
235 return [p.name for p in self.dependants]
237 @property
238 def nparameters(self):
239 return len(self.parameters)
241 @property
242 def ntargets(self):
243 return len(self.targets)
245 @property
246 def nwaveform_targets(self):
247 return len(self.waveform_targets)
249 @property
250 def nsatellite_targets(self):
251 return len(self.satellite_targets)
253 @property
254 def ngnss_targets(self):
255 return len(self.gnss_targets)
257 @property
258 def nmisfits(self):
259 nmisfits = 0
260 for target in self.targets:
261 nmisfits += target.nmisfits
262 return nmisfits
264 @property
265 def ndependants(self):
266 return len(self.dependants)
268 @property
269 def ncombined(self):
270 return len(self.parameters) + len(self.dependants)
272 @property
273 def combined(self):
274 return self.parameters + self.dependants
276 @property
277 def satellite_targets(self):
278 return [t for t in self.targets
279 if isinstance(t, SatelliteMisfitTarget)]
281 @property
282 def gnss_targets(self):
283 return [t for t in self.targets
284 if isinstance(t, GNSSCampaignMisfitTarget)]
286 @property
287 def waveform_targets(self):
288 return [t for t in self.targets
289 if isinstance(t, WaveformMisfitTarget)]
291 @property
292 def has_satellite(self):
293 if self.satellite_targets:
294 return True
295 return False
297 @property
298 def has_waveforms(self):
299 if self.waveform_targets:
300 return True
301 return False
303 def set_engine(self, engine):
304 self._engine = engine
306 def get_engine(self):
307 return self._engine
309 def get_source(self, x):
310 raise NotImplementedError
312 def pack(self, source):
313 raise NotImplementedError
315 def source_to_x(self, source):
316 bs = self.base_source
318 n, e = pod.latlon_to_ne(
319 bs.lat, bs.lon,
320 source.effective_lat, source.effective_lon)
322 source.lat, source.lon = bs.lat, bs.lon
323 source.north_shift = n
324 source.east_shift = e
326 tmin, tmax = self.ranges['time'].start, self.ranges['time'].stop
328 if (source.time - bs.time < tmin) or (source.time - bs.time > tmax):
329 rstatem = self.get_rstate_manager()
330 rstate = rstatem.get_rstate(name='source2x')
332 source.time = bs.time
333 source.time += rstate.uniform(low=tmin, high=tmax, size=1)
335 p = {}
336 for k in self.base_source.keys():
337 val = source[k]
338 if k == 'time':
339 p[k] = float(val)
341 elif k in self.ranges:
342 val = source[k]
344 if self.ranges[k].relative == 'add':
345 val = self.ranges[k].make_relative(
346 -self.base_source[k], source[k])
347 elif self.ranges[k].relative == 'mult':
348 val = self.ranges[k].make_relative(
349 1./self.base_source[k], source[k])
351 p[k] = float(val)
353 source = source.clone(**p)
355 return self.pack(source)
357 def get_gf_store_ids(self):
358 return tuple(set([t.store_id for t in self.targets]))
360 def get_gf_store(self, target):
361 if not isinstance(target, str):
362 target = target.store_id
363 if self.get_engine() is None:
364 raise GrondError('Cannot get GF Store, modelling is not set up.')
365 return self.get_engine().get_store(target)
367 def random_uniform(self, xbounds, rstate, fixed_magnitude=None):
368 if fixed_magnitude is not None:
369 raise GrondError(
370 'Setting fixed magnitude in random model generation not '
371 'supported for this type of problem.')
373 x = rstate.uniform(0., 1., self.nparameters)
374 x *= (xbounds[:, 1] - xbounds[:, 0])
375 x += xbounds[:, 0]
376 return x
378 def minimum(self, xbounds, fixed_magnitude=None):
379 if fixed_magnitude is not None:
380 raise GrondError(
381 'Setting fixed magnitude in random model generation not '
382 'supported for this type of problem.')
384 x = xbounds[:, 0]
385 return x
387 def maximum(self, xbounds, fixed_magnitude=None):
388 if fixed_magnitude is not None:
389 raise GrondError(
390 'Setting fixed magnitude in random model generation not '
391 'supported for this type of problem.')
393 x = xbounds[:, 1]
394 return x
396 def preconstrain(self, x, optimizer=None):
397 return x
399 def extract(self, xs, i):
400 if xs.ndim == 1:
401 return self.extract(xs[num.newaxis, :], i)[0]
403 if i < self.nparameters:
404 return xs[:, i]
405 else:
406 return self.make_dependant(
407 xs, self.dependants[i - self.nparameters].name)
409 def get_target_weights(self):
410 if self._target_weights is None:
411 self._target_weights = num.concatenate(
412 [target.get_combined_weight() for target in self.targets])
414 return self._target_weights
416 def get_target_residuals(self):
417 pass
419 def inter_family_weights(self, ns):
420 exp, root = self.get_norm_functions()
422 family, nfamilies = self.get_family_mask()
424 ws = num.zeros(self.nmisfits)
425 for ifamily in range(nfamilies):
426 mask = family == ifamily
427 ws[mask] = 1.0 / root(num.nansum(exp(ns[mask])))
429 return ws
431 def inter_family_weights2(self, ns):
432 '''
433 :param ns: 2D array with normalization factors ``ns[imodel, itarget]``
434 :returns: 2D array ``weights[imodel, itarget]``
435 '''
437 exp, root = self.get_norm_functions()
438 family, nfamilies = self.get_family_mask()
440 ws = num.zeros(ns.shape)
441 for ifamily in range(nfamilies):
442 mask = family == ifamily
443 ws[:, mask] = (1.0 / root(
444 num.nansum(exp(ns[:, mask]), axis=1)))[:, num.newaxis]
446 return ws
448 def get_reference_model(self):
449 model = num.zeros(self.nparameters)
450 model_source_params = self.pack(self.base_source)
451 model[:model_source_params.size] = model_source_params
452 return model
454 def get_parameter_bounds(self):
455 out = []
456 for p in self.problem_parameters:
457 r = self.ranges[p.name]
458 out.append((r.start, r.stop))
460 for target in self.targets:
461 for p in target.target_parameters:
462 r = target.target_ranges[p.name_nogroups]
463 out.append((r.start, r.stop))
465 return num.array(out, dtype=float)
467 def get_dependant_bounds(self):
468 return num.zeros((0, 2))
470 def get_combined_bounds(self):
471 return num.vstack((
472 self.get_parameter_bounds(),
473 self.get_dependant_bounds()))
475 def raise_invalid_norm_exponent(self):
476 raise GrondError('Invalid norm exponent: %f' % self.norm_exponent)
478 def get_norm_functions(self):
479 if self.norm_exponent == 2:
480 def sqr(x):
481 return x**2
483 return sqr, num.sqrt
485 elif self.norm_exponent == 1:
486 def noop(x):
487 return x
489 return noop, num.abs
491 else:
492 self.raise_invalid_norm_exponent()
494 def combine_misfits(
495 self, misfits,
496 extra_weights=None,
497 extra_residuals=None,
498 extra_correlated_weights=dict(),
499 get_contributions=False):
500 '''
501 Combine misfit contributions (residuals) to global or bootstrap misfits
503 :param misfits: 3D array ``misfits[imodel, iresidual, 0]`` are the
504 misfit contributions (residuals) ``misfits[imodel, iresidual, 1]``
505 are the normalisation contributions. It is also possible to give
506 the misfit and normalisation contributions for a single model as
507 ``misfits[iresidual, 0]`` and misfits[iresidual, 1]`` in which
508 case, the first dimension (imodel) of the result will be stipped
509 off.
511 :param extra_weights: if given, 2D array of extra weights to be applied
512 to the contributions, indexed as
513 ``extra_weights[ibootstrap, iresidual]``.
515 :param extra_residuals: if given, 2D array of perturbations to be added
516 to the residuals, indexed as
517 ``extra_residuals[ibootstrap, iresidual]``.
519 :param extra_correlated_weights: if a dictionary of
520 ``imisfit: correlated weight matrix`` is passed a correlated
521 weight matrix is applied to the misfit and normalisation values.
522 `imisfit` is the starting index in the misfits vector the
523 correlated weight matrix applies to.
525 :param get_contributions: get the weighted and perturbed contributions
526 (don't do the sum).
528 :returns: if no *extra_weights* or *extra_residuals* are given, a 1D
529 array indexed as ``misfits[imodel]`` containing the global misfit
530 for each model is returned, otherwise a 2D array
531 ``misfits[imodel, ibootstrap]`` with the misfit for every model and
532 weighting/residual set is returned.
533 '''
534 if misfits.ndim == 2:
535 misfits = misfits[num.newaxis, :, :]
536 return self.combine_misfits(
537 misfits, extra_weights, extra_residuals,
538 extra_correlated_weights, get_contributions)[0, ...]
540 if extra_weights is None and extra_residuals is None:
541 return self.combine_misfits(
542 misfits, False, False,
543 extra_correlated_weights, get_contributions)[:, 0]
545 assert misfits.ndim == 3
546 assert not num.any(extra_weights) or extra_weights.ndim == 2
547 assert not num.any(extra_residuals) or extra_residuals.ndim == 2
549 if self.norm_exponent != 2 and extra_correlated_weights:
550 raise GrondError('Correlated weights can only be used '
551 ' with norm_exponent=2')
553 exp, root = self.get_norm_functions()
555 nmodels = misfits.shape[0]
556 nmisfits = misfits.shape[1] # noqa
558 mf = misfits[:, num.newaxis, :, :].copy()
560 if num.any(extra_residuals):
561 mf = mf + extra_residuals[num.newaxis, :, :, num.newaxis]
563 res = mf[..., 0]
564 norms = mf[..., 1]
566 for imisfit, corr_weight_mat in extra_correlated_weights.items():
568 jmisfit = imisfit + corr_weight_mat.shape[0]
570 for imodel in range(nmodels):
571 corr_res = res[imodel, :, imisfit:jmisfit]
572 corr_norms = norms[imodel, :, imisfit:jmisfit]
574 res[imodel, :, imisfit:jmisfit] = \
575 correlated_weights(corr_res, corr_weight_mat)
577 norms[imodel, :, imisfit:jmisfit] = \
578 correlated_weights(corr_norms, corr_weight_mat)
580 # Apply normalization family weights (these weights depend on
581 # on just calculated correlated norms!)
582 weights_fam = \
583 self.inter_family_weights2(norms[:, 0, :])[:, num.newaxis, :]
585 weights_fam = exp(weights_fam)
587 res = exp(res)
588 norms = exp(norms)
590 res *= weights_fam
591 norms *= weights_fam
593 weights_tar = self.get_target_weights()[num.newaxis, num.newaxis, :]
594 if num.any(extra_weights):
595 weights_tar = weights_tar * extra_weights[num.newaxis, :, :]
597 weights_tar = exp(weights_tar)
599 res = res * weights_tar
600 norms = norms * weights_tar
602 if get_contributions:
603 return res / num.nansum(norms, axis=2)[:, :, num.newaxis]
605 result = root(
606 num.nansum(res, axis=2) /
607 num.nansum(norms, axis=2))
609 assert result[result < 0].size == 0
610 return result
612 def make_family_mask(self):
613 family_names = set()
614 families = num.zeros(self.nmisfits, dtype=int)
616 idx = 0
617 for itarget, target in enumerate(self.targets):
618 family_names.add(target.normalisation_family)
619 families[idx:idx + target.nmisfits] = len(family_names) - 1
620 idx += target.nmisfits
622 return families, len(family_names)
624 def get_family_mask(self):
625 if self._family_mask is None:
626 self._family_mask = self.make_family_mask()
628 return self._family_mask
630 def evaluate(self, x, mask=None, result_mode='full', targets=None):
631 source = self.get_source(x)
632 engine = self.get_engine()
634 self.set_target_parameter_values(x)
636 if mask is not None and targets is not None:
637 raise ValueError('Mask cannot be defined with targets set.')
638 targets = targets if targets is not None else self.targets
640 for target in targets:
641 target.set_result_mode(result_mode)
643 modelling_targets = []
644 t2m_map = {}
645 for itarget, target in enumerate(targets):
646 t2m_map[target] = target.prepare_modelling(engine, source, targets)
647 if mask is None or mask[itarget]:
648 modelling_targets.extend(t2m_map[target])
650 u2m_map = {}
651 for imtarget, mtarget in enumerate(modelling_targets):
652 if mtarget not in u2m_map:
653 u2m_map[mtarget] = []
655 u2m_map[mtarget].append(imtarget)
657 modelling_targets_unique = list(u2m_map.keys())
659 resp = engine.process(source, modelling_targets_unique)
661 modelling_results_unique = list(resp.results_list[0])
662 modelling_results = [None] * len(modelling_targets)
664 for mtarget, mresult in zip(
665 modelling_targets_unique, modelling_results_unique):
667 for itarget in u2m_map[mtarget]:
668 modelling_results[itarget] = mresult
670 imt = 0
671 results = []
672 for itarget, target in enumerate(targets):
673 nmt_this = len(t2m_map[target])
674 if mask is None or mask[itarget]:
675 result = target.finalize_modelling(
676 engine, source,
677 t2m_map[target],
678 modelling_results[imt:imt + nmt_this])
680 imt += nmt_this
681 else:
682 result = gf.SeismosizerError(
683 'target was excluded from modelling')
685 results.append(result)
687 return results
689 def misfits(self, x, mask=None):
690 results = self.evaluate(x, mask=mask, result_mode='sparse')
691 misfits = num.full((self.nmisfits, 2), num.nan)
693 imisfit = 0
694 for target, result in zip(self.targets, results):
695 if isinstance(result, MisfitResult):
696 misfits[imisfit:imisfit + target.nmisfits, :] = result.misfits
698 imisfit += target.nmisfits
700 return misfits
702 def forward(self, x):
703 source = self.get_source(x)
704 engine = self.get_engine()
706 plain_targets = []
707 for target in self.targets:
708 plain_targets.extend(target.get_plain_targets(engine, source))
710 resp = engine.process(source, plain_targets)
712 results = []
713 for target, result in zip(plain_targets, resp.results_list[0]):
714 if isinstance(result, gf.SeismosizerError):
715 logger.debug(
716 '%s.%s.%s.%s: %s' % (target.codes + (str(result),)))
717 else:
718 results.append(result)
720 return results
722 def get_random_model(self, ntries_limit=100):
723 xbounds = self.get_parameter_bounds()
725 for _ in range(ntries_limit):
726 x = self.random_uniform(xbounds, rstate=g_rstate)
727 try:
728 return self.preconstrain(x)
730 except Forbidden:
731 pass
733 raise GrondError(
734 'Could not find any suitable candidate sample within %i tries' % (
735 ntries_limit))
737 def get_min_model(self, ntries_limit=100):
738 xbounds = self.get_parameter_bounds()
740 for _ in range(ntries_limit):
741 x = self.minimum(xbounds)
742 try:
743 return self.preconstrain(x)
745 except Forbidden:
746 pass
748 raise GrondError(
749 'Could not find any suitable candidate sample within %i tries' % (
750 ntries_limit))
752 def get_max_model(self, ntries_limit=100):
753 xbounds = self.get_parameter_bounds()
755 for _ in range(ntries_limit):
756 x = self.maximum(xbounds)
757 try:
758 return self.preconstrain(x)
760 except Forbidden:
761 pass
763 raise GrondError(
764 'Could not find any suitable candidate sample within %i tries' % (
765 ntries_limit))
768class ProblemInfoNotAvailable(GrondError):
769 pass
772class ProblemDataNotAvailable(GrondError):
773 pass
776class NoSuchAttribute(GrondError):
777 pass
780class InvalidAttributeName(GrondError):
781 pass
784class ModelHistory(object):
785 '''
786 Write, read and follow sequences of models produced in an optimisation run.
788 :param problem: :class:`grond.Problem` instance
789 :param path: path to rundir, defaults to None
790 :type path: str, optional
791 :param mode: open mode, 'r': read, 'w': write
792 :type mode: str, optional
793 '''
795 nmodels_capacity_min = 1024
797 def __init__(self, problem, nchains=None, path=None, mode='r'):
798 self.mode = mode
800 self.problem = problem
801 self.path = path
802 self.nchains = nchains
804 self._models_buffer = None
805 self._misfits_buffer = None
806 self._bootstraps_buffer = None
807 self._sample_contexts_buffer = None
809 self._sorted_misfit_idx = {}
811 self.models = None
812 self.misfits = None
813 self.bootstrap_misfits = None
815 self.sampler_contexts = None
817 self.nmodels_capacity = self.nmodels_capacity_min
818 self.listeners = []
820 self._attributes = {}
822 if mode == 'r':
823 self.load()
825 @staticmethod
826 def verify_rundir(rundir):
827 _rundir_files = ('misfits', 'models')
829 if not op.exists(rundir):
830 raise ProblemDataNotAvailable(
831 'Directory does not exist: %s' % rundir)
832 for f in _rundir_files:
833 if not op.exists(op.join(rundir, f)):
834 raise ProblemDataNotAvailable('File not found: %s' % f)
836 @classmethod
837 def follow(cls, path, nchains=None, wait=20.):
838 '''
839 Start following a rundir (constructor).
841 :param path: the path to follow, a grond rundir
842 :type path: str, optional
843 :param wait: wait time until the folder become alive
844 :type wait: number in seconds, optional
845 :returns: A :py:class:`ModelHistory` instance
846 '''
847 start_watch = time.time()
848 while (time.time() - start_watch) < wait:
849 try:
850 cls.verify_rundir(path)
851 problem = load_problem_info(path)
852 return cls(problem, nchains=nchains, path=path, mode='r')
853 except (ProblemDataNotAvailable, OSError):
854 time.sleep(.25)
856 @property
857 def nmodels(self):
858 if self.models is None:
859 return 0
860 else:
861 return self.models.shape[0]
863 @nmodels.setter
864 def nmodels(self, nmodels_new):
865 assert 0 <= nmodels_new <= self.nmodels
866 self.models = self._models_buffer[:nmodels_new, :]
867 self.misfits = self._misfits_buffer[:nmodels_new, :, :]
868 if self.nchains is not None:
869 self.bootstrap_misfits = self._bootstraps_buffer[:nmodels_new, :, :] # noqa
870 if self._sample_contexts_buffer is not None:
871 self.sampler_contexts = self._sample_contexts_buffer[:nmodels_new, :] # noqa
873 @property
874 def nmodels_capacity(self):
875 if self._models_buffer is None:
876 return 0
877 else:
878 return self._models_buffer.shape[0]
880 @nmodels_capacity.setter
881 def nmodels_capacity(self, nmodels_capacity_new):
882 if self.nmodels_capacity != nmodels_capacity_new:
884 models_buffer = num.zeros(
885 (nmodels_capacity_new, self.problem.nparameters),
886 dtype=float)
887 misfits_buffer = num.zeros(
888 (nmodels_capacity_new, self.problem.nmisfits, 2),
889 dtype=float)
890 sample_contexts_buffer = num.zeros(
891 (nmodels_capacity_new, 4),
892 dtype=int)
893 sample_contexts_buffer.fill(-1)
895 if self.nchains is not None:
896 bootstraps_buffer = num.zeros(
897 (nmodels_capacity_new, self.nchains),
898 dtype=float)
900 ncopy = min(self.nmodels, nmodels_capacity_new)
902 if self._models_buffer is not None:
903 models_buffer[:ncopy, :] = \
904 self._models_buffer[:ncopy, :]
905 misfits_buffer[:ncopy, :, :] = \
906 self._misfits_buffer[:ncopy, :, :]
907 sample_contexts_buffer[:ncopy, :] = \
908 self._sample_contexts_buffer[:ncopy, :]
910 self._models_buffer = models_buffer
911 self._misfits_buffer = misfits_buffer
912 self._sample_contexts_buffer = sample_contexts_buffer
914 if self.nchains is not None:
915 if self._bootstraps_buffer is not None:
916 bootstraps_buffer[:ncopy, :] = \
917 self._bootstraps_buffer[:ncopy, :]
918 self._bootstraps_buffer = bootstraps_buffer
920 def clear(self):
921 assert self.mode != 'r', 'History is read-only, cannot clear.'
922 self.nmodels = 0
923 self.nmodels_capacity = self.nmodels_capacity_min
925 def extend(
926 self, models, misfits,
927 bootstrap_misfits=None,
928 sampler_contexts=None):
930 nmodels = self.nmodels
931 n = models.shape[0]
933 nmodels_capacity_want = max(
934 self.nmodels_capacity_min, nextpow2(nmodels + n))
936 if nmodels_capacity_want != self.nmodels_capacity:
937 self.nmodels_capacity = nmodels_capacity_want
939 self._models_buffer[nmodels:nmodels + n, :] = models
940 self._misfits_buffer[nmodels:nmodels + n, :, :] = misfits
942 self.models = self._models_buffer[:nmodels + n, :]
943 self.misfits = self._misfits_buffer[:nmodels + n, :, :]
945 if bootstrap_misfits is not None:
946 self._bootstraps_buffer[nmodels:nmodels + n, :] = bootstrap_misfits
947 self.bootstrap_misfits = self._bootstraps_buffer[:nmodels + n, :]
949 if sampler_contexts is not None:
950 self._sample_contexts_buffer[nmodels:nmodels + n, :] \
951 = sampler_contexts
952 self.sampler_contexts = self._sample_contexts_buffer[
953 :nmodels + n, :]
955 if self.path and self.mode == 'w':
956 for i in range(n):
957 self.problem.dump_problem_data(
958 self.path, models[i, :], misfits[i, :, :],
959 bootstrap_misfits[i, :]
960 if bootstrap_misfits is not None else None,
961 sampler_contexts[i, :]
962 if sampler_contexts is not None else None)
964 self._sorted_misfit_idx.clear()
965 self.emit('extend', nmodels, n, models, misfits, sampler_contexts)
967 def append(
968 self, model, misfits,
969 bootstrap_misfits=None,
970 sampler_context=None):
972 if bootstrap_misfits is not None:
973 bootstrap_misfits = bootstrap_misfits[num.newaxis, :]
975 if sampler_context is not None:
976 sampler_context = sampler_context[num.newaxis, :]
978 return self.extend(
979 model[num.newaxis, :], misfits[num.newaxis, :, :],
980 bootstrap_misfits, sampler_context)
982 def load(self):
983 self.mode = 'r'
984 self.verify_rundir(self.path)
985 models, misfits, bootstraps, sampler_contexts = load_problem_data(
986 self.path, self.problem, nchains=self.nchains)
987 self.extend(models, misfits, bootstraps, sampler_contexts)
989 def update(self):
990 ''' Update history from path '''
991 nmodels_available = get_nmodels(self.path, self.problem)
992 if self.nmodels == nmodels_available:
993 return
995 try:
996 new_models, new_misfits, new_bootstraps, new_sampler_contexts = \
997 load_problem_data(
998 self.path,
999 self.problem,
1000 nmodels_skip=self.nmodels,
1001 nchains=self.nchains)
1003 except ValueError:
1004 return
1006 self.extend(
1007 new_models,
1008 new_misfits,
1009 new_bootstraps,
1010 new_sampler_contexts)
1012 def add_listener(self, listener):
1013 ''' Add a listener to the history
1015 The listening class can implement the following methods:
1016 * ``extend``
1017 '''
1018 self.listeners.append(listener)
1020 def emit(self, event_name, *args, **kwargs):
1021 for listener in self.listeners:
1022 slot = getattr(listener, event_name, None)
1023 if callable(slot):
1024 slot(*args, **kwargs)
1026 @property
1027 def attribute_names(self):
1028 apath = op.join(self.path, 'attributes')
1029 if not os.path.exists(apath):
1030 return []
1032 return [fn for fn in os.listdir(apath)
1033 if StringID.regex.match(fn)]
1035 def get_attribute(self, name):
1036 if name not in self._attributes:
1037 if name not in self.attribute_names:
1038 raise NoSuchAttribute(name)
1040 path = op.join(self.path, 'attributes', name)
1042 with open(path, 'rb') as f:
1043 self._attributes[name] = num.fromfile(
1044 f, dtype='<i4',
1045 count=self.nmodels).astype(int)
1047 assert self._attributes[name].shape == (self.nmodels,)
1049 return self._attributes[name]
1051 def set_attribute(self, name, attribute):
1052 if not StringID.regex.match(name):
1053 raise InvalidAttributeName(name)
1055 attribute = attribute.astype(int)
1056 assert attribute.shape == (self.nmodels,)
1058 apath = op.join(self.path, 'attributes')
1060 if not os.path.exists(apath):
1061 os.mkdir(apath)
1063 path = op.join(apath, name)
1065 with open(path, 'wb') as f:
1066 attribute.astype('<i4').tofile(f)
1068 self._attributes[name] = attribute
1070 def ensure_bootstrap_misfits(self, optimiser):
1071 if self.bootstrap_misfits is None:
1072 problem = self.problem
1073 self.bootstrap_misfits = problem.combine_misfits(
1074 self.misfits,
1075 extra_weights=optimiser.get_bootstrap_weights(problem),
1076 extra_residuals=optimiser.get_bootstrap_residuals(problem))
1078 def imodels_by_cluster(self, cluster_attribute):
1079 if cluster_attribute is None:
1080 return [(-1, 100.0, num.arange(self.nmodels))]
1082 by_cluster = []
1083 try:
1084 iclusters = self.get_attribute(cluster_attribute)
1085 iclusters_avail = num.unique(iclusters)
1087 for icluster in iclusters_avail:
1088 imodels = num.where(iclusters == icluster)[0]
1089 by_cluster.append(
1090 (icluster,
1091 (100.0 * imodels.size) / self.nmodels,
1092 imodels))
1094 if by_cluster and by_cluster[0][0] == -1:
1095 by_cluster.append(by_cluster.pop(0))
1097 except NoSuchAttribute:
1098 logger.warning(
1099 'Attribute %s not set in run %s.\n'
1100 ' Skipping model retrieval by clusters.' % (
1101 cluster_attribute, self.problem.name))
1103 return by_cluster
1105 def models_by_cluster(self, cluster_attribute):
1106 if cluster_attribute is None:
1107 return [(-1, 100.0, self.models)]
1109 return [
1110 (icluster, percentage, self.models[imodels])
1111 for (icluster, percentage, imodels)
1112 in self.imodels_by_cluster(cluster_attribute)]
1114 def mean_sources_by_cluster(self, cluster_attribute):
1115 return [
1116 (icluster, percentage, stats.get_mean_source(self.problem, models))
1117 for (icluster, percentage, models)
1118 in self.models_by_cluster(cluster_attribute)]
1120 def get_sorted_misfits_idx(self, chain=0):
1121 if chain not in self._sorted_misfit_idx.keys():
1122 self._sorted_misfit_idx[chain] = num.argsort(
1123 self.bootstrap_misfits[:, chain])
1125 return self._sorted_misfit_idx[chain]
1127 def get_sorted_misfits(self, chain=0):
1128 isort = self.get_sorted_misfits_idx(chain)
1129 return self.bootstrap_misfits[:, chain][isort]
1131 def get_sorted_models(self, chain=0):
1132 isort = self.get_sorted_misfits_idx(chain=0)
1133 return self.models[isort, :]
1135 def get_sorted_primary_misfits(self):
1136 return self.get_sorted_misfits(chain=0)
1138 def get_sorted_primary_models(self):
1139 return self.get_sorted_models(chain=0)
1141 def get_best_model(self, chain=0):
1142 return self.get_sorted_models(chain)[0, ...]
1144 def get_best_misfit(self, chain=0):
1145 return self.get_sorted_misfits(chain)[0]
1147 def get_mean_model(self):
1148 return num.mean(self.models, axis=0)
1150 def get_mean_misfit(self, chain=0):
1151 return num.mean(self.bootstrap_misfits[:, chain])
1153 def get_best_source(self, chain=0):
1154 return self.problem.get_source(self.get_best_model(chain))
1156 def get_mean_source(self, chain=0):
1157 return self.problem.get_source(self.get_mean_model())
1159 def get_chain_misfits(self, chain=0):
1160 return self.bootstrap_misfits[:, chain]
1162 def get_primary_chain_misfits(self):
1163 return self.get_chain_misfits(chain=0)
1166class RandomStateManager(object):
1168 MAX_LEN = 64
1169 save_struct = struct.Struct('%ds7s2496sqqd' % MAX_LEN)
1171 def __init__(self):
1172 self.rstates = {}
1173 self.lock = threading.Lock()
1175 def get_rstate(self, name, seed=None):
1176 assert len(name) <= self.MAX_LEN
1178 if name not in self.rstates:
1179 self.rstates[name] = num.random.RandomState(seed)
1180 return self.rstates[name]
1182 @property
1183 def nstates(self):
1184 return len(self.rstates)
1186 def save_state(self, fname):
1187 with self.lock:
1188 with open(fname, 'wb') as f:
1190 for name, rstate in self.rstates.items():
1191 s, arr, pos, has_gauss, chached_gauss = rstate.get_state()
1192 f.write(
1193 self.save_struct.pack(
1194 name.encode(), s.encode(), arr.tobytes(),
1195 pos, has_gauss, chached_gauss))
1197 def load_state(self, fname):
1198 with self.lock:
1199 with open(fname, 'rb') as f:
1200 while True:
1201 buff = f.read(self.save_struct.size)
1202 if not buff:
1203 break
1205 name, s, arr, pos, has_gauss, chached_gauss = \
1206 self.save_struct.unpack(buff)
1208 name = name.replace(b'\x00', b'').decode()
1209 s = s.replace(b'\x00', b'').decode()
1210 arr = num.frombuffer(arr, dtype=num.uint32)
1212 rstate = num.random.RandomState()
1213 rstate.set_state((s, arr, pos, has_gauss, chached_gauss))
1214 self.rstates[name] = rstate
1217def get_nmodels(dirname, problem):
1218 fn = op.join(dirname, 'models')
1219 with open(fn, 'r') as f:
1220 nmodels1 = os.fstat(f.fileno()).st_size // (problem.nparameters * 8)
1222 fn = op.join(dirname, 'misfits')
1223 with open(fn, 'r') as f:
1224 nmodels2 = os.fstat(f.fileno()).st_size // (problem.nmisfits * 2 * 8)
1226 return min(nmodels1, nmodels2)
1229def load_problem_info_and_data(dirname, subset=None, nchains=None):
1230 problem = load_problem_info(dirname)
1231 models, misfits, bootstraps, sampler_contexts = load_problem_data(
1232 xjoin(dirname, subset), problem, nchains=nchains)
1233 return problem, models, misfits, bootstraps, sampler_contexts
1236def load_optimiser_info(dirname):
1237 fn = op.join(dirname, 'optimiser.yaml')
1238 return guts.load(filename=fn)
1241def load_problem_info(dirname):
1242 try:
1243 fn = op.join(dirname, 'problem.yaml')
1244 return guts.load(filename=fn)
1245 except OSError as e:
1246 logger.debug(e)
1247 raise ProblemInfoNotAvailable(
1248 'No problem info available (%s).' % dirname)
1251def load_problem_data(dirname, problem, nmodels_skip=0, nchains=None):
1253 def get_chains_fn():
1254 for fn in (op.join(dirname, 'bootstraps'),
1255 op.join(dirname, 'chains')):
1256 if op.exists(fn):
1257 return fn
1258 return False
1260 try:
1261 nmodels = get_nmodels(dirname, problem) - nmodels_skip
1263 fn = op.join(dirname, 'models')
1264 offset = nmodels_skip * problem.nparameters * 8
1265 models = num.memmap(
1266 fn, dtype='<f8',
1267 offset=offset,
1268 shape=(nmodels, problem.nparameters))\
1269 .astype(float, copy=False)
1271 fn = op.join(dirname, 'misfits')
1272 offset = nmodels_skip * problem.nmisfits * 2 * 8
1273 misfits = num.memmap(
1274 fn, dtype='<f8',
1275 offset=offset,
1276 shape=(nmodels, problem.nmisfits, 2))\
1277 .astype(float, copy=False)
1279 chains = None
1280 fn = get_chains_fn()
1281 if fn and nchains is not None:
1282 offset = nmodels_skip * nchains * 8
1283 chains = num.memmap(
1284 fn, dtype='<f8',
1285 offset=offset,
1286 shape=(nmodels, nchains))\
1287 .astype(float, copy=False)
1289 sampler_contexts = None
1290 fn = op.join(dirname, 'choices')
1291 if op.exists(fn):
1292 offset = nmodels_skip * 4 * 8
1293 sampler_contexts = num.memmap(
1294 fn, dtype='<i8',
1295 offset=offset,
1296 shape=(nmodels, 4))\
1297 .astype(int, copy=False)
1299 fn = op.join(dirname, 'rstate')
1300 if op.exists(fn):
1301 problem.get_rstate_manager().load_state(fn)
1303 except OSError as e:
1304 logger.debug(str(e))
1305 raise ProblemDataNotAvailable(
1306 'No problem data available (%s).' % dirname)
1308 return models, misfits, chains, sampler_contexts
1311__all__ = '''
1312 ProblemConfig
1313 Problem
1314 ModelHistory
1315 ProblemInfoNotAvailable
1316 ProblemDataNotAvailable
1317 load_problem_info
1318 load_problem_info_and_data
1319 InvalidAttributeName
1320 NoSuchAttribute
1321'''.split()