Coverage for /usr/local/lib/python3.11/dist-packages/grond/plot/collection.py: 66%

127 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2025-04-03 09:31 +0000

1# https://pyrocko.org/grond - GPLv3 

2# 

3# The Grond Developers, 21st Century 

4import os 

5import os.path as op 

6import logging 

7 

8from pyrocko import guts, util 

9from pyrocko.guts import Dict, List, Tuple, Float, Unicode, Object, String 

10 

11from grond.meta import StringID 

12from grond.plot.config import PlotFormat 

13 

14 

15guts_prefix = 'grond' 

16 

17logger = logging.getLogger('grond.plot.collection') 

18 

19 

20class PlotItem(Object): 

21 name = StringID.T() 

22 attributes = Dict.T( 

23 StringID.T(), List.T(String.T())) 

24 title = Unicode.T( 

25 optional=True, 

26 help='item\'s description') 

27 description = Unicode.T( 

28 optional=True, 

29 help='item\'s description') 

30 

31 

32class PlotGroup(Object): 

33 name = StringID.T( 

34 help='group name') 

35 section = StringID.T( 

36 optional=True, 

37 help='group\'s section path, e.g. results.waveforms') 

38 title = Unicode.T( 

39 optional=True, 

40 help='group\'s title') 

41 description = Unicode.T( 

42 optional=True, 

43 help='group description') 

44 formats = List.T( 

45 PlotFormat.T(), 

46 help='plot format') 

47 variant = StringID.T( 

48 help='variant of the group') 

49 feather_icon = String.T( 

50 default='bar-chart-2', 

51 help='Feather icon for the HTML report.') 

52 size_cm = Tuple.T(2, Float.T()) 

53 items = List.T(PlotItem.T()) 

54 attributes = Dict.T(StringID.T(), List.T(String.T())) 

55 

56 def filename_image(self, item, format): 

57 return '%s.%s.%s.%s' % ( 

58 self.name, 

59 self.variant, 

60 item.name, 

61 format.extension) 

62 

63 

64class PlotCollection(Object): 

65 group_refs = List.T(Tuple.T(2, StringID.T())) 

66 

67 

68class PlotCollectionManager(object): 

69 

70 def __init__(self, path, show=False): 

71 self._path = path 

72 self.load_collection() 

73 self._show = show 

74 

75 def load_collection(self): 

76 path = self.path_collection() 

77 if op.exists(path): 

78 self._collection = guts.load(filename=self.path_collection()) 

79 else: 

80 self._collection = PlotCollection() 

81 

82 def dump_collection(self): 

83 path = self.path_collection() 

84 util.ensuredirs(path) 

85 guts.dump(self._collection, filename=path) 

86 

87 def path_collection(self): 

88 return op.join(self._path, 'plot_collection.yaml') 

89 

90 def path_image(self, group, item, format): 

91 return op.join( 

92 self._path, group.name, group.variant, 

93 group.filename_image(item, format)) 

94 

95 def path_group(self, group_ref=None, group=None): 

96 if group_ref is not None: 

97 group_name, group_variant = group_ref 

98 else: 

99 group_name = group.name 

100 group_variant = group.variant 

101 

102 return op.join( 

103 self._path, group_name, group_variant, '%s.%s.%s' % ( 

104 group_name, group_variant, 'plot_group.yaml')) 

105 

106 def paths_group_dirs(self, group_ref=None, group=None): 

107 if group_ref is not None: 

108 group_name, group_variant = group_ref 

109 else: 

110 group_name = group.name 

111 group_variant = group.variant 

112 

113 return [ 

114 op.join(self._path, group_name, group_variant), 

115 op.join(self._path, group_name)] 

116 

117 def create_group_mpl(self, config, iter_item_figure, **kwargs): 

118 from matplotlib import pyplot as plt 

119 group = PlotGroup( 

120 formats=guts.clone(config.formats), 

121 size_cm=config.size_cm, 

122 name=config.name, 

123 variant=config.variant, 

124 **kwargs) 

125 

126 path_group = self.path_group(group=group) 

127 if os.path.exists(path_group): 

128 self.remove_group_files(path_group) 

129 

130 group_ref = (group.name, group.variant) 

131 if group_ref in self._collection.group_refs: 

132 self._collection.group_refs.remove(group_ref) 

133 

134 self.dump_collection() 

135 

136 figs_to_close = [] 

137 for item, fig in iter_item_figure: 

138 group.items.append(item) 

139 for format in group.formats: 

140 path = self.path_image(group, item, format) 

141 util.ensuredirs(path) 

142 format.render_mpl( 

143 fig, 

144 path=path, 

145 dpi=format.get_dpi(group.size_cm)) 

146 

147 logger.info('Figure saved: %s' % path) 

148 

149 if not self._show: 

150 plt.close(fig) 

151 else: 

152 figs_to_close.append(fig) 

153 

154 util.ensuredirs(path_group) 

155 group.validate() 

156 group.dump(filename=path_group) 

157 self._collection.group_refs.append(group_ref) 

158 self.dump_collection() 

159 

160 if self._show: 

161 plt.show() 

162 

163 for fig in figs_to_close: 

164 plt.close(fig) 

165 

166 def create_group_automap(self, config, iter_item_figure, **kwargs): 

167 group = PlotGroup( 

168 formats=guts.clone(config.formats), 

169 size_cm=config.size_cm, 

170 name=config.name, 

171 variant=config.variant, 

172 **kwargs) 

173 

174 path_group = self.path_group(group=group) 

175 if os.path.exists(path_group): 

176 self.remove_group_files(path_group) 

177 

178 group_ref = (group.name, group.variant) 

179 if group_ref in self._collection.group_refs: 

180 self._collection.group_refs.remove(group_ref) 

181 

182 self.dump_collection() 

183 

184 for item, automap in iter_item_figure: 

185 group.items.append(item) 

186 for format in group.formats: 

187 path = self.path_image(group, item, format) 

188 util.ensuredirs(path) 

189 format.render_automap( 

190 automap, 

191 path=path, 

192 resolution=format.get_dpi(group.size_cm)) 

193 

194 logger.info('Figure saved: %s' % path) 

195 

196 util.ensuredirs(path_group) 

197 group.dump(filename=path_group) 

198 self._collection.group_refs.append(group_ref) 

199 self.dump_collection() 

200 

201 def create_group_gmtpy(self, config, iter_item_figure): 

202 pass 

203 

204 def remove_group_files(self, path_group): 

205 group = guts.load(filename=path_group) 

206 for item in group.items: 

207 for format in group.formats: 

208 path = self.path_image(group, item, format) 

209 try: 

210 os.unlink(path) 

211 except OSError: 

212 pass 

213 

214 os.unlink(path_group) 

215 for path in self.paths_group_dirs(group=group): 

216 try: 

217 os.rmdir(path) 

218 except OSError: 

219 pass 

220 

221 

222__all__ = [ 

223 'PlotItem', 

224 'PlotGroup', 

225 'PlotCollection', 

226 'PlotCollectionManager', 

227]