1import logging 

2import hashlib 

3import time 

4import numpy as num 

5from scipy import signal 

6 

7from matplotlib.cm import get_cmap 

8from matplotlib.colors import Normalize 

9 

10from pyrocko.trace import t2ind 

11from .qt_compat import qg 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16DEFAULT_CMAP = 'viridis' 

17 

18 

19class TraceWaterfall: 

20 

21 def __init__(self): 

22 self.tmin = 0. 

23 self.tmax = 0. 

24 self.traces = [] 

25 

26 self._current_cmap = None 

27 self.cmap = None 

28 self.norm = Normalize() 

29 

30 self._data_cache = None 

31 

32 self._show_absolute = False 

33 self._integrate = False 

34 self._clip_min = 0. 

35 self._clip_max = 1. 

36 self._common_scale = True 

37 

38 self.set_cmap(DEFAULT_CMAP) 

39 

40 def set_traces(self, traces): 

41 self.traces = traces 

42 

43 def set_time_range(self, tmin, tmax): 

44 self.tmin = tmin 

45 self.tmax = tmax 

46 

47 def set_clip(self, clip_min, clip_max): 

48 assert 0. <= clip_min < clip_max <= 1. 

49 self._clip_min = clip_min 

50 self._clip_max = clip_max 

51 

52 def set_integrate(self, integrate): 

53 self._integrate = integrate 

54 

55 def show_absolute_values(self, show_absolute): 

56 self._show_absolute = show_absolute 

57 

58 def set_cmap(self, cmap): 

59 if cmap == self._current_cmap: 

60 return 

61 logger.debug('setting colormap to %s', cmap) 

62 self.cmap = get_cmap(cmap) 

63 self._current_cmap = cmap 

64 

65 def set_common_scale(self, _common_scale): 

66 self._common_scale = _common_scale 

67 

68 def get_state_hash(self): 

69 sha1 = hashlib.sha1() 

70 sha1.update(self.tmin.hex().encode()) 

71 sha1.update(self.tmax.hex().encode()) 

72 sha1.update(self._clip_min.hex().encode()) 

73 sha1.update(self._clip_max.hex().encode()) 

74 sha1.update(self.cmap.name.encode()) 

75 sha1.update(bytes(self._show_absolute)) 

76 sha1.update(bytes(self._integrate)) 

77 sha1.update(bytes(len(self.traces))) 

78 for tr in self.traces: 

79 sha1.update(tr.hash(unsafe=True).encode()) 

80 

81 return sha1 

82 

83 def get_image(self, px_x, px_y): 

84 hash = self.get_state_hash() 

85 hash.update(bytes(px_x)) 

86 hash.update(bytes(px_y)) 

87 

88 data_hash = hash.hexdigest() 

89 

90 if self._data_cache and self._data_cache[-1] == data_hash: 

91 logger.debug('using cached image') 

92 return self._data_cache 

93 

94 # Undersample in space 

95 traces_step = int(len(self.traces) // px_y) + 1 

96 traces = self.traces[::traces_step] 

97 img_rows = len(traces) 

98 

99 # Undersample in time 

100 raw_deltat = min(tr.deltat for tr in traces) 

101 raw_nsamples = int(round((self.tmax - self.tmin) / raw_deltat)) + 1 

102 

103 img_undersample = max(1, int(raw_nsamples // (2*px_x))) 

104 img_deltat = raw_deltat * img_undersample 

105 img_nsamples = int(round((self.tmax - self.tmin) / img_deltat)) + 1 

106 

107 dtypes = set(tr.ydata.dtype for tr in traces) 

108 dtype = num.float64 if num.float64 in dtypes else num.float32 

109 

110 data = num.zeros((img_rows, img_nsamples), dtype=dtype) 

111 empty_data = num.ones_like(data, dtype=num.bool) 

112 

113 deltats = num.zeros(img_rows) if self._integrate else None 

114 

115 logger.debug( 

116 'image render: using [::%d] traces at %d time undersampling' 

117 ' - rect (%d, %d), data: (%d, %d)', 

118 traces_step, img_undersample, px_y, px_x, *data.shape) 

119 

120 for itr, tr in enumerate(traces): 

121 tr_data = tr.ydata 

122 

123 if tr.deltat != img_deltat: 

124 time_vec = tr.tmin \ 

125 + num.arange((tr.tmax - tr.tmin) // img_deltat) \ 

126 * img_deltat 

127 tr_data = num.interp(time_vec, tr.get_xdata(), tr.ydata) 

128 

129 ibeg = max(0, t2ind(self.tmin - tr.tmin, img_deltat, round)) 

130 iend = min( 

131 tr_data.size, 

132 t2ind(self.tmax - tr.tmin, img_deltat, round)) 

133 tr_tmin = tr.tmin + ibeg * img_deltat 

134 

135 img_ibeg = max(0, t2ind(tr_tmin - self.tmin, img_deltat, round)) 

136 img_iend = img_ibeg + (iend - ibeg) 

137 

138 data[itr, img_ibeg:img_iend] = tr_data[ibeg:iend] 

139 empty_data[itr, img_ibeg:img_iend] = False 

140 

141 if self._integrate: 

142 deltats[itr] = tr.deltat 

143 

144 if self._integrate: 

145 data = num.cumsum(data, axis=1) * deltats[:, num.newaxis] 

146 # data -= data.mean(axis=1)[:, num.newaxis] 

147 

148 if self._common_scale: 

149 data /= num.abs(data).max(axis=1)[:, num.newaxis] 

150 

151 if self._show_absolute: 

152 data = num.abs(signal.hilbert(data, axis=1)) 

153 vmax = data.max() 

154 vmin = data.min() 

155 else: 

156 vmax = num.abs(data).max() 

157 vmin = -vmax 

158 

159 vrange = vmax - vmin 

160 

161 self.norm.vmin = vmin + self._clip_min*vrange 

162 self.norm.vmax = vmax - (1. - self._clip_max)*vrange 

163 

164 tstart = time.time() 

165 img_data = self.norm(data) 

166 t_norm = time.time() - tstart 

167 tstart = time.time() 

168 img_data = self.cmap(img_data, alpha=None, bytes=True) 

169 t_cmap = time.time() - tstart 

170 logger.debug('normalizing: %.3f cmap: %.3f', t_norm, t_cmap) 

171 

172 # Mask out empty data 

173 img_data[empty_data, 3] = 0 

174 

175 px_x, px_y = data.shape 

176 img = qg.QImage( 

177 img_data, 

178 px_y, px_x, qg.QImage.Format_RGBA8888) 

179 

180 self._data_cache = (data, img, data_hash) 

181 return self._data_cache 

182 

183 def draw_waterfall(self, p, rect=None): 

184 if not self.traces: 

185 raise AttributeError('No traces to paint.') 

186 

187 rect = rect or p.window() 

188 trace_data, img, *_ = self.get_image( 

189 int(rect.width()), int(rect.height())) 

190 p.drawImage(rect, img)