Coverage for /usr/local/lib/python3.11/dist-packages/grond/optimisers/plot.py: 12%

282 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-26 16:25 +0000

1import logging 

2import numpy as num 

3from scipy import signal 

4 

5from matplotlib import cm, pyplot as plt 

6 

7from pyrocko.guts import Tuple, Float, Int, StringChoice, Bool 

8from pyrocko.plot import mpl_margins, mpl_graph_color, mpl_init 

9 

10from grond.plot.config import PlotConfig 

11from grond.plot.collection import PlotItem 

12 

13logger = logging.getLogger('grond.problem.plot') 

14 

15guts_prefix = 'grond' 

16 

17 

18def fixlim(lo, hi): 

19 if lo == hi: 

20 return lo - 1.0, hi + 1.0 

21 else: 

22 return lo, hi 

23 

24 

25class SequencePlot(PlotConfig): 

26 ''' 

27 Draws all parameter values evaluated during the optimisation 

28 

29 The sequence of all the parameter values is either a function of the 

30 optimisation in progress or of the misfit from high to low. This plot can 

31 be used to check on convergence or see if model parameters push the given 

32 bounds. The color always shows the relative misfit. Relatively high misfits 

33 are in cold blue colors and relatively low misfits in red. The last panel 

34 gives the corresponding misfit values. 

35 ''' 

36 

37 name = 'sequence' 

38 size_cm = Tuple.T(2, Float.T(), default=(14., 6.)) 

39 misfit_cutoff = Float.T(optional=True) 

40 ibootstrap = Int.T(optional=True) 

41 sort_by = StringChoice.T( 

42 choices=['iteration', 'misfit'], 

43 default='iteration') 

44 subplot_layout = Tuple.T(2, Int.T(), default=(1, 1)) 

45 marker_size = Float.T(default=1.5) 

46 show_reference = Bool.T(default=True) 

47 

48 def make(self, environ): 

49 cm = environ.get_plot_collection_manager() 

50 history = environ.get_history() 

51 optimiser = environ.get_optimiser() 

52 

53 mpl_init(fontsize=self.font_size) 

54 cm.create_group_mpl( 

55 self, 

56 self.draw_figures(history, optimiser), 

57 title=u'Sequence Plots', 

58 section='optimiser', 

59 description=u''' 

60Sequence plots for all parameters of the optimisation. 

61 

62The sequence of all the parameter values is either a function of the 

63optimisation in progress or of the misfit from high to low. This plot can be 

64used to check on convergence or to see if model parameters push the given 

65bounds. 

66 

67The color always shows the relative misfit. Relatively high misfits are in 

68cold blue colors and relatively low misfits in red. The last panel gives the 

69corresponding misfit values. 

70''', 

71 feather_icon='fast-forward') 

72 

73 def draw_figures(self, history, optimiser): 

74 misfit_cutoff = self.misfit_cutoff 

75 sort_by = self.sort_by 

76 

77 problem = history.problem 

78 models = history.models 

79 

80 npar = problem.nparameters 

81 ndep = problem.ndependants 

82 fontsize = self.font_size 

83 nfx, nfy = self.subplot_layout 

84 

85 imodels = num.arange(history.nmodels) 

86 bounds = problem.get_combined_bounds() 

87 

88 xref = problem.get_reference_model() 

89 

90 gms = history.get_primary_chain_misfits() 

91 gms_softclip = num.where(gms > 1.0, 0.2 * num.log10(gms) + 1.0, gms) 

92 

93 isort = num.argsort(gms)[::-1] 

94 

95 if sort_by == 'iteration': 

96 imodels = imodels[isort] 

97 elif sort_by == 'misfit': 

98 imodels = num.arange(imodels.size) 

99 else: 

100 assert False 

101 

102 gms = gms[isort] 

103 gms_softclip = gms_softclip[isort] 

104 models = models[isort, :] 

105 

106 iorder = num.empty_like(isort) 

107 iorder = num.arange(iorder.size) 

108 

109 if misfit_cutoff is None: 

110 ibest = num.ones(gms.size, dtype=bool) 

111 else: 

112 ibest = gms < misfit_cutoff 

113 

114 def config_axes(axes, nfx, nfy, impl, iplot, nplots): 

115 if (impl - 1) % nfx != nfx - 1: 

116 axes.get_yaxis().tick_left() 

117 

118 if (impl - 1) >= (nfx * (nfy - 1)) or iplot >= nplots - nfx: 

119 axes.set_xlabel('Iteration') 

120 if not (impl - 1) // nfx == 0: 

121 axes.get_xaxis().tick_bottom() 

122 elif (impl - 1) // nfx == 0: 

123 axes.get_xaxis().tick_top() 

124 axes.set_xticklabels([]) 

125 else: 

126 axes.get_xaxis().set_visible(False) 

127 

128 # nfz = (npar + ndep + 1 - 1) / (nfx*nfy) + 1 

129 cmap = cm.YlOrRd 

130 cmap = cm.jet 

131 msize = self.marker_size 

132 axes = None 

133 fig = None 

134 item_fig = None 

135 nfigs = 0 

136 alpha = 0.5 

137 for ipar in range(npar): 

138 impl = ipar % (nfx * nfy) + 1 

139 

140 if impl == 1: 

141 if item_fig: 

142 yield item_fig 

143 nfigs += 1 

144 

145 fig = plt.figure(figsize=self.size_inch) 

146 labelpos = mpl_margins( 

147 fig, nw=nfx, nh=nfy, 

148 left=7., 

149 right=2., 

150 top=1., 

151 bottom=5., 

152 wspace=7., hspace=2., units=fontsize) 

153 

154 item = PlotItem(name='fig_%i' % (nfigs+1)) 

155 item.attributes['parameters'] = [] 

156 item_fig = (item, fig) 

157 

158 par = problem.parameters[ipar] 

159 

160 item_fig[0].attributes['parameters'].append(par.name) 

161 

162 axes = fig.add_subplot(nfy, nfx, impl) 

163 labelpos(axes, 2.5, 2.0) 

164 

165 axes.set_ylabel(par.get_label()) 

166 axes.get_yaxis().set_major_locator(plt.MaxNLocator(4)) 

167 

168 config_axes(axes, nfx, nfy, impl, ipar, npar + ndep + 1) 

169 

170 axes.set_ylim(*fixlim(*par.scaled(bounds[ipar]))) 

171 axes.set_xlim(0, history.nmodels) 

172 

173 axes.scatter( 

174 imodels[ibest], par.scaled(models[ibest, ipar]), s=msize, 

175 c=iorder[ibest], edgecolors='none', cmap=cmap, alpha=alpha, 

176 rasterized=True) 

177 

178 if self.show_reference: 

179 axes.axhline(par.scaled(xref[ipar]), color='black', alpha=0.3) 

180 

181 for idep in range(ndep): 

182 # ifz, ify, ifx = num.unravel_index(ipar, (nfz, nfy, nfx)) 

183 impl = (npar + idep) % (nfx * nfy) + 1 

184 

185 if impl == 1: 

186 if item_fig: 

187 yield item_fig 

188 nfigs += 1 

189 

190 fig = plt.figure(figsize=self.size_inch) 

191 labelpos = mpl_margins( 

192 fig, nw=nfx, nh=nfy, 

193 left=7., 

194 right=2., 

195 top=1., 

196 bottom=5., 

197 wspace=7., hspace=2., units=fontsize) 

198 

199 item = PlotItem(name='fig_%i' % (nfigs+1)) 

200 item.attributes['parameters'] = [] 

201 

202 item_fig = (item, fig) 

203 

204 par = problem.dependants[idep] 

205 item_fig[0].attributes['parameters'].append(par.name) 

206 

207 axes = fig.add_subplot(nfy, nfx, impl) 

208 labelpos(axes, 2.5, 2.0) 

209 

210 axes.set_ylabel(par.get_label()) 

211 axes.get_yaxis().set_major_locator(plt.MaxNLocator(4)) 

212 

213 config_axes(axes, nfx, nfy, impl, npar + idep, npar + ndep + 1) 

214 

215 axes.set_ylim(*fixlim(*par.scaled(bounds[npar + idep]))) 

216 axes.set_xlim(0, history.nmodels) 

217 

218 ys = problem.make_dependant(models[ibest, :], par.name) 

219 axes.scatter( 

220 imodels[ibest], par.scaled(ys), s=msize, c=iorder[ibest], 

221 edgecolors='none', cmap=cmap, alpha=alpha, rasterized=True) 

222 

223 if self.show_reference: 

224 y = problem.make_dependant(xref, par.name) 

225 axes.axhline(par.scaled(y), color='black', alpha=0.3) 

226 

227 impl = (npar + ndep) % (nfx * nfy) + 1 

228 if impl == 1: 

229 if item_fig: 

230 yield item_fig 

231 nfigs += 1 

232 

233 fig = plt.figure(figsize=self.size_inch) 

234 labelpos = mpl_margins( 

235 fig, nw=nfx, nh=nfy, 

236 left=7., 

237 right=2., 

238 top=1., 

239 bottom=5., 

240 wspace=7., hspace=2., units=fontsize) 

241 

242 item = PlotItem(name='fig_%i' % (nfigs+1)) 

243 item.attributes['parameters'] = [] 

244 

245 item_fig = (item, fig) 

246 

247 axes = fig.add_subplot(nfy, nfx, impl) 

248 labelpos(axes, 2.5, 2.0) 

249 

250 config_axes(axes, nfx, nfy, impl, npar + ndep, npar + ndep + 1) 

251 

252 axes.set_ylim(0., 1.5) 

253 axes.set_yticks([0., 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4]) 

254 axes.set_yticklabels( 

255 ['0.0', '0.2', '0.4', '0.6', '0.8', '1', '10', '100']) 

256 

257 axes.scatter( 

258 imodels[ibest], gms_softclip[ibest], c=iorder[ibest], 

259 s=msize, edgecolors='none', cmap=cmap, alpha=alpha) 

260 

261 axes.axhspan(1.0, 1.5, color=(0.8, 0.8, 0.8), alpha=0.2) 

262 axes.axhline(1.0, color=(0.5, 0.5, 0.5), zorder=2) 

263 

264 axes.set_xlim(0, history.nmodels) 

265 axes.set_xlabel('Iteration') 

266 

267 axes.set_ylabel('Misfit') 

268 

269 yield item_fig 

270 nfigs += 1 

271 

272 

273class ContributionsPlot(PlotConfig): 

274 ''' Relative contribution of single targets to the global misfit 

275 ''' 

276 

277 name = 'contributions' 

278 size_cm = Tuple.T(2, Float.T(), default=(21., 14.9)) 

279 

280 def make(self, environ): 

281 cm = environ.get_plot_collection_manager() 

282 history = environ.get_history() 

283 optimiser = environ.get_optimiser() 

284 dataset = environ.get_dataset() 

285 

286 environ.setup_modelling() 

287 

288 mpl_init(fontsize=self.font_size) 

289 cm.create_group_mpl( 

290 self, 

291 self.draw_figures(dataset, history, optimiser), 

292 title=u'Target Contributions', 

293 section='solution', 

294 feather_icon='thermometer', 

295 description=u''' 

296Contributions of the targets to the total misfit. 

297 

298The relative contribution that each single target has in the global misfit 

299result is plotted relative and unscales as a function of global misfit 

300(descending). 

301 

302The target contribution is shown in color-filled curves with the bottom curve 

303on the bottom and the best-fit target on top. This plot can be used to analyse 

304the balance of targets in the optimisations. For ideal configurations, the 

305target contributions are of similar size. If the contribution of a single 

306target is much larger than those of all others, the weighting should be 

307modified. 

308''') 

309 

310 def draw_figures(self, dataset, history, optimiser): 

311 

312 fontsize = self.font_size 

313 

314 fig = plt.figure(figsize=self.size_inch) 

315 labelpos = mpl_margins(fig, nw=2, nh=2, w=7., h=5., wspace=2., 

316 hspace=5., units=fontsize) 

317 

318 problem = history.problem 

319 if not problem: 

320 logger.warning('Problem not set.') 

321 return [] 

322 

323 models = history.models 

324 

325 if models.size == 0: 

326 logger.warning('Empty models vector.') 

327 return [] 

328 

329 for target in problem.targets: 

330 target.set_dataset(dataset) 

331 

332 imodels = num.arange(history.nmodels) 

333 

334 gms = history.get_sorted_primary_misfits()[::-1] 

335 isort = history.get_sorted_misfits_idx(chain=0)[::-1] 

336 

337 gms **= problem.norm_exponent 

338 gms_softclip = num.where(gms > 1.0, 0.1 * num.log10(gms) + 1.0, gms) 

339 

340 gcms = problem.combine_misfits( 

341 history.misfits, 

342 extra_correlated_weights=optimiser.get_correlated_weights(problem), 

343 get_contributions=True) 

344 

345 gcms = gcms[isort, :] 

346 nmisfits = gcms.shape[1] # noqa 

347 

348 ncontributions = sum([1 if t.plot_misfits_cumulative else t.nmisfits 

349 for t in problem.targets]) 

350 cum_gcms = num.zeros((history.nmodels, ncontributions)) 

351 

352 # Squash matrix and sum large targets.nmisifts, eg SatelliteTarget 

353 plot_target_labels = [] 

354 idx = 0 

355 idx_cum = 0 

356 for itarget, target in enumerate(problem.targets): 

357 target_gcms = gcms[:, idx:idx+target.nmisfits] 

358 if target.plot_misfits_cumulative: 

359 cum_gcms[:, idx_cum] = target_gcms.sum(axis=1) 

360 plot_target_labels.append(target.string_id()) 

361 idx_cum += 1 

362 else: 

363 cum_gcms[:, idx_cum:idx_cum+target.nmisfits] = target_gcms 

364 plot_target_labels.extend(target.misfits_string_ids()) 

365 idx_cum += target.nmisfits 

366 idx += target.nmisfits 

367 

368 jsort = num.argsort(cum_gcms[-1, :])[::-1] 

369 

370 # ncols = 4 

371 # nrows = ((problem.ntargets + 1) - 1) / ncols + 1 

372 

373 axes = fig.add_subplot(2, 2, 1) 

374 labelpos(axes, 2.5, 2.0) 

375 

376 axes.set_ylabel('Relative contribution (smoothed)') 

377 axes.set_ylim(0.0, 1.0) 

378 

379 axes2 = fig.add_subplot(2, 2, 3, sharex=axes) 

380 labelpos(axes2, 2.5, 2.0) 

381 

382 axes2.set_xlabel( 

383 'Tested model, sorted descending by global misfit value') 

384 

385 axes2.set_ylabel('Square of misfit') 

386 

387 axes2.set_ylim(0., 1.5) 

388 axes2.axhspan(1.0, 1.5, color=(0.8, 0.8, 0.8)) 

389 axes2.set_yticks( 

390 [0., 0.2, 0.4, 0.6, 0.8, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5]) 

391 axes2.set_yticklabels( 

392 ['0.0', '0.2', '0.4', '0.6', '0.8', '1', '10', '100', '1000', 

393 '10000', '100000']) 

394 

395 axes2.set_xlim(imodels[0], imodels[-1]) 

396 

397 rel_ms_sum = num.zeros(history.nmodels) 

398 rel_ms_smooth_sum = num.zeros(history.nmodels) 

399 ms_smooth_sum = num.zeros(history.nmodels) 

400 b = num.hanning(min(100, history.nmodels//5)) 

401 b /= num.sum(b) 

402 a = [1] 

403 ii = 0 

404 

405 for idx in jsort: 

406 target_label = plot_target_labels[idx] 

407 ms = cum_gcms[:, idx] 

408 

409 ms = num.where(num.isfinite(ms), ms, 0.0) 

410 if num.all(ms == 0.0): 

411 continue 

412 

413 rel_ms = ms / gms 

414 

415 if b.shape[0] > 5: 

416 rel_ms_smooth = signal.filtfilt(b, a, rel_ms) 

417 else: 

418 rel_ms_smooth = rel_ms 

419 

420 ms_smooth = rel_ms_smooth * gms_softclip 

421 

422 rel_poly_y = num.concatenate( 

423 [rel_ms_smooth_sum[::-1], rel_ms_smooth_sum + rel_ms_smooth]) 

424 poly_x = num.concatenate([imodels[::-1], imodels]) 

425 

426 add_args = {} 

427 if ii < 20: 

428 add_args['label'] = '%s (%.2g)' % ( 

429 target_label, num.mean(rel_ms[-1])) 

430 

431 axes.fill( 

432 poly_x, rel_poly_y, 

433 alpha=0.5, 

434 color=mpl_graph_color(ii), 

435 **add_args) 

436 

437 poly_y = num.concatenate( 

438 [ms_smooth_sum[::-1], ms_smooth_sum + ms_smooth]) 

439 

440 axes2.fill(poly_x, poly_y, alpha=0.5, color=mpl_graph_color(ii)) 

441 

442 rel_ms_sum += rel_ms 

443 

444 # axes.plot( 

445 # imodels, rel_ms_sum, color='black', alpha=0.1, zorder=-1) 

446 

447 ms_smooth_sum += ms_smooth 

448 rel_ms_smooth_sum += rel_ms_smooth 

449 ii += 1 

450 

451 axes.legend( 

452 title='Contributions (top twenty)', 

453 bbox_to_anchor=(1.05, 0.0, 1.0, 1.0), 

454 loc='upper left', 

455 ncol=1, borderaxespad=0., prop={'size': 9}) 

456 

457 axes2.plot(imodels, gms_softclip, color='black') 

458 axes2.axhline(1.0, color=(0.5, 0.5, 0.5)) 

459 

460 return [[PlotItem(name='main'), fig]] 

461 

462 

463class BootstrapPlot(PlotConfig): 

464 ''' 

465 Sorted misfit (descending) of single bootstrap chains 

466 

467 For each bootstrap configuration, all models are sorted according to their 

468 misfit value (red lines) and their global misfit value (black line). (They 

469 are sorted individually for each line). The best model of every bootstrap 

470 configuration (right end model of red lines) is marked as a cross in the 

471 global misfit configuration. The horizontal black lines indicate mean and 

472 +- standard deviation of the y-axis values of these crosses. If the 

473 bootstrap configurations converge to the same region in model-space, all 

474 crosses should be close to the right end of the plot. If this is not the 

475 case, some bootstrap configurations have converged to very different places 

476 in model-space. This would be an indicator that there might be 

477 inconsistencies in the observations (maybe due to faulty or noisy or 

478 misoriented data). Also the shape of the curve in general can give 

479 information. A well-behaved optimisation run has approximately linear 

480 functions in this plot. Only at the end they should have a higher downward 

481 gradient. This would be the place where the objective functions of the 

482 bootstrap start to disagree. 

483 ''' 

484 

485 name = 'bootstrap' 

486 size_cm = Tuple.T(2, Float.T(), default=(21., 14.9)) 

487 show_ticks = Bool.T(default=False) 

488 

489 def make(self, environ): 

490 cm = environ.get_plot_collection_manager() 

491 history = environ.get_history() 

492 optimiser = environ.get_optimiser() 

493 mpl_init(fontsize=self.font_size) 

494 cm.create_group_mpl( 

495 self, 

496 self.draw_figures(history, optimiser), 

497 title=u'Bootstrap Misfit', 

498 section='optimiser', 

499 feather_icon='trending-down', 

500 description=u''' 

501Sorted misfit (descending) of single bootstrap chains. 

502 

503For each bootstrap configuration, all models are sorted according to their 

504misfit value (red lines) and their global misfit value (black line). (They are 

505sorted individually for each line). The best model of every bootstrap 

506configuration (right end model of red lines) is marked as a cross in the global 

507misfit configuration. The horizontal black lines indicate mean and +- standard 

508deviation of the y-axis values of these crosses. 

509 

510If the bootstrap configurations converge to the same region in model-space, all 

511crosses should be close to the right end of the plot. If this is not the case, 

512some bootstrap configurations have converged to very different places in 

513model-space. This would indicate that there might be inconsistencies in the 

514observations (maybe due to faulty or noisy or misoriented data). Also the shape 

515of the curve in general can give information. A well-behaved optimisation run 

516has approximately linear functions in this plot. Only at the end they should 

517have a higher downward gradient. This would be the place where the objective 

518functions of the bootstrap start to disagree. 

519''') 

520 

521 def draw_figures(self, history, optimiser): 

522 

523 fig = plt.figure() 

524 

525 imodels = num.arange(history.nmodels) 

526 gms = history.bootstrap_misfits[:, 0] 

527 

528 gms_softclip = num.where(gms > 1.0, 

529 0.1 * num.log10(gms) + 1.0, 

530 gms) 

531 axes = fig.add_subplot(1, 1, 1) 

532 

533 ibests = [] 

534 for ibootstrap in range(history.nchains): 

535 # if ibootstrap ==0: 

536 # global, no-bootstrapping misfits, chain 

537 # gms = history.bootstrap_misfits[:, ibootstrap] 

538 # gms_softclip = num.where(gms > 1.0, 

539 # 0.1 * num.log10(gms) + 1.0, 

540 # gms) 

541 

542 bms = history.bootstrap_misfits[:, ibootstrap] 

543 isort_bms = num.argsort(bms)[::-1] 

544 ibests.append(isort_bms[-1]) 

545 

546 bms_softclip = num.where( 

547 bms > 1.0, 0.1 * num.log10(bms) + 1.0, bms) 

548 axes.plot(imodels, bms_softclip[isort_bms], color='red', alpha=0.2) 

549 

550 isort = num.argsort(gms)[::-1] 

551 iorder = num.empty(isort.size) 

552 iorder[isort] = imodels 

553 

554 axes.plot(iorder[ibests], gms_softclip[ibests], 'x', color='black') 

555 

556 m = num.median(gms_softclip[ibests]) 

557 s = num.std(gms_softclip[ibests]) 

558 axes.axhline(m + s, color='black', alpha=0.5) 

559 axes.axhline(m, color='black') 

560 axes.axhline(m - s, color='black', alpha=0.5) 

561 

562 axes.plot(imodels, gms_softclip[isort], color='black') 

563 

564 axes.set_xlim(imodels[0], imodels[-1]) 

565 axes.set_xlabel( 

566 'Tested model, sorted descending by global misfit value') 

567 

568 return [(PlotItem(name='main'), fig)] 

569 

570 

571def get_plot_classes(): 

572 return [SequencePlot, ContributionsPlot, BootstrapPlot]