Coverage for /usr/local/lib/python3.11/dist-packages/grond/environment.py: 75%
216 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
1import time
2import logging
3import shutil
4import os
6from grond.config import read_config, write_config
7from grond import meta, run_info
8from grond.problems.base import load_optimiser_info, load_problem_info, \
9 ModelHistory
11op = os.path
13logger = logging.getLogger('grond.environment')
16class GrondEnvironmentError(meta.GrondError):
17 pass
20class EventSelectionFailed(GrondEnvironmentError):
21 pass
24class NoCurrentEventAvailable(GrondEnvironmentError):
25 def __init__(self, message='no current event available'):
26 GrondEnvironmentError.__init__(self, message)
29class NoEventSelectionAvailable(GrondEnvironmentError):
30 def __init__(self, message='no event selection available'):
31 GrondEnvironmentError.__init__(self, message)
34class NoRundirAvailable(GrondEnvironmentError):
35 def __init__(self, message='no rundir available'):
36 GrondEnvironmentError.__init__(self, message)
39class NoPlotCollectionManagerAvailable(GrondEnvironmentError):
40 def __init__(self, message='no plot collection manager available'):
41 GrondEnvironmentError.__init__(self, message)
44class Environment(object):
46 def __init__(self, args=None, config=None, event_names=None):
48 self._current_event_name = None
49 self._selected_event_names = None
50 self._config = None
51 self._plot_collection_manager = None
53 if isinstance(args, str):
54 args = [args]
56 if not args and not config:
57 raise GrondEnvironmentError('missing arguments')
59 if config and event_names:
60 self._config_path = None
61 self._rundir_path = None
62 self._config = config
64 if isinstance(event_names, str):
65 event_names = [event_names]
66 self.set_selected_event_names(event_names)
68 elif op.isdir(args[0]):
69 self._rundir_path = args[0]
70 self._config_path = op.join(self._rundir_path, 'config.yaml')
72 else:
73 self._rundir_path = None
74 self._config_path = args[0]
75 self.set_selected_event_names(args[1:])
77 self.reset()
79 @classmethod
80 def discover(cls, rundir):
81 running_fn = op.join(rundir, '.running')
82 while op.exists(running_fn):
83 try:
84 cls.verify_rundir(rundir)
85 return cls([rundir])
86 except GrondEnvironmentError:
87 time.sleep(.25)
88 raise GrondEnvironmentError('could not discover rundir')
90 @staticmethod
91 def verify_rundir(rundir_path):
92 files = [
93 'config.yaml',
94 'problem.yaml',
95 'optimiser.yaml',
96 'misfits'
97 ]
98 for fn in files:
99 if not op.exists(op.join(rundir_path, fn)):
100 raise GrondEnvironmentError('inconsistent rundir')
102 def copy(self, destination, force=False):
103 ''' Copy the environment and return it '''
104 files = [
105 'config.yaml',
106 'problem.yaml',
107 'optimiser.yaml',
108 'misfits',
109 'models',
110 'choices',
111 'chains'
112 ]
114 if op.exists(destination) and not force:
115 raise OSError('Directory %s already exists' % destination)
117 destination = op.abspath(destination)
118 os.makedirs(destination, exist_ok=True)
120 for file in files:
121 src = op.join(self._rundir_path, file)
122 dest = op.join(destination, file)
124 if not op.isfile(src):
125 logger.debug('Cannot find file %s', src)
126 continue
127 logger.debug('Copying %s to %s', src, dest)
129 shutil.copy(src, dest)
131 cls = self.__class__
132 return cls(destination)
134 def reset(self):
135 self._histories = {}
136 self._dataset = None
137 self._optimiser = None
138 self._problem = None
140 def get_config(self):
141 if self._config is None:
142 self._config = read_config(self._config_path)
144 return self._config
146 def write_config(self):
147 write_config(self.get_config(), self.get_config_path())
149 def get_available_event_names(self):
150 return self.get_config().get_event_names()
152 def set_current_event_name(self, event_name):
153 self._current_event_name = event_name
154 self.reset()
156 def get_current_event_name(self):
157 if self._current_event_name is None:
158 try:
159 self.get_rundir_path()
160 self._current_event_name = self.get_problem().base_source.name
161 except NoRundirAvailable:
162 try:
163 event_names = self.get_selected_event_names()
164 if len(event_names) == 1:
165 self._current_event_name = event_names[0]
166 else:
167 raise NoCurrentEventAvailable()
169 except NoEventSelectionAvailable:
170 raise NoCurrentEventAvailable()
172 return self._current_event_name
174 def set_selected_event_names(self, args):
175 event_names = self.get_available_event_names()
176 if len(args) == 0:
177 if len(event_names) == 1:
178 self._selected_event_names = event_names
179 else:
180 if not event_names:
181 raise EventSelectionFailed(
182 'No event file found, check your config!')
183 raise EventSelectionFailed(
184 'Ambiguous event selection. Select from available events:'
185 '\n %s\n or \'all\' to use all available events'
186 % '\n '.join(event_names))
188 elif len(args) == 1 and args[0] == 'all':
189 self._selected_event_names = event_names
191 else:
192 self._selected_event_names = []
193 for event_name in args:
194 if event_name not in event_names:
195 self._selected_event_names = None
196 raise EventSelectionFailed(
197 'No such event: %s' % event_name)
199 self._selected_event_names.append(event_name)
201 @property
202 def nevents_selected(self):
203 return len(self.get_selected_event_names())
205 def get_selected_event_names(self):
206 if self._selected_event_names is None:
207 raise NoEventSelectionAvailable()
209 return self._selected_event_names
211 def get_dataset(self):
212 if self._dataset is None:
213 event_name = self.get_current_event_name()
214 self._dataset = self.get_config().get_dataset(event_name)
216 return self._dataset
218 def set_rundir_path(self, path):
219 self._rundir_path = path
221 def get_rundir_path(self):
222 if self._rundir_path is None:
223 raise NoRundirAvailable()
225 return self._rundir_path
227 def have_rundir(self):
228 return self._rundir_path is not None
230 def get_run_info_path(self):
231 return op.join(self.get_rundir_path(), 'run_info.yaml')
233 def get_run_info(self):
234 run_info_path = self.get_run_info_path()
235 if not op.exists(run_info_path):
236 info = run_info.RunInfo()
237 return info
238 else:
239 return run_info.read_info(run_info_path)
241 def set_run_info(self, info):
242 run_info_path = self.get_run_info_path()
243 run_info.write_info(info, run_info_path)
245 def get_optimiser(self):
246 if self._optimiser is None:
247 try:
248 self._optimiser = load_optimiser_info(self.get_rundir_path())
249 except NoRundirAvailable:
250 self._optimiser = \
251 self.get_config().optimiser_config.get_optimiser()
253 return self._optimiser
255 def get_problem(self):
256 if self._problem is None:
257 try:
258 self._problem = load_problem_info(self.get_rundir_path())
259 except NoRundirAvailable:
260 self._problem = \
261 self.get_config().get_problem(
262 self.get_dataset().get_event())
264 return self._problem
266 def get_history(self, subset=None):
267 if subset not in self._histories:
268 history = \
269 ModelHistory(
270 self.get_problem(),
271 nchains=self.get_optimiser().nchains,
272 path=meta.xjoin(self.get_rundir_path(), subset))
274 history.ensure_bootstrap_misfits(self.get_optimiser())
275 self._histories[subset] = history
277 return self._histories[subset]
279 def set_plot_collection_manager(self, pcm):
280 self._plot_collection_manager = pcm
282 def get_plot_collection_manager(self):
283 if self._plot_collection_manager is None:
284 raise NoPlotCollectionManagerAvailable()
286 return self._plot_collection_manager
288 def setup_modelling(self):
289 '''Must be called before any modelling can be done.'''
290 logger.debug('Setting up modelling...')
291 self.get_config().setup_modelling_environment(self.get_problem())
292 ds = self.get_dataset()
293 for target in self.get_problem().targets:
294 target.set_dataset(ds)
296 def get_plot_classes(self):
297 '''Discover all plot classes relevant for the setup.'''
299 plots = set()
300 try:
301 plots.update(self.get_problem().get_plot_classes())
302 except GrondEnvironmentError:
303 pass
305 try:
306 plots.update(self.get_optimiser().get_plot_classes())
307 except GrondEnvironmentError:
308 pass
310 try:
311 for target in self.get_problem().targets:
312 plots.update(target.get_plot_classes())
313 except GrondEnvironmentError:
314 pass
316 return sorted(list(plots), key=lambda plot: plot.name)
318 def get_plots_path(self):
319 try:
320 return op.join(self.get_rundir_path(), 'plots')
321 except NoRundirAvailable:
322 return 'plots'
324 def get_config_path(self):
325 return self._config_path
327 def is_running(self):
328 return op.exists(self.get_rundir_path, '.running')
331__all__ = [
332 'GrondEnvironmentError',
333 'EventSelectionFailed',
334 'NoCurrentEventAvailable',
335 'NoRundirAvailable',
336 'NoPlotCollectionManagerAvailable',
337 'Environment',
338]