Coverage for /usr/local/lib/python3.11/dist-packages/grond/optimisers/highscore/plot.py: 18%

221 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-11-27 15:15 +0000

1# https://pyrocko.org/grond - GPLv3 

2# 

3# The Grond Developers, 21st Century 

4import logging 

5import numpy as num 

6 

7from matplotlib import pyplot as plt 

8from matplotlib.ticker import FuncFormatter 

9 

10from pyrocko.plot import mpl_init, mpl_margins, mpl_color 

11from pyrocko.guts import Tuple, Float 

12from pyrocko import trace 

13 

14from grond.plot.config import PlotConfig 

15from grond.plot.collection import PlotItem 

16 

17logger = logging.getLogger('grond.optimiser.highscore.plot') 

18 

19guts_prefix = 'grond' 

20 

21 

22def _pcolormesh_same_dim(ax, x, y, v, **kwargs): 

23 # x, y, v must have the same dimension 

24 try: 

25 return ax.pcolormesh(x, y, v, shading='nearest', **kwargs) 

26 except TypeError: 

27 # matplotlib versions < 3.3 

28 return ax.pcolormesh(x, y, v[:-1, :-1], **kwargs) 

29 

30 

31class HighScoreOptimiserPlot(object): 

32 

33 def __init__( 

34 self, optimiser, problem, history, xpar_name, ypar_name, 

35 movie_filename): 

36 

37 self.optimiser = optimiser 

38 self.problem = problem 

39 self.chains = optimiser.chains(problem, history) 

40 self.history = history 

41 self.xpar_name = xpar_name 

42 self.ypar_name = ypar_name 

43 self.fontsize = 10. 

44 self.movie_filename = movie_filename 

45 self.show = False 

46 self.iiter = 0 

47 self.iiter_last_draw = 0 

48 self._volatile = [] 

49 self._blocks_complete = set() 

50 

51 def start(self): 

52 nfx = 1 

53 nfy = 1 

54 

55 problem = self.problem 

56 

57 ixpar = problem.name_to_index(self.xpar_name) 

58 iypar = problem.name_to_index(self.ypar_name) 

59 

60 mpl_init(fontsize=self.fontsize) 

61 fig = plt.figure(figsize=(9.6, 5.4)) 

62 labelpos = mpl_margins(fig, nw=nfx, nh=nfy, w=7., h=5., wspace=7., 

63 hspace=2., units=self.fontsize) 

64 

65 xpar = problem.parameters[ixpar] 

66 ypar = problem.parameters[iypar] 

67 

68 if xpar.unit == ypar.unit: 

69 axes = fig.add_subplot(nfy, nfx, 1, aspect=1.0) 

70 else: 

71 axes = fig.add_subplot(nfy, nfx, 1) 

72 

73 labelpos(axes, 2.5, 2.0) 

74 

75 axes.set_xlabel(xpar.get_label()) 

76 axes.set_ylabel(ypar.get_label()) 

77 

78 axes.get_xaxis().set_major_locator(plt.MaxNLocator(4)) 

79 axes.get_yaxis().set_major_locator(plt.MaxNLocator(4)) 

80 

81 xref = problem.get_reference_model() 

82 axes.axvline(xpar.scaled(xref[ixpar]), color='black', alpha=0.3) 

83 axes.axhline(ypar.scaled(xref[iypar]), color='black', alpha=0.3) 

84 

85 self.fig = fig 

86 self.problem = problem 

87 self.xpar = xpar 

88 self.ypar = ypar 

89 self.axes = axes 

90 self.ixpar = ixpar 

91 self.iypar = iypar 

92 from matplotlib import colors 

93 n = self.optimiser.nbootstrap + 1 

94 hsv = num.vstack(( 

95 num.random.uniform(0., 1., n), 

96 num.random.uniform(0.5, 0.9, n), 

97 num.repeat(0.7, n))).T 

98 

99 self.bcolors = colors.hsv_to_rgb(hsv[num.newaxis, :, :])[0, :, :] 

100 self.bcolors[0, :] = [0., 0., 0.] 

101 

102 bounds = self.problem.get_combined_bounds() 

103 

104 from grond import plot 

105 self.xlim = plot.fixlim(*xpar.scaled(bounds[ixpar])) 

106 self.ylim = plot.fixlim(*ypar.scaled(bounds[iypar])) 

107 

108 self.set_limits() 

109 

110 from matplotlib.colors import LinearSegmentedColormap 

111 

112 self.cmap = LinearSegmentedColormap.from_list('probability', [ 

113 (1.0, 1.0, 1.0), 

114 (0.5, 0.9, 0.6)]) 

115 

116 self.writer = None 

117 if self.movie_filename: 

118 from matplotlib.animation import FFMpegWriter 

119 

120 metadata = dict(title=problem.name, artist='Grond') 

121 

122 self.writer = FFMpegWriter( 

123 fps=30, 

124 metadata=metadata, 

125 codec='libx264', 

126 bitrate=200000, 

127 extra_args=[ 

128 '-pix_fmt', 'yuv420p', 

129 '-profile:v', 'baseline', 

130 '-level', '3', 

131 '-an']) 

132 

133 self.writer.setup(self.fig, self.movie_filename, dpi=200) 

134 

135 if self.show: 

136 plt.ion() 

137 plt.show() 

138 

139 def set_limits(self): 

140 self.axes.autoscale(False) 

141 self.axes.set_xlim(*self.xlim) 

142 self.axes.set_ylim(*self.ylim) 

143 

144 def draw_frame(self): 

145 

146 self.chains.goto(self.iiter+1) 

147 msize = 15. 

148 

149 for artist in self._volatile: 

150 artist.remove() 

151 

152 self._volatile[:] = [] 

153 

154 nblocks = self.iiter // 100 + 1 

155 

156 models = self.history.models[:self.iiter+1] 

157 

158 for iblock in range(nblocks): 

159 if iblock in self._blocks_complete: 

160 continue 

161 

162 models_add = self.history.models[ 

163 iblock*100:min((iblock+1)*100, self.iiter+1)] 

164 

165 fx = self.problem.extract(models_add, self.ixpar) 

166 fy = self.problem.extract(models_add, self.iypar) 

167 collection = self.axes.scatter( 

168 self.xpar.scaled(fx), 

169 self.ypar.scaled(fy), 

170 color='black', 

171 s=msize * 0.15, alpha=0.2, edgecolors='none') 

172 

173 if models_add.shape[0] != 100: 

174 self._volatile.append(collection) 

175 else: 

176 self._blocks_complete.add(iblock) 

177 

178 for ichain in range(self.chains.nchains): 

179 

180 iiters = self.chains.indices(ichain) 

181 fx = self.problem.extract(models[iiters, :], self.ixpar) 

182 fy = self.problem.extract(models[iiters, :], self.iypar) 

183 

184 nfade = 20 

185 t1 = num.maximum(0.0, iiters - (models.shape[0] - nfade)) / nfade 

186 factors = num.sqrt(1.0 - t1) * (1.0 + 15. * t1**2) 

187 

188 msizes = msize * factors 

189 

190 paths = self.axes.scatter( 

191 self.xpar.scaled(fx), 

192 self.ypar.scaled(fy), 

193 color=self.bcolors[ichain], 

194 s=msizes, alpha=0.5, edgecolors='none') 

195 

196 self._volatile.append(paths) 

197 

198 _, phase, iiter_phase = self.optimiser.get_sampler_phase(self.iiter) 

199 

200 np = 1000 

201 models_prob = num.zeros((np, self.problem.nparameters)) 

202 for ip in range(np): 

203 models_prob[ip, :] = phase.get_sample( 

204 self.problem, iiter_phase, self.chains) 

205 

206 fx = self.problem.extract(models_prob, self.ixpar) 

207 fy = self.problem.extract(models_prob, self.iypar) 

208 

209 if False: 

210 

211 bounds = self.problem.get_combined_bounds() 

212 

213 nx = 20 

214 ny = 20 

215 x_edges = num.linspace( 

216 bounds[self.ixpar][0], bounds[self.ixpar][1], nx) 

217 y_edges = num.linspace( 

218 bounds[self.iypar][0], bounds[self.iypar][1], ny) 

219 

220 p, _, _ = num.histogram2d(fx, fy, bins=(x_edges, y_edges)) 

221 x, y = num.meshgrid(x_edges, y_edges) 

222 

223 artist = self.axes.pcolormesh( 

224 self.xpar.scaled(x), 

225 self.ypar.scaled(y), 

226 p, cmap=self.cmap, zorder=-1) 

227 

228 self._volatile.append(artist) 

229 

230 else: 

231 collection = self.axes.scatter( 

232 self.xpar.scaled(fx), 

233 self.ypar.scaled(fy), 

234 color='green', 

235 s=msize * 0.15, alpha=0.2, edgecolors='none') 

236 

237 self._volatile.append(collection) 

238 

239 if self.writer: 

240 self.writer.grab_frame() 

241 

242 artist = self.axes.annotate( 

243 '%i (%s)' % (self.iiter+1, phase.__class__.__name__), 

244 xy=(0., 1.), 

245 xycoords='axes fraction', 

246 xytext=(self.fontsize/2., -self.fontsize/2.), 

247 textcoords='offset points', 

248 ha='left', 

249 va='top', 

250 fontsize=self.fontsize, 

251 fontstyle='normal') 

252 

253 self._volatile.append(artist) 

254 

255 if self.show: 

256 plt.draw() 

257 

258 self.iiter_last_draw = self.iiter + 1 

259 

260 def finish(self): 

261 if self.writer: 

262 self.writer.finish() 

263 

264 if self.show: 

265 plt.show() 

266 plt.ioff() 

267 

268 def render(self): 

269 self.start() 

270 

271 while self.iiter < self.history.nmodels: 

272 logger.info('Rendering frame %i/%i.' 

273 % (self.iiter+1, self.history.nmodels)) 

274 self.draw_frame() 

275 self.iiter += 1 

276 

277 self.finish() 

278 

279 

280def rolling_window(a, window): 

281 shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) 

282 strides = a.strides + (a.strides[-1],) 

283 return num.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) 

284 

285 

286class HighScoreAcceptancePlot(PlotConfig): 

287 '''Model acceptance plot ''' 

288 name = 'acceptance' 

289 size_cm = Tuple.T(2, Float.T(), default=(21., 14.9)) 

290 

291 def make(self, environ): 

292 cm = environ.get_plot_collection_manager() 

293 cm.create_group_mpl( 

294 self, 

295 self.draw_figures(environ), 

296 title=u'Acceptance', 

297 section='optimiser', 

298 description=u''' 

299Model acceptance and accepted model popularities. 

300 

301The plots in this section can be used to investigate performance and 

302characteristics of the optimisation algorithm. 

303''', 

304 feather_icon='check') 

305 

306 def draw_figures(self, environ): 

307 nwindow = 200 

308 show_raw_acceptance_rates = False 

309 optimiser = environ.get_optimiser() 

310 problem = environ.get_problem() 

311 history = environ.get_history() 

312 chains = optimiser.chains(problem, history) 

313 chains.load() 

314 

315 acceptance = chains.acceptance_history 

316 

317 nmodels_rate = history.nmodels - (nwindow - 1) 

318 if nmodels_rate < 1: 

319 logger.warning( 

320 'Cannot create plot acceptance: insufficient number of tested ' 

321 'models.') 

322 

323 return 

324 

325 acceptance_rate = num.zeros((history.nchains, nmodels_rate)) 

326 for ichain in range(history.nchains): 

327 acceptance_rate[ichain, :] = trace.moving_sum( 

328 acceptance[ichain, :], nwindow, mode='valid') / float(nwindow) 

329 

330 acceptance_n = num.sum(acceptance, axis=0) 

331 

332 acceptance_any = num.minimum(acceptance_n, 1) 

333 

334 acceptance_any_rate = trace.moving_sum( 

335 acceptance_any, nwindow, mode='valid') / float(nwindow) 

336 

337 acceptance_p = acceptance_n / float(history.nchains) 

338 

339 popularity = trace.moving_sum( 

340 acceptance_p, nwindow, mode='valid') \ 

341 / float(nwindow) / acceptance_any_rate 

342 

343 mpl_init(fontsize=self.font_size) 

344 fig = plt.figure(figsize=self.size_inch) 

345 labelpos = mpl_margins(fig, w=7., h=5., units=self.font_size) 

346 

347 axes = fig.add_subplot(1, 1, 1) 

348 labelpos(axes, 2.5, 2.0) 

349 

350 imodels = num.arange(history.nmodels) 

351 

352 imodels_rate = imodels[nwindow-1:] 

353 

354 axes.plot( 

355 acceptance_n/history.nchains * 100., 

356 '.', 

357 ms=2.0, 

358 color=mpl_color('skyblue2'), 

359 label='Popularity of Accepted Models', 

360 alpha=0.3) 

361 

362 if show_raw_acceptance_rates: 

363 for ichain in range(chains.nchains): 

364 axes.plot(imodels_rate, acceptance_rate[ichain, :]*100., 

365 color=mpl_color('scarletred2'), alpha=0.2) 

366 

367 axes.plot( 

368 imodels_rate, 

369 popularity * 100., 

370 color=mpl_color('skyblue2'), 

371 label='Popularity (moving average)') 

372 axes.plot( 

373 imodels_rate, 

374 acceptance_any_rate*100., 

375 color='black', 

376 label='Acceptance Rate (any chain)') 

377 

378 axes.legend() 

379 

380 axes.set_xlabel('Iteration') 

381 axes.set_ylabel('Acceptance Rate, Model Popularity') 

382 

383 axes.set_ylim(0., 100.) 

384 axes.set_xlim(0., history.nmodels - 1) 

385 axes.grid(alpha=.2) 

386 axes.yaxis.set_major_formatter(FuncFormatter(lambda v, p: '%d%%' % v)) 

387 

388 iiter = 0 

389 bgcolors = [mpl_color('aluminium1'), mpl_color('aluminium2')] 

390 for iphase, phase in enumerate(optimiser.sampler_phases): 

391 axes.axvspan( 

392 iiter, iiter+phase.niterations, 

393 color=bgcolors[iphase % len(bgcolors)]) 

394 

395 iiter += phase.niterations 

396 

397 yield ( 

398 PlotItem( 

399 name='acceptance', 

400 description=u''' 

401Acceptance rate (black line) within a moving window of %d iterations. 

402 

403A model is considered accepted, if it is accepted in at least one chain. The 

404popularity of accepted models is shown as blue dots. Popularity is defined as 

405the percentage of chains accepting the model (100%% meaning acceptance in all 

406chains). A moving average of the popularities is shown as blue line (same 

407averaging interval as for the acceptance rate). Different background colors 

408represent different sampler phases. 

409''' % nwindow), 

410 fig) 

411 

412 mpl_init(fontsize=self.font_size) 

413 fig = plt.figure(figsize=self.size_inch) 

414 labelpos = mpl_margins(fig, w=7., h=5., units=self.font_size) 

415 

416 axes = fig.add_subplot(1, 1, 1) 

417 labelpos(axes, 2.5, 2.0) 

418 

419 nwindow2 = max(1, int(history.nmodels / (self.size_inch[1] * 100))) 

420 nmodels_rate2 = history.nmodels - (nwindow2 - 1) 

421 acceptance_rate2 = num.zeros((history.nchains, nmodels_rate2)) 

422 for ichain in range(history.nchains): 

423 acceptance_rate2[ichain, :] = trace.moving_sum( 

424 acceptance[ichain, :], nwindow2, mode='valid') \ 

425 / float(nwindow2) 

426 

427 imodels_rate2 = imodels[nwindow2-1:] 

428 

429 _pcolormesh_same_dim( 

430 axes, 

431 imodels_rate2, 

432 num.arange(history.nchains), 

433 num.log(0.01+acceptance_rate2), 

434 cmap='GnBu') 

435 

436 if history.sampler_contexts is not None: 

437 axes.plot( 

438 imodels, 

439 history.sampler_contexts[:, 1], 

440 '.', 

441 ms=2.0, 

442 color='black', 

443 label='Breeding Chain', 

444 alpha=0.3) 

445 

446 axes.set_xlabel('Iteration') 

447 axes.set_ylabel('Bootstrap Chain') 

448 axes.set_xlim(0, history.nmodels - 1) 

449 axes.set_ylim(0, history.nchains - 1) 

450 

451 axes.xaxis.grid(alpha=.4) 

452 

453 yield ( 

454 PlotItem( 

455 name='acceptance_img', 

456 description=u''' 

457Model acceptance per bootstrap chain averaged over %d models (background color, 

458low to high acceptance as light to dark colors). 

459 

460Black dots mark the base chains used when sampling new models (directed sampler 

461phases only). 

462''' % nwindow2), 

463 fig) 

464 

465 

466__all__ = [ 

467 'HighScoreOptimiserPlot', 'HighScoreAcceptancePlot']