Coverage for /usr/local/lib/python3.11/dist-packages/grond/config.py: 66%
186 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
1import logging
2import re
3import os.path as op
4from pyrocko import guts, gf
5from pyrocko.guts import Bool, List, String, Int, StringChoice
7from .meta import Path, HasPaths, GrondError
8from .dataset import DatasetConfig
9from .analysers.base import AnalyserConfig
10from .analysers.target_balancing import TargetBalancingAnalyserConfig
11from .problems.base import ProblemConfig
12from .optimisers.base import OptimiserConfig
13from .targets.base import TargetGroup
14from .version import __version__
16logger = logging.getLogger('grond.config')
18guts_prefix = 'grond'
21def color_diff(diff):
22 green = '\x1b[32m'
23 red = '\x1b[31m'
24 blue = '\x1b[34m'
25 dim = '\x1b[2m'
26 reset = '\x1b[0m'
28 for line in diff:
29 if line.startswith('+'):
30 yield green + line + reset
31 elif line.startswith('-'):
32 yield red + line + reset
33 elif line.startswith('^'):
34 yield blue + line + reset
35 elif line.startswith('@'):
36 yield dim + line + reset
37 else:
38 yield line
41_log_level_choices = ['critical', 'error', 'warning', 'info', 'debug']
44class GlobalConfig(HasPaths):
46 loglevel = StringChoice.T(
47 choices=_log_level_choices,
48 default='info',
49 help='Log level. Choices: %s. Default: %s' % (
50 ', '.join(_log_level_choices),
51 'info'))
53 status = StringChoice.T(
54 choices=['state', 'quiet'],
55 default='state',
56 help='Status output selection. Choices: state, quiet. Default: state)')
58 nthreads = Int.T(
59 default=1,
60 help='Number of threads to utilize. Default: 1')
62 nparallel = Int.T(
63 default=1,
64 help='Number of parallel processes to utilize. Default: 1')
66 def override_with_cli_arguments(self, args):
67 if args.loglevel is not None:
68 self.loglevel = args.loglevel
70 if args.status is not None:
71 self.status = args.status
73 if args.nthreads is not None:
74 self.nthreads = args.nthreads
76 if args.nparallel is not None:
77 self.nparallel = args.nparallel
79 if self.nparallel != 1:
80 self.status = 'quiet'
83g_global_config = GlobalConfig()
86def get_global_config():
87 return g_global_config
90class EngineConfig(HasPaths):
91 gf_stores_from_pyrocko_config = Bool.T(
92 default=True,
93 help='Load the GF stores from ~/.pyrocko/config')
94 gf_store_superdirs = List.T(
95 Path.T(),
96 help='List of path hosting collection of Green\'s function stores.')
97 gf_store_dirs = List.T(
98 Path.T(),
99 help='List of Green\'s function stores')
101 def __init__(self, *args, **kwargs):
102 HasPaths.__init__(self, *args, **kwargs)
103 self._engine = None
105 def get_engine(self):
106 if self._engine is None:
107 fp = self.expand_path
108 try:
109 extra_args = dict(nthreads=self.nthreads)
110 except AttributeError:
111 logger.warning(
112 'The installed version of Pyrocko does not allow setting '
113 '`nthreads` through `pyrocko.gf.LocalEngine`, therefore '
114 'using single thread. Please upgrade Pyrocko to the '
115 'latest version.')
117 extra_args = {}
119 self._engine = gf.LocalEngine(
120 use_config=self.gf_stores_from_pyrocko_config,
121 store_superdirs=fp(self.gf_store_superdirs),
122 store_dirs=fp(self.gf_store_dirs),
123 **extra_args)
125 return self._engine
128class Config(HasPaths):
129 rundir_template = Path.T(
130 help='Rundir for the optimisation, supports templating'
131 ' (eg. ${event_name})')
132 dataset_config = DatasetConfig.T(
133 help='Dataset configuration object')
134 target_groups = List.T(
135 TargetGroup.T(),
136 help='List of ``TargetGroup``s')
137 problem_config = ProblemConfig.T(
138 help='Problem config')
139 analyser_configs = List.T(
140 AnalyserConfig.T(),
141 default=[TargetBalancingAnalyserConfig.D()],
142 help='List of problem analysers')
143 optimiser_config = OptimiserConfig.T(
144 help='The optimisers configuration')
145 engine_config = EngineConfig.T(
146 default=EngineConfig.D(),
147 help=':class:`pyrocko.gf.LocalEngine` configuration')
148 event_names = List.T(
149 String.T(),
150 help='Restrict application to given event names. If empty, all events '
151 'found through the dataset configuration are considered.')
152 event_names_exclude = List.T(
153 String.T(),
154 help='Event names to be excluded')
156 def __init__(self, *args, **kwargs):
157 HasPaths.__init__(self, *args, **kwargs)
159 def get_event_names(self):
160 if self.event_names:
161 names = self.event_names
162 else:
163 names = self.dataset_config.get_event_names()
165 return [name for name in names if name not in self.event_names_exclude]
167 @property
168 def nevents(self):
169 return len(self.dataset_config.get_event_names())
171 def get_dataset(self, event_name):
172 return self.dataset_config.get_dataset(event_name)
174 def get_targets(self, event):
175 ds = self.get_dataset(event.name)
177 targets = []
178 for igroup, target_group in enumerate(self.target_groups):
179 targets.extend(target_group.get_targets(
180 ds, event, 'target.%i' % igroup))
182 return targets
184 def setup_modelling_environment(self, problem):
185 problem.set_engine(self.engine_config.get_engine())
186 ds = self.get_dataset(problem.base_source.name)
187 synt = ds.synthetic_test
188 if synt:
189 synt.set_problem(problem)
190 problem.base_source = problem.get_source(synt.get_x())
192 def get_problem(self, event):
193 targets = self.get_targets(event)
194 problem = self.problem_config.get_problem(
195 event, self.target_groups, targets)
196 self.setup_modelling_environment(problem)
197 return problem
199 def get_elements(self, ypath):
200 return list(guts.iter_elements(self, ypath))
202 def set_elements(self, ypath, value):
203 guts.set_elements(self, ypath, value, regularize=True)
205 def clone(self):
206 return guts.clone(self)
209def read_config(path):
210 try:
211 config = guts.load(filename=path)
212 except OSError:
213 raise GrondError(
214 'Cannot read Grond configuration file: %s' % path)
216 if not isinstance(config, Config):
217 raise GrondError('Invalid Grond configuration in file "%s".' % path)
219 config.set_basepath(op.dirname(path) or '.')
220 return config
223def write_config(config, path):
224 try:
225 basepath = config.get_basepath()
226 dirname = op.dirname(path) or '.'
227 config.change_basepath(dirname)
228 guts.dump(
229 config,
230 filename=path,
231 header='Grond configuration file, version %s' % __version__)
233 config.change_basepath(basepath)
235 except OSError:
236 raise GrondError(
237 'Cannot write Grond configuration file: %s' % path)
240def diff_configs(path1, path2):
241 import sys
242 import difflib
243 from pyrocko import guts_agnostic as aguts
245 t1 = aguts.load(filename=path1)
246 t2 = aguts.load(filename=path2)
248 s1 = aguts.dump(t1)
249 s2 = aguts.dump(t2)
251 result = list(difflib.unified_diff(
252 s1.splitlines(1), s2.splitlines(1),
253 'left', 'right'))
255 if sys.stdout.isatty():
256 sys.stdout.writelines(color_diff(result))
257 else:
258 sys.stdout.writelines(result)
261class YPathError(GrondError):
262 pass
265def parse_yname(yname):
266 ident = r'[a-zA-Z][a-zA-Z0-9_]*'
267 rint = r'-?[0-9]+'
268 m = re.match(
269 r'^(%s)(\[((%s)?(:)(%s)?|(%s))\])?$'
270 % (ident, rint, rint, rint), yname)
272 if not m:
273 raise YPathError('Syntax error in component: "%s"' % yname)
275 d = dict(
276 name=m.group(1))
278 if m.group(2):
279 if m.group(5):
280 istart = iend = None
281 if m.group(4):
282 istart = int(m.group(4))
283 if m.group(6):
284 iend = int(m.group(6))
286 d['slice'] = (istart, iend)
287 else:
288 d['index'] = int(m.group(7))
290 return d
293def _decend(obj, ynames):
294 if ynames:
295 for sobj in iter_get_obj(obj, ynames):
296 yield sobj
297 else:
298 yield obj
301def iter_get_obj(obj, ynames):
302 yname = ynames.pop(0)
303 d = parse_yname(yname)
304 if d['name'] not in obj.T.propnames:
305 raise AttributeError(d['name'])
307 obj = getattr(obj, d['name'])
309 if 'index' in d:
310 sobj = obj[d['index']]
311 for ssobj in _decend(sobj, ynames):
312 yield ssobj
314 elif 'slice' in d:
315 for i in range(*slice(*d['slice']).indices(len(obj))):
316 sobj = obj[i]
317 for ssobj in _decend(sobj, ynames):
318 yield ssobj
319 else:
320 for sobj in _decend(obj, ynames):
321 yield sobj
324__all__ = '''
325 EngineConfig
326 Config
327 read_config
328 write_config
329 diff_configs
330'''.split()