1from __future__ import print_function 

2import numpy as num 

3from matplotlib.axes import Axes 

4from matplotlib.ticker import MultipleLocator 

5 

6from pyrocko.guts import Tuple, Float 

7from pyrocko import plot 

8 

9from .config import PlotConfig 

10 

11guts_prefix = 'grond' 

12 

13 

14def limits(points): 

15 lims = num.zeros((3, 2)) 

16 if points.size != 0: 

17 lims[:, 0] = num.min(points, axis=0) 

18 lims[:, 1] = num.max(points, axis=0) 

19 

20 return lims 

21 

22 

23class NotEnoughSpace(Exception): 

24 pass 

25 

26 

27class SectionPlotConfig(PlotConfig): 

28 

29 size_cm = Tuple.T( 

30 2, Float.T(), default=(20., 20.)) 

31 

32 margins_em = Tuple.T( 

33 4, Float.T(), default=(7., 5., 7., 5.)) 

34 

35 separator_em = Float.T(default=1.0) 

36 

37 

38class SectionPlot(object): 

39 

40 def __init__(self, config=None): 

41 if config is None: 

42 config = SectionPlotConfig() 

43 

44 self.config = config 

45 self._disconnect_data = [] 

46 self._width = self._height = self._pixels = None 

47 self._plt = plot.mpl_init(self.config.font_size) 

48 self._fig = fig = self._plt.figure(figsize=self.config.size_inch) 

49 

50 rect = [0., 0., 1., 1.] 

51 self._axes_xy = Axes(fig, rect) 

52 self._axes_xz = Axes(fig, rect) 

53 self._axes_zy = Axes(fig, rect) 

54 

55 self._view_limits = num.zeros((3, 2)) 

56 

57 self._view_limits[:, :] = num.nan 

58 

59 self._update_geometry() 

60 

61 for axes in self.axes_list: 

62 fig.add_axes(axes) 

63 self._connect(axes, 'xlim_changed', self.lim_changed_handler) 

64 self._connect(axes, 'ylim_changed', self.lim_changed_handler) 

65 

66 self._cid_resize = fig.canvas.mpl_connect( 

67 'resize_event', self.resize_handler) 

68 

69 self._connect(fig, 'dpi_changed', self.dpi_changed_handler) 

70 

71 self._lim_changed_depth = 0 

72 

73 def _connect(self, obj, sig, handler): 

74 cid = obj.callbacks.connect(sig, handler) 

75 self._disconnect_data.append((obj, cid)) 

76 

77 def _disconnect_all(self): 

78 for obj, cid in self._disconnect_data: 

79 obj.callbacks.disconnect(cid) 

80 

81 self._fig.canvas.mpl_disconnect(self._cid_resize) 

82 

83 def dpi_changed_handler(self, fig): 

84 self._update_geometry() 

85 

86 def resize_handler(self, event): 

87 self._update_geometry() 

88 

89 def lim_changed_handler(self, axes): 

90 self._lim_changed_depth += 1 

91 if self._lim_changed_depth < 2: 

92 self._update_layout() 

93 

94 self._lim_changed_depth -= 1 

95 

96 def _update_geometry(self): 

97 w, h = self._fig.canvas.get_width_height() 

98 p = self.get_pixels_factor() 

99 

100 if (self._width, self._height, self._pixels) != (w, h, p): 

101 self._width = w 

102 self._height = h 

103 self._pixels = p 

104 self._update_layout() 

105 

106 @property 

107 def margins(self): 

108 return tuple( 

109 x * self.config.font_size / self._pixels 

110 for x in self.config.margins_em) 

111 

112 @property 

113 def separator(self): 

114 return self.config.separator_em * self.config.font_size / self._pixels 

115 

116 def rect_to_figure_coords(self, rect): 

117 left, bottom, width, height = rect 

118 return ( 

119 left / self._width, 

120 bottom / self._height, 

121 width / self._width, 

122 height / self._height) 

123 

124 def point_to_axes_coords(self, axes, point): 

125 x, y = point 

126 aleft, abottom, awidth, aheight = axes.get_position().bounds 

127 

128 x_fig = x / self._width 

129 y_fig = y / self._height 

130 

131 x_axes = (x_fig - aleft) / awidth 

132 y_axes = (y_fig - abottom) / aheight 

133 

134 return (x_axes, y_axes) 

135 

136 def get_pixels_factor(self): 

137 try: 

138 r = self._fig.canvas.get_renderer() 

139 return 1.0 / r.points_to_pixels(1.0) 

140 except AttributeError: 

141 return 1.0 

142 

143 def make_limits(self, lims): 

144 a = plot.AutoScaler(space=0.05) 

145 return a.make_scale(lims)[:2] 

146 

147 def get_data_limits(self): 

148 xs = [] 

149 ys = [] 

150 zs = [] 

151 xs.extend(self._axes_xy.get_xaxis().get_data_interval()) 

152 ys.extend(self._axes_xy.get_yaxis().get_data_interval()) 

153 xs.extend(self._axes_xz.get_xaxis().get_data_interval()) 

154 zs.extend(self._axes_xz.get_yaxis().get_data_interval()) 

155 zs.extend(self._axes_zy.get_xaxis().get_data_interval()) 

156 ys.extend(self._axes_zy.get_yaxis().get_data_interval()) 

157 lims = num.zeros((3, 2)) 

158 lims[0, :] = num.nanmin(xs), num.nanmax(xs) 

159 lims[1, :] = num.nanmin(ys), num.nanmax(ys) 

160 lims[2, :] = num.nanmin(zs), num.nanmax(zs) 

161 lims[num.logical_not(num.isfinite(lims))] = 0.0 

162 return lims 

163 

164 def set_xlim(self, xmin, xmax): 

165 self._view_limits[0, :] = xmin, xmax 

166 self._update_layout() 

167 

168 def set_ylim(self, ymin, ymax): 

169 self._view_limits[1, :] = ymin, ymax 

170 self._update_layout() 

171 

172 def set_zlim(self, zmin, zmax): 

173 self._view_limits[2, :] = zmin, zmax 

174 self._update_layout() 

175 

176 def _update_layout(self): 

177 data_limits = self.get_data_limits() 

178 

179 limits = num.zeros((3, 2)) 

180 for i in range(3): 

181 limits[i, :] = self.make_limits(data_limits[i, :]) 

182 

183 mask = num.isfinite(self._view_limits) 

184 limits[mask] = self._view_limits[mask] 

185 

186 deltas = limits[:, 1] - limits[:, 0] 

187 

188 data_w = deltas[0] + deltas[2] 

189 data_h = deltas[1] + deltas[2] 

190 

191 ml, mt, mr, mb = self.margins 

192 ms = self.separator 

193 

194 data_r = data_h / data_w 

195 em = self.config.font_size 

196 w = self._width 

197 h = self._height 

198 fig_w_avail = w - mr - ml - ms 

199 fig_h_avail = h - mt - mb - ms 

200 

201 if fig_w_avail <= 0.0 or fig_h_avail <= 0.0: 

202 raise NotEnoughSpace() 

203 

204 fig_r = fig_h_avail / fig_w_avail 

205 

206 if data_r < fig_r: 

207 data_expanded_h = data_w * fig_r 

208 data_expanded_w = data_w 

209 else: 

210 data_expanded_h = data_h 

211 data_expanded_w = data_h / fig_r 

212 

213 limits[0, 0] -= 0.5 * (data_expanded_w - data_w) 

214 limits[0, 1] += 0.5 * (data_expanded_w - data_w) 

215 limits[1, 0] -= 0.5 * (data_expanded_h - data_h) 

216 limits[1, 1] += 0.5 * (data_expanded_h - data_h) 

217 

218 deltas = limits[:, 1] - limits[:, 0] 

219 

220 w1 = fig_w_avail * deltas[0] / data_expanded_w 

221 w2 = fig_w_avail * deltas[2] / data_expanded_w 

222 

223 h1 = fig_h_avail * deltas[1] / data_expanded_h 

224 h2 = fig_h_avail * deltas[2] / data_expanded_h 

225 

226 rect_xy = [ml, mb+h2+ms, w1, h1] 

227 rect_xz = [ml, mb, w1, h2] 

228 rect_zy = [ml+w1+ms, mb+h2+ms, w2, h1] 

229 

230 axes_xy, axes_xz, axes_zy = self.axes_list 

231 

232 axes_xy.set_position( 

233 self.rect_to_figure_coords(rect_xy), which='both') 

234 axes_xz.set_position( 

235 self.rect_to_figure_coords(rect_xz), which='both') 

236 axes_zy.set_position( 

237 self.rect_to_figure_coords(rect_zy), which='both') 

238 

239 def wcenter(rect): 

240 return rect[0] + rect[2]*0.5 

241 

242 def hcenter(rect): 

243 return rect[1] + rect[3]*0.5 

244 

245 self.set_label_coords( 

246 axes_xy, 'x', [wcenter(rect_xy), h - 1.0*em]) 

247 self.set_label_coords( 

248 axes_xy, 'y', [2.0*em, hcenter(rect_xy)]) 

249 self.set_label_coords( 

250 axes_zy, 'x', [wcenter(rect_zy), h - 1.0*em]) 

251 self.set_label_coords( 

252 axes_xz, 'y', [2.0*em, hcenter(rect_xz)]) 

253 

254 scaler = plot.AutoScaler() 

255 inc = scaler.make_scale( 

256 [0, min(data_expanded_w, data_expanded_h)], override_mode='off')[2] 

257 

258 axes_xy.set_xlim(*limits[0, :]) 

259 axes_xy.set_ylim(*limits[1, :]) 

260 axes_xy.get_xaxis().set_tick_params( 

261 bottom=False, top=True, labelbottom=False, labeltop=True) 

262 axes_xy.get_yaxis().set_tick_params( 

263 left=True, labelleft=True, right=False, labelright=False) 

264 

265 axes_xz.set_xlim(*limits[0, :]) 

266 axes_xz.set_ylim(*limits[2, ::-1]) 

267 axes_xz.get_xaxis().set_tick_params( 

268 bottom=True, top=False, labelbottom=False, labeltop=False) 

269 axes_xz.get_yaxis().set_tick_params( 

270 left=True, labelleft=True, right=True, labelright=False) 

271 

272 axes_zy.set_xlim(*limits[2, :]) 

273 axes_zy.set_ylim(*limits[1, :]) 

274 axes_zy.get_xaxis().set_tick_params( 

275 bottom=True, top=True, labelbottom=False, labeltop=True) 

276 axes_zy.get_yaxis().set_tick_params( 

277 left=False, labelleft=False, right=True, labelright=False) 

278 

279 for axes in self.axes_list: 

280 tl = MultipleLocator(inc) 

281 axes.get_xaxis().set_major_locator(tl) 

282 tl = MultipleLocator(inc) 

283 axes.get_yaxis().set_major_locator(tl) 

284 

285 def set_label_coords(self, axes, which, point): 

286 axis = axes.get_xaxis() if which == 'x' else axes.get_yaxis() 

287 axis.set_label_coords(*self.point_to_axes_coords(axes, point)) 

288 

289 @property 

290 def fig(self): 

291 return self._fig 

292 

293 @property 

294 def axes_xy(self): 

295 return self._axes_xy 

296 

297 @property 

298 def axes_xz(self): 

299 return self._axes_xz 

300 

301 @property 

302 def axes_zy(self): 

303 return self._axes_zy 

304 

305 @property 

306 def axes_list(self): 

307 return [ 

308 self._axes_xy, self._axes_xz, self._axes_zy] 

309 

310 def plot(self, points, *args, **kwargs): 

311 self._axes_xy.plot(points[:, 0], points[:, 1], *args, **kwargs) 

312 self._axes_xz.plot(points[:, 0], points[:, 2], *args, **kwargs) 

313 self._axes_zy.plot(points[:, 2], points[:, 1], *args, **kwargs) 

314 

315 def close(self): 

316 self._disconnect_all() 

317 self._plt.close(self._fig) 

318 

319 def show(self): 

320 self._plt.show() 

321 

322 def set_xlabel(self, s): 

323 self._axes_xy.set_xlabel(s) 

324 

325 def set_ylabel(self, s): 

326 self._axes_xy.set_ylabel(s) 

327 

328 def set_zlabel(self, s): 

329 self._axes_xz.set_ylabel(s) 

330 self._axes_zy.set_xlabel(s)