Coverage for /usr/local/lib/python3.11/dist-packages/grond/environment.py: 75%

216 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-26 16:25 +0000

1import time 

2import logging 

3import shutil 

4import os 

5 

6from grond.config import read_config, write_config 

7from grond import meta, run_info 

8from grond.problems.base import load_optimiser_info, load_problem_info, \ 

9 ModelHistory 

10 

11op = os.path 

12 

13logger = logging.getLogger('grond.environment') 

14 

15 

16class GrondEnvironmentError(meta.GrondError): 

17 pass 

18 

19 

20class EventSelectionFailed(GrondEnvironmentError): 

21 pass 

22 

23 

24class NoCurrentEventAvailable(GrondEnvironmentError): 

25 def __init__(self, message='no current event available'): 

26 GrondEnvironmentError.__init__(self, message) 

27 

28 

29class NoEventSelectionAvailable(GrondEnvironmentError): 

30 def __init__(self, message='no event selection available'): 

31 GrondEnvironmentError.__init__(self, message) 

32 

33 

34class NoRundirAvailable(GrondEnvironmentError): 

35 def __init__(self, message='no rundir available'): 

36 GrondEnvironmentError.__init__(self, message) 

37 

38 

39class NoPlotCollectionManagerAvailable(GrondEnvironmentError): 

40 def __init__(self, message='no plot collection manager available'): 

41 GrondEnvironmentError.__init__(self, message) 

42 

43 

44class Environment(object): 

45 

46 def __init__(self, args=None, config=None, event_names=None): 

47 

48 self._current_event_name = None 

49 self._selected_event_names = None 

50 self._config = None 

51 self._plot_collection_manager = None 

52 

53 if isinstance(args, str): 

54 args = [args] 

55 

56 if not args and not config: 

57 raise GrondEnvironmentError('missing arguments') 

58 

59 if config and event_names: 

60 self._config_path = None 

61 self._rundir_path = None 

62 self._config = config 

63 

64 if isinstance(event_names, str): 

65 event_names = [event_names] 

66 self.set_selected_event_names(event_names) 

67 

68 elif op.isdir(args[0]): 

69 self._rundir_path = args[0] 

70 self._config_path = op.join(self._rundir_path, 'config.yaml') 

71 

72 else: 

73 self._rundir_path = None 

74 self._config_path = args[0] 

75 self.set_selected_event_names(args[1:]) 

76 

77 self.reset() 

78 

79 @classmethod 

80 def discover(cls, rundir): 

81 running_fn = op.join(rundir, '.running') 

82 while op.exists(running_fn): 

83 try: 

84 cls.verify_rundir(rundir) 

85 return cls([rundir]) 

86 except GrondEnvironmentError: 

87 time.sleep(.25) 

88 raise GrondEnvironmentError('could not discover rundir') 

89 

90 @staticmethod 

91 def verify_rundir(rundir_path): 

92 files = [ 

93 'config.yaml', 

94 'problem.yaml', 

95 'optimiser.yaml', 

96 'misfits' 

97 ] 

98 for fn in files: 

99 if not op.exists(op.join(rundir_path, fn)): 

100 raise GrondEnvironmentError('inconsistent rundir') 

101 

102 def copy(self, destination, force=False): 

103 ''' Copy the environment and return it ''' 

104 files = [ 

105 'config.yaml', 

106 'problem.yaml', 

107 'optimiser.yaml', 

108 'misfits', 

109 'models', 

110 'choices', 

111 'chains' 

112 ] 

113 

114 if op.exists(destination) and not force: 

115 raise OSError('Directory %s already exists' % destination) 

116 

117 destination = op.abspath(destination) 

118 os.makedirs(destination, exist_ok=True) 

119 

120 for file in files: 

121 src = op.join(self._rundir_path, file) 

122 dest = op.join(destination, file) 

123 

124 if not op.isfile(src): 

125 logger.debug('Cannot find file %s', src) 

126 continue 

127 logger.debug('Copying %s to %s', src, dest) 

128 

129 shutil.copy(src, dest) 

130 

131 cls = self.__class__ 

132 return cls(destination) 

133 

134 def reset(self): 

135 self._histories = {} 

136 self._dataset = None 

137 self._optimiser = None 

138 self._problem = None 

139 

140 def get_config(self): 

141 if self._config is None: 

142 self._config = read_config(self._config_path) 

143 

144 return self._config 

145 

146 def write_config(self): 

147 write_config(self.get_config(), self.get_config_path()) 

148 

149 def get_available_event_names(self): 

150 return self.get_config().get_event_names() 

151 

152 def set_current_event_name(self, event_name): 

153 self._current_event_name = event_name 

154 self.reset() 

155 

156 def get_current_event_name(self): 

157 if self._current_event_name is None: 

158 try: 

159 self.get_rundir_path() 

160 self._current_event_name = self.get_problem().base_source.name 

161 except NoRundirAvailable: 

162 try: 

163 event_names = self.get_selected_event_names() 

164 if len(event_names) == 1: 

165 self._current_event_name = event_names[0] 

166 else: 

167 raise NoCurrentEventAvailable() 

168 

169 except NoEventSelectionAvailable: 

170 raise NoCurrentEventAvailable() 

171 

172 return self._current_event_name 

173 

174 def set_selected_event_names(self, args): 

175 event_names = self.get_available_event_names() 

176 if len(args) == 0: 

177 if len(event_names) == 1: 

178 self._selected_event_names = event_names 

179 else: 

180 if not event_names: 

181 raise EventSelectionFailed( 

182 'No event file found, check your config!') 

183 raise EventSelectionFailed( 

184 'Ambiguous event selection. Select from available events:' 

185 '\n %s\n or \'all\' to use all available events' 

186 % '\n '.join(event_names)) 

187 

188 elif len(args) == 1 and args[0] == 'all': 

189 self._selected_event_names = event_names 

190 

191 else: 

192 self._selected_event_names = [] 

193 for event_name in args: 

194 if event_name not in event_names: 

195 self._selected_event_names = None 

196 raise EventSelectionFailed( 

197 'No such event: %s' % event_name) 

198 

199 self._selected_event_names.append(event_name) 

200 

201 @property 

202 def nevents_selected(self): 

203 return len(self.get_selected_event_names()) 

204 

205 def get_selected_event_names(self): 

206 if self._selected_event_names is None: 

207 raise NoEventSelectionAvailable() 

208 

209 return self._selected_event_names 

210 

211 def get_dataset(self): 

212 if self._dataset is None: 

213 event_name = self.get_current_event_name() 

214 self._dataset = self.get_config().get_dataset(event_name) 

215 

216 return self._dataset 

217 

218 def set_rundir_path(self, path): 

219 self._rundir_path = path 

220 

221 def get_rundir_path(self): 

222 if self._rundir_path is None: 

223 raise NoRundirAvailable() 

224 

225 return self._rundir_path 

226 

227 def have_rundir(self): 

228 return self._rundir_path is not None 

229 

230 def get_run_info_path(self): 

231 return op.join(self.get_rundir_path(), 'run_info.yaml') 

232 

233 def get_run_info(self): 

234 run_info_path = self.get_run_info_path() 

235 if not op.exists(run_info_path): 

236 info = run_info.RunInfo() 

237 return info 

238 else: 

239 return run_info.read_info(run_info_path) 

240 

241 def set_run_info(self, info): 

242 run_info_path = self.get_run_info_path() 

243 run_info.write_info(info, run_info_path) 

244 

245 def get_optimiser(self): 

246 if self._optimiser is None: 

247 try: 

248 self._optimiser = load_optimiser_info(self.get_rundir_path()) 

249 except NoRundirAvailable: 

250 self._optimiser = \ 

251 self.get_config().optimiser_config.get_optimiser() 

252 

253 return self._optimiser 

254 

255 def get_problem(self): 

256 if self._problem is None: 

257 try: 

258 self._problem = load_problem_info(self.get_rundir_path()) 

259 except NoRundirAvailable: 

260 self._problem = \ 

261 self.get_config().get_problem( 

262 self.get_dataset().get_event()) 

263 

264 return self._problem 

265 

266 def get_history(self, subset=None): 

267 if subset not in self._histories: 

268 history = \ 

269 ModelHistory( 

270 self.get_problem(), 

271 nchains=self.get_optimiser().nchains, 

272 path=meta.xjoin(self.get_rundir_path(), subset)) 

273 

274 history.ensure_bootstrap_misfits(self.get_optimiser()) 

275 self._histories[subset] = history 

276 

277 return self._histories[subset] 

278 

279 def set_plot_collection_manager(self, pcm): 

280 self._plot_collection_manager = pcm 

281 

282 def get_plot_collection_manager(self): 

283 if self._plot_collection_manager is None: 

284 raise NoPlotCollectionManagerAvailable() 

285 

286 return self._plot_collection_manager 

287 

288 def setup_modelling(self): 

289 '''Must be called before any modelling can be done.''' 

290 logger.debug('Setting up modelling...') 

291 self.get_config().setup_modelling_environment(self.get_problem()) 

292 ds = self.get_dataset() 

293 for target in self.get_problem().targets: 

294 target.set_dataset(ds) 

295 

296 def get_plot_classes(self): 

297 '''Discover all plot classes relevant for the setup.''' 

298 

299 plots = set() 

300 try: 

301 plots.update(self.get_problem().get_plot_classes()) 

302 except GrondEnvironmentError: 

303 pass 

304 

305 try: 

306 plots.update(self.get_optimiser().get_plot_classes()) 

307 except GrondEnvironmentError: 

308 pass 

309 

310 try: 

311 for target in self.get_problem().targets: 

312 plots.update(target.get_plot_classes()) 

313 except GrondEnvironmentError: 

314 pass 

315 

316 return sorted(list(plots), key=lambda plot: plot.name) 

317 

318 def get_plots_path(self): 

319 try: 

320 return op.join(self.get_rundir_path(), 'plots') 

321 except NoRundirAvailable: 

322 return 'plots' 

323 

324 def get_config_path(self): 

325 return self._config_path 

326 

327 def is_running(self): 

328 return op.exists(self.get_rundir_path, '.running') 

329 

330 

331__all__ = [ 

332 'GrondEnvironmentError', 

333 'EventSelectionFailed', 

334 'NoCurrentEventAvailable', 

335 'NoRundirAvailable', 

336 'NoPlotCollectionManagerAvailable', 

337 'Environment', 

338]