Coverage for /usr/local/lib/python3.11/dist-packages/grond/config.py: 65%

186 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-11-01 12:39 +0000

1import logging 

2import re 

3import os.path as op 

4from pyrocko import guts, gf 

5from pyrocko.guts import Bool, List, String, Int, StringChoice 

6 

7from .meta import Path, HasPaths, GrondError 

8from .dataset import DatasetConfig 

9from .analysers.base import AnalyserConfig 

10from .analysers.target_balancing import TargetBalancingAnalyserConfig 

11from .problems.base import ProblemConfig 

12from .optimisers.base import OptimiserConfig 

13from .targets.base import TargetGroup 

14from .version import __version__ 

15 

16logger = logging.getLogger('grond.config') 

17 

18guts_prefix = 'grond' 

19 

20 

21def color_diff(diff): 

22 green = '\x1b[32m' 

23 red = '\x1b[31m' 

24 blue = '\x1b[34m' 

25 dim = '\x1b[2m' 

26 reset = '\x1b[0m' 

27 

28 for line in diff: 

29 if line.startswith('+'): 

30 yield green + line + reset 

31 elif line.startswith('-'): 

32 yield red + line + reset 

33 elif line.startswith('^'): 

34 yield blue + line + reset 

35 elif line.startswith('@'): 

36 yield dim + line + reset 

37 else: 

38 yield line 

39 

40 

41_log_level_choices = ['critical', 'error', 'warning', 'info', 'debug'] 

42 

43 

44class GlobalConfig(HasPaths): 

45 

46 loglevel = StringChoice.T( 

47 choices=_log_level_choices, 

48 default='info', 

49 help='Log level. Choices: %s. Default: %s' % ( 

50 ', '.join(_log_level_choices), 

51 'info')) 

52 

53 status = StringChoice.T( 

54 choices=['state', 'quiet'], 

55 default='state', 

56 help='Status output selection. Choices: state, quiet. Default: state)') 

57 

58 nthreads = Int.T( 

59 default=1, 

60 help='Number of threads to utilize. Default: 1') 

61 

62 nparallel = Int.T( 

63 default=1, 

64 help='Number of parallel processes to utilize. Default: 1') 

65 

66 def override_with_cli_arguments(self, args): 

67 if args.loglevel is not None: 

68 self.loglevel = args.loglevel 

69 

70 if args.status is not None: 

71 self.status = args.status 

72 

73 if args.nthreads is not None: 

74 self.nthreads = args.nthreads 

75 

76 if args.nparallel is not None: 

77 self.nparallel = args.nparallel 

78 

79 if self.nparallel != 1: 

80 self.status = 'quiet' 

81 

82 

83g_global_config = GlobalConfig() 

84 

85 

86def get_global_config(): 

87 return g_global_config 

88 

89 

90class EngineConfig(HasPaths): 

91 gf_stores_from_pyrocko_config = Bool.T( 

92 default=True, 

93 help='Load the GF stores from ~/.pyrocko/config') 

94 gf_store_superdirs = List.T( 

95 Path.T(), 

96 help='List of path hosting collection of Green\'s function stores.') 

97 gf_store_dirs = List.T( 

98 Path.T(), 

99 help='List of Green\'s function stores') 

100 

101 def __init__(self, *args, **kwargs): 

102 HasPaths.__init__(self, *args, **kwargs) 

103 self._engine = None 

104 

105 def get_engine(self): 

106 global_config = get_global_config() 

107 if self._engine is None: 

108 fp = self.expand_path 

109 if hasattr(gf.LocalEngine, 'nthreads'): 

110 extra_args = dict(nthreads=global_config.nthreads) 

111 else: 

112 extra_args = {} 

113 logger.warning( 

114 'The installed version of Pyrocko does not allow setting ' 

115 '`nthreads` through `pyrocko.gf.LocalEngine`, therefore ' 

116 'using single thread. Please upgrade Pyrocko to the ' 

117 'latest version.') 

118 

119 self._engine = gf.LocalEngine( 

120 use_config=self.gf_stores_from_pyrocko_config, 

121 store_superdirs=fp(self.gf_store_superdirs), 

122 store_dirs=fp(self.gf_store_dirs), 

123 **extra_args) 

124 

125 return self._engine 

126 

127 

128class Config(HasPaths): 

129 rundir_template = Path.T( 

130 help='Rundir for the optimisation, supports templating' 

131 ' (eg. ${event_name})') 

132 dataset_config = DatasetConfig.T( 

133 help='Dataset configuration object') 

134 target_groups = List.T( 

135 TargetGroup.T(), 

136 help='List of ``TargetGroup``s') 

137 problem_config = ProblemConfig.T( 

138 help='Problem config') 

139 analyser_configs = List.T( 

140 AnalyserConfig.T(), 

141 default=[TargetBalancingAnalyserConfig.D()], 

142 help='List of problem analysers') 

143 optimiser_config = OptimiserConfig.T( 

144 help='The optimisers configuration') 

145 engine_config = EngineConfig.T( 

146 default=EngineConfig.D(), 

147 help=':class:`pyrocko.gf.LocalEngine` configuration') 

148 event_names = List.T( 

149 String.T(), 

150 help='Restrict application to given event names. If empty, all events ' 

151 'found through the dataset configuration are considered.') 

152 event_names_exclude = List.T( 

153 String.T(), 

154 help='Event names to be excluded') 

155 

156 def __init__(self, *args, **kwargs): 

157 HasPaths.__init__(self, *args, **kwargs) 

158 

159 def get_event_names(self): 

160 if self.event_names: 

161 names = self.event_names 

162 else: 

163 names = self.dataset_config.get_event_names() 

164 

165 return [name for name in names if name not in self.event_names_exclude] 

166 

167 @property 

168 def nevents(self): 

169 return len(self.dataset_config.get_events()) 

170 

171 def get_dataset(self, event_name): 

172 return self.dataset_config.get_dataset(event_name) 

173 

174 def get_targets(self, event): 

175 ds = self.get_dataset(event.name) 

176 

177 targets = [] 

178 for igroup, target_group in enumerate(self.target_groups): 

179 targets.extend(target_group.get_targets( 

180 ds, event, 'target.%i' % igroup)) 

181 

182 return targets 

183 

184 def setup_modelling_environment(self, problem): 

185 problem.set_engine(self.engine_config.get_engine()) 

186 ds = self.get_dataset(problem.base_source.name) 

187 synt = ds.synthetic_test 

188 if synt: 

189 synt.set_problem(problem) 

190 problem.base_source = problem.get_source(synt.get_x()) 

191 

192 def get_problem(self, event): 

193 targets = self.get_targets(event) 

194 problem = self.problem_config.get_problem( 

195 event, self.target_groups, targets) 

196 self.setup_modelling_environment(problem) 

197 return problem 

198 

199 def get_elements(self, ypath): 

200 return list(guts.iter_elements(self, ypath)) 

201 

202 def set_elements(self, ypath, value): 

203 guts.set_elements(self, ypath, value, regularize=True) 

204 

205 def clone(self): 

206 return guts.clone(self) 

207 

208 

209def read_config(path): 

210 try: 

211 config = guts.load(filename=path) 

212 except OSError: 

213 raise GrondError( 

214 'Cannot read Grond configuration file: %s' % path) 

215 

216 if not isinstance(config, Config): 

217 raise GrondError('Invalid Grond configuration in file "%s".' % path) 

218 

219 config.set_basepath(op.dirname(path) or '.') 

220 return config 

221 

222 

223def write_config(config, path): 

224 try: 

225 basepath = config.get_basepath() 

226 dirname = op.dirname(path) or '.' 

227 config.change_basepath(dirname) 

228 guts.dump( 

229 config, 

230 filename=path, 

231 header='Grond configuration file, version %s' % __version__) 

232 

233 config.change_basepath(basepath) 

234 

235 except OSError: 

236 raise GrondError( 

237 'Cannot write Grond configuration file: %s' % path) 

238 

239 

240def diff_configs(path1, path2): 

241 import sys 

242 import difflib 

243 from pyrocko import guts_agnostic as aguts 

244 

245 t1 = aguts.load(filename=path1) 

246 t2 = aguts.load(filename=path2) 

247 

248 s1 = aguts.dump(t1) 

249 s2 = aguts.dump(t2) 

250 

251 result = list(difflib.unified_diff( 

252 s1.splitlines(1), s2.splitlines(1), 

253 'left', 'right')) 

254 

255 if sys.stdout.isatty(): 

256 sys.stdout.writelines(color_diff(result)) 

257 else: 

258 sys.stdout.writelines(result) 

259 

260 

261class YPathError(GrondError): 

262 pass 

263 

264 

265def parse_yname(yname): 

266 ident = r'[a-zA-Z][a-zA-Z0-9_]*' 

267 rint = r'-?[0-9]+' 

268 m = re.match( 

269 r'^(%s)(\[((%s)?(:)(%s)?|(%s))\])?$' 

270 % (ident, rint, rint, rint), yname) 

271 

272 if not m: 

273 raise YPathError('Syntax error in component: "%s"' % yname) 

274 

275 d = dict( 

276 name=m.group(1)) 

277 

278 if m.group(2): 

279 if m.group(5): 

280 istart = iend = None 

281 if m.group(4): 

282 istart = int(m.group(4)) 

283 if m.group(6): 

284 iend = int(m.group(6)) 

285 

286 d['slice'] = (istart, iend) 

287 else: 

288 d['index'] = int(m.group(7)) 

289 

290 return d 

291 

292 

293def _decend(obj, ynames): 

294 if ynames: 

295 for sobj in iter_get_obj(obj, ynames): 

296 yield sobj 

297 else: 

298 yield obj 

299 

300 

301def iter_get_obj(obj, ynames): 

302 yname = ynames.pop(0) 

303 d = parse_yname(yname) 

304 if d['name'] not in obj.T.propnames: 

305 raise AttributeError(d['name']) 

306 

307 obj = getattr(obj, d['name']) 

308 

309 if 'index' in d: 

310 sobj = obj[d['index']] 

311 for ssobj in _decend(sobj, ynames): 

312 yield ssobj 

313 

314 elif 'slice' in d: 

315 for i in range(*slice(*d['slice']).indices(len(obj))): 

316 sobj = obj[i] 

317 for ssobj in _decend(sobj, ynames): 

318 yield ssobj 

319 else: 

320 for sobj in _decend(obj, ynames): 

321 yield sobj 

322 

323 

324__all__ = ''' 

325 EngineConfig 

326 Config 

327 read_config 

328 write_config 

329 diff_configs 

330'''.split()