1from __future__ import print_function
3import sys
4import logging
5import time
6import copy
7import shutil
8import glob
9import math
10import os
11import numpy as num
12from contextlib import contextmanager
14from pyrocko.guts import Object, String, Float, List
15from pyrocko import gf, trace, guts, util, weeding
16from pyrocko import parimap, model, marker as pmarker
18from .dataset import NotFound, InvalidObject
19from .problems.base import Problem, load_problem_info_and_data, \
20 load_problem_data, ProblemDataNotAvailable
22from .optimisers.base import BadProblem
23from .targets.waveform.target import WaveformMisfitResult
24from .targets.base import dump_misfit_result_collection, \
25 MisfitResultCollection, MisfitResult, MisfitResultError
26from .meta import expand_template, GrondError, selected
27from .environment import Environment
28from .monitor import GrondMonitor
30logger = logging.getLogger('grond.core')
31guts_prefix = 'grond'
32op = os.path
35class RingBuffer(num.ndarray):
36 def __new__(cls, *args, **kwargs):
37 cls = num.ndarray.__new__(cls, *args, **kwargs)
38 cls.fill(0.)
39 return cls
41 def __init__(self, *args, **kwargs):
42 self.pos = 0
44 def put(self, value):
45 self[self.pos] = value
46 self.pos += 1
47 self.pos %= self.size
50def mahalanobis_distance(xs, mx, cov):
51 imask = num.diag(cov) != 0.
52 icov = num.linalg.inv(cov[imask, :][:, imask])
53 temp = xs[:, imask] - mx[imask]
54 return num.sqrt(num.sum(temp * num.dot(icov, temp.T).T, axis=1))
57@contextmanager
58def lock_rundir(rundir):
59 statefn = op.join(rundir, '.running')
60 if op.exists(statefn):
61 raise EnvironmentError('file %s already exists!' % statefn)
62 try:
63 with open(statefn, 'w') as f:
64 f.write('')
65 yield True
66 finally:
67 os.remove(statefn)
70class DirectoryAlreadyExists(GrondError):
71 pass
74def weed(origin, targets, limit, neighborhood=3):
76 azimuths = num.zeros(len(targets))
77 dists = num.zeros(len(targets))
78 for i, target in enumerate(targets):
79 _, azimuths[i] = target.azibazi_to(origin)
80 dists[i] = target.distance_to(origin)
82 badnesses = num.ones(len(targets), dtype=float)
83 deleted, meandists_kept = weeding.weed(
84 azimuths, dists, badnesses,
85 nwanted=limit,
86 neighborhood=neighborhood)
88 targets_weeded = [
89 target for (delete, target) in zip(deleted, targets) if not delete]
91 return targets_weeded, meandists_kept, deleted
94def sarr(a):
95 return ' '.join('%15g' % x for x in a)
98def forward(env, show='filtered'):
99 payload = []
100 if env.have_rundir():
101 env.setup_modelling()
102 history = env.get_history(subset='harvest')
103 xbest = history.get_best_model()
104 problem = env.get_problem()
105 ds = env.get_dataset()
106 payload.append((ds, problem, xbest))
108 else:
109 for event_name in env.get_selected_event_names():
110 env.set_current_event_name(event_name)
111 env.setup_modelling()
112 problem = env.get_problem()
113 ds = env.get_dataset()
114 xref = problem.preconstrain(problem.get_reference_model())
115 payload.append((ds, problem, xref))
117 all_trs = []
118 events = []
119 stations = {}
120 for (ds, problem, x) in payload:
121 results = problem.evaluate(x)
123 event = problem.get_source(x).pyrocko_event()
124 events.append(event)
126 for result in results:
127 if isinstance(result, WaveformMisfitResult):
128 if show == 'filtered':
129 result.filtered_obs.set_codes(location='ob')
130 result.filtered_syn.set_codes(location='sy')
131 all_trs.append(result.filtered_obs)
132 all_trs.append(result.filtered_syn)
133 elif show == 'processed':
134 result.processed_obs.set_codes(location='ob')
135 result.processed_syn.set_codes(location='sy')
136 all_trs.append(result.processed_obs)
137 all_trs.append(result.processed_syn)
138 else:
139 raise ValueError('Invalid argument for show: %s' % show)
141 for station in ds.get_stations():
142 stations[station.nsl()] = station
144 markers = []
145 for ev in events:
146 markers.append(pmarker.EventMarker(ev))
148 trace.snuffle(all_trs, markers=markers, stations=list(stations.values()))
151def harvest(
152 rundir, problem=None, nbest=10, force=False, weed=0,
153 export_fits=[]):
155 env = Environment([rundir])
156 optimiser = env.get_optimiser()
157 nchains = env.get_optimiser().nchains
159 if problem is None:
160 problem, xs, misfits, bootstrap_misfits, _ = \
161 load_problem_info_and_data(rundir, nchains=nchains)
162 else:
163 xs, misfits, bootstrap_misfits, _ = \
164 load_problem_data(rundir, problem, nchains=nchains)
166 logger.info('Harvesting problem "%s"...' % problem.name)
168 dumpdir = op.join(rundir, 'harvest')
169 if op.exists(dumpdir):
170 if force:
171 shutil.rmtree(dumpdir)
172 else:
173 raise DirectoryAlreadyExists(
174 'Harvest directory already exists: %s' % dumpdir)
176 util.ensuredir(dumpdir)
178 ibests_list = []
179 ibests = []
180 gms = bootstrap_misfits[:, 0]
181 isort = num.argsort(gms)
183 ibests_list.append(isort[:nbest])
185 if weed != 3:
186 for ibootstrap in range(optimiser.nbootstrap):
187 bms = bootstrap_misfits[:, ibootstrap]
188 isort = num.argsort(bms)
189 ibests_list.append(isort[:nbest])
190 ibests.append(isort[0])
192 if weed:
193 mean_gm_best = num.median(gms[ibests])
194 std_gm_best = num.std(gms[ibests])
195 ibad = set()
197 for ibootstrap, ibest in enumerate(ibests):
198 if gms[ibest] > mean_gm_best + std_gm_best:
199 ibad.add(ibootstrap)
201 ibests_list = [
202 ibests_ for (ibootstrap, ibests_) in enumerate(ibests_list)
203 if ibootstrap not in ibad]
205 ibests = num.concatenate(ibests_list)
207 if weed == 2:
208 ibests = ibests[gms[ibests] < mean_gm_best]
210 for i in ibests:
211 problem.dump_problem_data(dumpdir, xs[i], misfits[i, :, :])
213 if export_fits:
214 env.setup_modelling()
215 problem = env.get_problem()
216 history = env.get_history(subset='harvest')
218 for what in export_fits:
219 if what == 'best':
220 models = [history.get_best_model()]
221 elif what == 'mean':
222 models = [history.get_mean_model()]
223 elif what == 'ensemble':
224 models = history.models
225 else:
226 raise GrondError(
227 'Invalid option for harvest\'s export_fits argument: %s'
228 % what)
230 results = []
231 for x in models:
232 results.append([
233 (result if isinstance(result, MisfitResult)
234 else MisfitResultError(message=str(result))) for
235 result in problem.evaluate(x)])
237 emr = MisfitResultCollection(results=results)
239 dump_misfit_result_collection(
240 emr,
241 op.join(dumpdir, 'fits-%s.yaml' % what))
243 logger.info('Done harvesting problem "%s".' % problem.name)
246def cluster(rundir, clustering, metric):
247 env = Environment([rundir])
248 history = env.get_history(subset='harvest')
249 problem = history.problem
250 models = history.models
252 events = [problem.get_source(model).pyrocko_event() for model in models]
254 from grond.clustering import metrics
256 if metric not in metrics.metrics:
257 raise GrondError('Unknown metric: %s' % metric)
259 mat = metrics.compute_similarity_matrix(events, metric)
261 clusters = clustering.perform(mat)
263 labels = num.sort(num.unique(clusters))
264 bins = num.concatenate((labels, [labels[-1]+1]))
265 ns = num.histogram(clusters, bins)[0]
267 history.set_attribute('cluster', clusters)
269 for i in range(labels.size):
270 if labels[i] == -1:
271 logging.info(
272 'Number of unclustered events: %5i' % ns[i])
273 else:
274 logging.info(
275 'Number of events in cluster %i: %5i' % (labels[i], ns[i]))
278def get_event_names(config):
279 return config.get_event_names()
282def check_problem(problem):
283 if len(problem.targets) == 0:
284 raise GrondError('No targets available')
287def check(
288 config,
289 event_names=None,
290 target_string_ids=None,
291 show_waveforms=False,
292 n_random_synthetics=10,
293 stations_used_path=None):
295 markers = []
296 stations_used = {}
297 erroneous = []
298 for ievent, event_name in enumerate(event_names):
299 ds = config.get_dataset(event_name)
300 event = ds.get_event()
301 trs_all = []
302 try:
303 problem = config.get_problem(event)
305 _, nfamilies = problem.get_family_mask()
306 logger.info('Problem: %s' % problem.name)
307 logger.info('Number of target families: %i' % nfamilies)
308 logger.info('Number of targets (total): %i' % len(problem.targets))
310 if target_string_ids:
311 problem.targets = [
312 target for target in problem.targets
313 if util.match_nslc(target_string_ids, target.string_id())]
315 logger.info(
316 'Number of targets (selected): %i' % len(problem.targets))
318 check_problem(problem)
320 results_list = []
321 sources = []
322 if n_random_synthetics == 0:
323 x = problem.preconstrain(problem.get_reference_model())
324 sources.append(problem.base_source)
325 results = problem.evaluate(x)
326 results_list.append(results)
328 else:
329 for i in range(n_random_synthetics):
330 x = problem.get_random_model()
331 sources.append(problem.get_source(x))
332 results = problem.evaluate(x)
333 results_list.append(results)
335 if show_waveforms:
336 engine = config.engine_config.get_engine()
337 times = []
338 tdata = []
339 for target in problem.targets:
340 tobs_shift_group = []
341 tcuts = []
342 for source in sources:
343 tmin_fit, tmax_fit, tfade, tfade_taper = \
344 target.get_taper_params(engine, source)
346 times.extend((tmin_fit-tfade*2., tmax_fit+tfade*2.))
348 tobs, tsyn = target.get_pick_shift(engine, source)
349 if None not in (tobs, tsyn):
350 tobs_shift = tobs - tsyn
351 else:
352 tobs_shift = 0.0
354 tcuts.append(target.get_cutout_timespan(
355 tmin_fit+tobs_shift, tmax_fit+tobs_shift, tfade))
357 tobs_shift_group.append(tobs_shift)
359 tcuts = num.array(tcuts, dtype=float)
361 tdata.append((
362 tfade,
363 num.mean(tobs_shift_group),
364 (num.min(tcuts[:, 0]), num.max(tcuts[:, 1]))))
366 tmin = min(times)
367 tmax = max(times)
369 tmax += (tmax-tmin)*2
371 for (tfade, tobs_shift, tcut), target in zip(
372 tdata, problem.targets):
374 store = engine.get_store(target.store_id)
376 deltat = store.config.deltat
378 freqlimits = list(target.get_freqlimits())
379 freqlimits[2] = 0.45/deltat
380 freqlimits[3] = 0.5/deltat
381 freqlimits = tuple(freqlimits)
383 try:
384 trs_projected, trs_restituted, trs_raw, _ = \
385 ds.get_waveform(
386 target.codes,
387 tmin=tmin+tobs_shift,
388 tmax=tmax+tobs_shift,
389 tfade=tfade,
390 freqlimits=freqlimits,
391 deltat=deltat,
392 backazimuth=target.
393 get_backazimuth_for_waveform(),
394 debug=True)
396 except NotFound as e:
397 logger.warning(str(e))
398 continue
400 trs_projected = copy.deepcopy(trs_projected)
401 trs_restituted = copy.deepcopy(trs_restituted)
402 trs_raw = copy.deepcopy(trs_raw)
404 for trx in trs_projected + trs_restituted + trs_raw:
405 trx.shift(-tobs_shift)
406 trx.set_codes(
407 network='',
408 station=target.string_id(),
409 location='')
411 for trx in trs_projected:
412 trx.set_codes(location=trx.location + '2_proj')
414 for trx in trs_restituted:
415 trx.set_codes(location=trx.location + '1_rest')
417 for trx in trs_raw:
418 trx.set_codes(location=trx.location + '0_raw')
420 trs_all.extend(trs_projected)
421 trs_all.extend(trs_restituted)
422 trs_all.extend(trs_raw)
424 for source in sources:
425 tmin_fit, tmax_fit, tfade, tfade_taper = \
426 target.get_taper_params(engine, source)
428 markers.append(pmarker.Marker(
429 nslc_ids=[('', target.string_id(), '*_proj', '*')],
430 tmin=tmin_fit, tmax=tmax_fit))
432 markers.append(pmarker.Marker(
433 nslc_ids=[('', target.string_id(), '*_raw', '*')],
434 tmin=tcut[0]-tobs_shift, tmax=tcut[1]-tobs_shift,
435 kind=1))
437 else:
438 for itarget, target in enumerate(problem.targets):
440 nok = 0
441 for results in results_list:
442 result = results[itarget]
443 if not isinstance(result, gf.SeismosizerError):
444 nok += 1
446 if nok == 0:
447 sok = 'not used'
448 elif nok == len(results_list):
449 sok = 'ok'
450 try:
451 s = ds.get_station(target)
452 stations_used[s.nsl()] = s
453 except (NotFound, InvalidObject):
454 pass
455 else:
456 sok = 'not used (%i/%i ok)' % (nok, len(results_list))
458 logger.info('%-40s %s' % (
459 (target.string_id() + ':', sok)))
461 except GrondError as e:
462 logger.error('Event %i, "%s": %s' % (
463 ievent,
464 event.name or util.time_to_str(event.time),
465 str(e)))
467 erroneous.append(event)
469 if show_waveforms:
470 trace.snuffle(trs_all, stations=ds.get_stations(), markers=markers)
472 if stations_used_path:
473 stations = list(stations_used.values())
474 stations.sort(key=lambda s: s.nsl())
475 model.dump_stations(stations, stations_used_path)
477 if erroneous:
478 raise GrondError(
479 'Check failed for events: %s'
480 % ', '.join(ev.name for ev in erroneous))
483g_state = {}
486def go(environment,
487 force=False, preserve=False,
488 nparallel=1, status='state', nthreads=0):
490 g_data = (environment, force, preserve,
491 status, nparallel, nthreads)
492 g_state[id(g_data)] = g_data
494 nevents = environment.nevents_selected
495 for x in parimap.parimap(
496 process_event,
497 range(environment.nevents_selected),
498 [id(g_data)] * nevents,
499 nprocs=nparallel):
501 pass
504def process_event(ievent, g_data_id):
506 environment, force, preserve, status, nparallel, nthreads = \
507 g_state[g_data_id]
509 config = environment.get_config()
510 event_name = environment.get_selected_event_names()[ievent]
511 nevents = environment.nevents_selected
513 ds = config.get_dataset(event_name)
514 event = ds.get_event()
515 problem = config.get_problem(event)
517 tstart = time.time()
518 monitor = None
519 rundir = None
520 try:
522 synt = ds.synthetic_test
523 if synt:
524 problem.base_source = problem.get_source(synt.get_x())
526 check_problem(problem)
528 rundir = expand_template(
529 config.rundir_template,
530 dict(problem_name=problem.name))
531 environment.set_rundir_path(rundir)
533 if op.exists(rundir):
534 if preserve:
535 nold_rundirs = len(glob.glob(rundir + '*'))
536 shutil.move(rundir, rundir+'-old-%d' % (nold_rundirs))
537 elif force:
538 shutil.rmtree(rundir)
539 else:
540 logger.warning(
541 'Skipping problem "%s": rundir already exists: %s' % (
542 problem.name, rundir))
543 return
545 util.ensuredir(rundir)
547 logger.info(
548 'Starting event %i / %i' % (ievent+1, nevents))
550 logger.info('Rundir: %s' % rundir)
552 logger.info('Analysing problem "%s".' % problem.name)
554 for analyser_conf in config.analyser_configs:
555 analyser = analyser_conf.get_analyser()
556 analyser.analyse(problem, ds)
558 basepath = config.get_basepath()
559 config.change_basepath(rundir)
560 guts.dump(config, filename=op.join(rundir, 'config.yaml'))
561 config.change_basepath(basepath)
563 optimiser = config.optimiser_config.get_optimiser()
564 optimiser.set_nthreads(nthreads)
566 optimiser.init_bootstraps(problem)
567 problem.dump_problem_info(rundir)
569 xs_inject = None
570 synt = ds.synthetic_test
571 if synt and synt.inject_solution:
572 xs_inject = synt.get_x()[num.newaxis, :]
574 if xs_inject is not None:
575 from .optimisers import highscore
576 if not isinstance(optimiser, highscore.HighScoreOptimiser):
577 raise GrondError(
578 'Optimiser does not support injections.')
580 optimiser.sampler_phases[0:0] = [
581 highscore.InjectionSamplerPhase(xs_inject=xs_inject)]
583 with lock_rundir(rundir):
584 if status == 'state':
585 monitor = GrondMonitor.watch(rundir)
586 optimiser.optimise(
587 problem,
588 rundir=rundir)
590 harvest(rundir, problem, force=True)
592 except BadProblem as e:
593 logger.error(str(e))
595 except GrondError as e:
596 logger.error(str(e))
598 finally:
599 if monitor:
600 monitor.terminate()
602 tstop = time.time()
603 logger.info(
604 'Stop %i / %i (%g min)' % (ievent+1, nevents, (tstop - tstart)/60.))
606 if rundir:
607 logger.info(
608 'Done with problem "%s", rundir is "%s".' % (problem.name, rundir))
611class ParameterStats(Object):
612 name = String.T()
613 mean = Float.T()
614 std = Float.T()
615 best = Float.T()
616 minimum = Float.T()
617 percentile5 = Float.T()
618 percentile16 = Float.T()
619 median = Float.T()
620 percentile84 = Float.T()
621 percentile95 = Float.T()
622 maximum = Float.T()
624 def __init__(self, *args, **kwargs):
625 kwargs.update(zip(self.T.propnames, args))
626 Object.__init__(self, **kwargs)
628 def get_values_dict(self):
629 return dict(
630 (self.name+'.' + k, getattr(self, k))
631 for k in self.T.propnames
632 if k != 'name')
635class ResultStats(Object):
636 problem = Problem.T()
637 parameter_stats_list = List.T(ParameterStats.T())
639 def get_values_dict(self):
640 d = {}
641 for ps in self.parameter_stats_list:
642 d.update(ps.get_values_dict())
643 return d
646def make_stats(problem, models, gms, pnames=None):
647 ibest = num.argmin(gms)
648 rs = ResultStats(problem=problem)
649 if pnames is None:
650 pnames = problem.parameter_names
652 for pname in pnames:
653 iparam = problem.name_to_index(pname)
654 vs = problem.extract(models, iparam)
655 mi, p5, p16, median, p84, p95, ma = map(float, num.percentile(
656 vs, [0., 5., 16., 50., 84., 95., 100.]))
658 mean = float(num.mean(vs))
659 std = float(num.std(vs))
660 best = float(vs[ibest])
661 s = ParameterStats(
662 pname, mean, std, best, mi, p5, p16, median, p84, p95, ma)
664 rs.parameter_stats_list.append(s)
666 return rs
669def try_add_location_uncertainty(data, types):
670 vs = [data.get(k, None) for k in (
671 'north_shift.std', 'east_shift.std', 'depth.std')]
673 if None not in vs:
674 data['location_uncertainty'] = math.sqrt(sum(v**2 for v in vs))
675 types['location_uncertainty'] = float
678def format_stats(rs, fmt):
679 pname_to_pindex = dict(
680 (p.name, i) for (i, p) in enumerate(rs.parameter_stats_list))
682 values = []
683 headers = []
684 for x in fmt:
685 if x == 'problem.name':
686 headers.append(x)
687 values.append('%-16s' % rs.problem.name)
688 else:
689 pname, qname = x.split('.')
690 pindex = pname_to_pindex[pname]
691 values.append(
692 '%16.7g' % getattr(rs.parameter_stats_list[pindex], qname))
693 headers.append(x)
695 return ' '.join(values)
698def export(
699 what, rundirs, type=None, pnames=None, filename=None, selection=None,
700 effective_lat_lon=False):
702 if pnames is not None:
703 pnames_clean = [
704 pname.split('.')[0] for pname in pnames
705 if not pname.startswith('problem.')]
706 shortform = all(len(pname.split('.')) == 2 for pname in pnames)
707 else:
708 pnames_clean = None
709 shortform = False
711 if what == 'stats' and type is not None:
712 raise GrondError('Invalid argument combination: what=%s, type=%s' % (
713 repr(what), repr(type)))
715 if what != 'stats' and shortform:
716 raise GrondError('Invalid argument combination: what=%s, pnames=%s' % (
717 repr(what), repr(pnames)))
719 if what != 'stats' and type != 'vector' and pnames is not None:
720 raise GrondError(
721 'Invalid argument combination: what=%s, type=%s, pnames=%s' % (
722 repr(what), repr(type), repr(pnames)))
724 if filename is None:
725 out = sys.stdout
726 else:
727 out = open(filename, 'w')
729 if type is None:
730 type = 'event'
732 if shortform:
733 print('#', ' '.join(['%16s' % x for x in pnames]), file=out)
735 def dump(x, gm, indices):
736 if type == 'vector':
737 print(' ', ' '.join(
738 '%16.7g' % problem.extract(x, i) for i in indices),
739 '%16.7g' % gm, file=out)
741 elif type == 'source':
742 source = problem.get_source(x)
743 if effective_lat_lon:
744 source.set_origin(*source.effective_latlon)
745 guts.dump(source, stream=out)
747 elif type == 'event':
748 ev = problem.get_source(x).pyrocko_event()
749 if effective_lat_lon:
750 ev.set_origin(*ev.effective_latlon)
752 model.dump_events([ev], stream=out)
754 elif type == 'event-yaml':
755 ev = problem.get_source(x).pyrocko_event()
756 if effective_lat_lon:
757 ev.set_origin(*ev.effective_latlon)
758 guts.dump_all([ev], stream=out)
760 else:
761 raise GrondError('Invalid argument: type=%s' % repr(type))
763 header = None
764 for rundir in rundirs:
765 env = Environment(rundir)
766 info = env.get_run_info()
768 try:
769 history = env.get_history(subset='harvest')
770 except ProblemDataNotAvailable as e:
771 logger.error(
772 'Harvest not available (Did the run succeed?): %s' % str(e))
773 continue
775 problem = history.problem
776 models = history.models
777 misfits = history.get_primary_chain_misfits()
779 if selection:
780 rs = make_stats(
781 problem, models,
782 history.get_primary_chain_misfits())
784 data = dict(tags=info.tags)
785 types = dict(tags=(list, str))
787 for k, v in rs.get_values_dict().items():
788 data[k] = v
789 types[k] = float
791 try_add_location_uncertainty(data, types)
793 if not selected(selection, data=data, types=types):
794 continue
796 else:
797 rs = None
799 if type == 'vector':
800 pnames_take = pnames_clean or \
801 problem.parameter_names[:problem.nparameters]
803 indices = num.array(
804 [problem.name_to_index(pname) for pname in pnames_take])
806 if type == 'vector' and what in ('best', 'mean', 'ensemble'):
807 extra = ['global_misfit']
808 else:
809 extra = []
811 new_header = '# ' + ' '.join(
812 '%16s' % x for x in pnames_take + extra)
814 if type == 'vector' and header != new_header:
815 print(new_header, file=out)
817 header = new_header
818 else:
819 indices = None
821 if what == 'best':
822 x_best = history.get_best_model()
823 gm_best = history.get_best_misfit()
824 dump(x_best, gm_best, indices)
826 elif what == 'mean':
827 x_mean = history.get_mean_model()
828 gm_mean = history.get_mean_misfit(chain=0)
829 dump(x_mean, gm_mean, indices)
831 elif what == 'ensemble':
832 isort = num.argsort(misfits)
833 for i in isort:
834 dump(models[i], misfits[i], indices)
836 elif what == 'stats':
837 if not rs:
838 rs = make_stats(problem, models,
839 history.get_primary_chain_misfits(),
840 pnames_clean)
842 if shortform:
843 print(' ', format_stats(rs, pnames), file=out)
844 else:
845 print(rs, file=out)
847 else:
848 raise GrondError('Invalid argument: what=%s' % repr(what))
850 if out is not sys.stdout:
851 out.close()
854__all__ = '''
855 DirectoryAlreadyExists
856 forward
857 harvest
858 cluster
859 go
860 get_event_names
861 check
862 export
863'''.split()