Coverage for /usr/local/lib/python3.11/dist-packages/grond/plot/collection.py: 66%

127 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-06-12 12:25 +0000

1import os 

2import os.path as op 

3import logging 

4 

5from pyrocko import guts, util 

6from pyrocko.guts import Dict, List, Tuple, Float, Unicode, Object, String 

7 

8from grond.meta import StringID 

9from grond.plot.config import PlotFormat 

10 

11 

12guts_prefix = 'grond' 

13 

14logger = logging.getLogger('grond.plot.collection') 

15 

16 

17class PlotItem(Object): 

18 name = StringID.T() 

19 attributes = Dict.T( 

20 StringID.T(), List.T(String.T())) 

21 title = Unicode.T( 

22 optional=True, 

23 help='item\'s description') 

24 description = Unicode.T( 

25 optional=True, 

26 help='item\'s description') 

27 

28 

29class PlotGroup(Object): 

30 name = StringID.T( 

31 help='group name') 

32 section = StringID.T( 

33 optional=True, 

34 help='group\'s section path, e.g. results.waveforms') 

35 title = Unicode.T( 

36 optional=True, 

37 help='group\'s title') 

38 description = Unicode.T( 

39 optional=True, 

40 help='group description') 

41 formats = List.T( 

42 PlotFormat.T(), 

43 help='plot format') 

44 variant = StringID.T( 

45 help='variant of the group') 

46 feather_icon = String.T( 

47 default='bar-chart-2', 

48 help='Feather icon for the HTML report.') 

49 size_cm = Tuple.T(2, Float.T()) 

50 items = List.T(PlotItem.T()) 

51 attributes = Dict.T(StringID.T(), List.T(String.T())) 

52 

53 def filename_image(self, item, format): 

54 return '%s.%s.%s.%s' % ( 

55 self.name, 

56 self.variant, 

57 item.name, 

58 format.extension) 

59 

60 

61class PlotCollection(Object): 

62 group_refs = List.T(Tuple.T(2, StringID.T())) 

63 

64 

65class PlotCollectionManager(object): 

66 

67 def __init__(self, path, show=False): 

68 self._path = path 

69 self.load_collection() 

70 self._show = show 

71 

72 def load_collection(self): 

73 path = self.path_collection() 

74 if op.exists(path): 

75 self._collection = guts.load(filename=self.path_collection()) 

76 else: 

77 self._collection = PlotCollection() 

78 

79 def dump_collection(self): 

80 path = self.path_collection() 

81 util.ensuredirs(path) 

82 guts.dump(self._collection, filename=path) 

83 

84 def path_collection(self): 

85 return op.join(self._path, 'plot_collection.yaml') 

86 

87 def path_image(self, group, item, format): 

88 return op.join( 

89 self._path, group.name, group.variant, 

90 group.filename_image(item, format)) 

91 

92 def path_group(self, group_ref=None, group=None): 

93 if group_ref is not None: 

94 group_name, group_variant = group_ref 

95 else: 

96 group_name = group.name 

97 group_variant = group.variant 

98 

99 return op.join( 

100 self._path, group_name, group_variant, '%s.%s.%s' % ( 

101 group_name, group_variant, 'plot_group.yaml')) 

102 

103 def paths_group_dirs(self, group_ref=None, group=None): 

104 if group_ref is not None: 

105 group_name, group_variant = group_ref 

106 else: 

107 group_name = group.name 

108 group_variant = group.variant 

109 

110 return [ 

111 op.join(self._path, group_name, group_variant), 

112 op.join(self._path, group_name)] 

113 

114 def create_group_mpl(self, config, iter_item_figure, **kwargs): 

115 from matplotlib import pyplot as plt 

116 group = PlotGroup( 

117 formats=guts.clone(config.formats), 

118 size_cm=config.size_cm, 

119 name=config.name, 

120 variant=config.variant, 

121 **kwargs) 

122 

123 path_group = self.path_group(group=group) 

124 if os.path.exists(path_group): 

125 self.remove_group_files(path_group) 

126 

127 group_ref = (group.name, group.variant) 

128 if group_ref in self._collection.group_refs: 

129 self._collection.group_refs.remove(group_ref) 

130 

131 self.dump_collection() 

132 

133 figs_to_close = [] 

134 for item, fig in iter_item_figure: 

135 group.items.append(item) 

136 for format in group.formats: 

137 path = self.path_image(group, item, format) 

138 util.ensuredirs(path) 

139 format.render_mpl( 

140 fig, 

141 path=path, 

142 dpi=format.get_dpi(group.size_cm)) 

143 

144 logger.info('Figure saved: %s' % path) 

145 

146 if not self._show: 

147 plt.close(fig) 

148 else: 

149 figs_to_close.append(fig) 

150 

151 util.ensuredirs(path_group) 

152 group.validate() 

153 group.dump(filename=path_group) 

154 self._collection.group_refs.append(group_ref) 

155 self.dump_collection() 

156 

157 if self._show: 

158 plt.show() 

159 

160 for fig in figs_to_close: 

161 plt.close(fig) 

162 

163 def create_group_automap(self, config, iter_item_figure, **kwargs): 

164 group = PlotGroup( 

165 formats=guts.clone(config.formats), 

166 size_cm=config.size_cm, 

167 name=config.name, 

168 variant=config.variant, 

169 **kwargs) 

170 

171 path_group = self.path_group(group=group) 

172 if os.path.exists(path_group): 

173 self.remove_group_files(path_group) 

174 

175 group_ref = (group.name, group.variant) 

176 if group_ref in self._collection.group_refs: 

177 self._collection.group_refs.remove(group_ref) 

178 

179 self.dump_collection() 

180 

181 for item, automap in iter_item_figure: 

182 group.items.append(item) 

183 for format in group.formats: 

184 path = self.path_image(group, item, format) 

185 util.ensuredirs(path) 

186 format.render_automap( 

187 automap, 

188 path=path, 

189 resolution=format.get_dpi(group.size_cm)) 

190 

191 logger.info('Figure saved: %s' % path) 

192 

193 util.ensuredirs(path_group) 

194 group.dump(filename=path_group) 

195 self._collection.group_refs.append(group_ref) 

196 self.dump_collection() 

197 

198 def create_group_gmtpy(self, config, iter_item_figure): 

199 pass 

200 

201 def remove_group_files(self, path_group): 

202 group = guts.load(filename=path_group) 

203 for item in group.items: 

204 for format in group.formats: 

205 path = self.path_image(group, item, format) 

206 try: 

207 os.unlink(path) 

208 except OSError: 

209 pass 

210 

211 os.unlink(path_group) 

212 for path in self.paths_group_dirs(group=group): 

213 try: 

214 os.rmdir(path) 

215 except OSError: 

216 pass 

217 

218 

219__all__ = [ 

220 'PlotItem', 

221 'PlotGroup', 

222 'PlotCollection', 

223 'PlotCollectionManager', 

224]