1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6# The container format: 

7# * A file consists of records. 

8# * A record consists of a record header a record payload, and possibly 

9# padding. 

10# * A record header consists of a label, a version, the record size, the 

11# payload size, a hash, and a record type. 

12# * A record payload consists of a sequence of record entries. 

13# * A record entry consists of a key, a type, and a value. 

14from __future__ import absolute_import, division 

15 

16from struct import unpack, pack 

17from io import BytesIO 

18import numpy as num 

19try: 

20 from hashlib import sha1 

21except ImportError: 

22 from sha import new as sha1 

23 

24try: 

25 from os import SEEK_CUR 

26except ImportError: 

27 SEEK_CUR = 1 

28 

29from . import util 

30 

31try: 

32 range = xrange 

33except NameError: 

34 pass 

35 

36 

37size_record_header = 64 

38no_hash = '\0' * 20 

39 

40numtypes = { 

41 '@i2': (num.int16, '>i2'), 

42 '@i4': (num.int32, '>i4'), 

43 '@i8': (num.int64, '>i8'), 

44 '@u2': (num.uint16, '>u2'), 

45 '@u4': (num.uint32, '>u4'), 

46 '@u8': (num.uint64, '>u8'), 

47 '@f4': (num.float32, '>f4'), 

48 '@f8': (num.float64, '>f8'), 

49} 

50 

51numtype2type = dict([(v[0], k) for (k, v) in numtypes.items()]) 

52 

53 

54def packer(fmt): 

55 return ((lambda x: pack('>'+fmt, x)), (lambda x: unpack('>'+fmt, x)[0])) 

56 

57 

58def unpack_array(fmt, data): 

59 return num.fromstring( 

60 data, dtype=numtypes[fmt][1]).astype(numtypes[fmt][0]) 

61 

62 

63def pack_array(fmt, data): 

64 return data.astype(numtypes[fmt][1]).tostring() 

65 

66 

67def array_packer(fmt): 

68 return ((lambda x: pack_array(fmt, x)), (lambda x: unpack_array(fmt, x))) 

69 

70 

71def encoding_packer(enc): 

72 return ((lambda x: x.encode(enc)), (lambda x: str(x.decode(enc)))) 

73 

74 

75def noop(x): 

76 return x 

77 

78 

79def time_to_str_ns(x): 

80 return util.time_to_str(x, format=9).encode('utf-8') 

81 

82 

83def str_to_time(x): 

84 return util.str_to_time(str(x.decode('utf8'))) 

85 

86 

87castings = { 

88 'i2': packer('h'), 

89 'i4': packer('i'), 

90 'i8': packer('q'), 

91 'u2': packer('H'), 

92 'u4': packer('I'), 

93 'u8': packer('Q'), 

94 'f4': packer('f'), 

95 'f8': packer('d'), 

96 'string': encoding_packer('utf-8'), 

97 'time_string': (time_to_str_ns, str_to_time), 

98 '@i2': array_packer('@i2'), 

99 '@i4': array_packer('@i4'), 

100 '@i8': array_packer('@i8'), 

101 '@u2': array_packer('@u2'), 

102 '@u4': array_packer('@u4'), 

103 '@u8': array_packer('@u8'), 

104 '@f4': array_packer('@f4'), 

105 '@f8': array_packer('@f8'), 

106} 

107 

108 

109def pack_value(type, value): 

110 try: 

111 return castings[type][0](value) 

112 except Exception as e: 

113 raise FileError( 

114 'Packing value failed (type=%s, value=%s, error=%s).' % 

115 (type, str(value)[:500], e)) 

116 

117 

118def unpack_value(type, value): 

119 try: 

120 return castings[type][1](value) 

121 except Exception as e: 

122 raise FileError( 

123 'Unpacking value failed (type=%s, error=%s).' % (type, e)) 

124 

125 

126class FileError(Exception): 

127 pass 

128 

129 

130class NoDataAvailable(Exception): 

131 pass 

132 

133 

134class WrongRecordType(Exception): 

135 pass 

136 

137 

138class MissingRecordValue(Exception): 

139 pass 

140 

141 

142class Record(object): 

143 def __init__( 

144 self, parent, mode, size_record, size_payload, hash, type, format, 

145 do_hash): 

146 

147 self.mode = mode 

148 self.size_record = size_record 

149 self.size_payload = size_payload 

150 self.hash = hash 

151 self.type = type 

152 if mode == 'w': 

153 self.size_payload = 0 

154 self.hash = None 

155 self._out = BytesIO() 

156 else: 

157 self.size_remaining = self.size_record - size_record_header 

158 self.size_padding = self.size_record - size_record_header - \ 

159 self.size_payload 

160 

161 self._f = parent._f 

162 self._parent = parent 

163 self._hasher = None 

164 self.format = format 

165 if do_hash and (self.mode == 'w' or self.hash): 

166 self._hasher = sha1() 

167 self._closed = False 

168 

169 def read(self, n=None): 

170 

171 assert not self._closed 

172 assert self.mode == 'r' 

173 

174 if n is None: 

175 n = self.size_payload 

176 

177 n = min(n, self.size_remaining - self.size_padding) 

178 data = self._f.read(n) 

179 self.size_remaining -= len(data) 

180 

181 if len(data) != n: 

182 raise FileError('Read returned less data than expected.') 

183 

184 if self._hasher: 

185 self._hasher.update(data) 

186 

187 return data 

188 

189 def write(self, data): 

190 assert not self._closed 

191 assert self.mode == 'w' 

192 self._out.write(data) 

193 if self._hasher: 

194 self._hasher.update(data) 

195 

196 self.size_payload += len(data) 

197 

198 def seek(self, n, whence): 

199 assert not self._closed 

200 assert self.mode == 'r' 

201 assert whence == SEEK_CUR 

202 assert n >= 0 

203 

204 n = min(n, self.size_remaining - self.size_padding) 

205 self._f.seek(n, whence) 

206 self._hasher = None 

207 self.size_remaining -= n 

208 

209 def skip(self, n): 

210 self.seek(n, SEEK_CUR) 

211 

212 def close(self): 

213 if self._closed: 

214 return 

215 

216 if self.mode == 'r': 

217 if self._hasher and self._hasher.digest() != self.hash: 

218 self.read(self.size_remaining) 

219 raise FileError( 

220 'Hash computed from record data does not match value ' 

221 'given in header.') 

222 else: 

223 self.seek(self.size_remaining, SEEK_CUR) 

224 

225 if self.size_padding: 

226 self._f.seek(self.size_padding, SEEK_CUR) 

227 else: 

228 if self.size_record is not None and \ 

229 self.size_payload > self.size_record - size_record_header: 

230 

231 raise FileError( 

232 'Too much data to fit into size-limited record.') 

233 

234 if self.size_record is None: 

235 self.size_record = self.size_payload + size_record_header 

236 

237 self.size_padding = self.size_record - self.size_payload - \ 

238 size_record_header 

239 

240 if self._hasher is not None: 

241 self.hash = self._hasher.digest() 

242 

243 self._parent.write_record_header( 

244 self.size_record, self.size_payload, self.hash, self.type) 

245 

246 self._f.write(self._out.getvalue()) 

247 self._out.close() 

248 self._f.write(b'\0' * self.size_padding) 

249 

250 self._closed = True 

251 self._parent = None 

252 self._f = None 

253 

254 def entries(self): 

255 

256 sizes = [] 

257 sum = 0 

258 while sum < self.size_payload: 

259 size = unpack('>Q', self.read(8))[0] 

260 sum += size + 8 

261 sizes.append(size) 

262 

263 n = len(sizes) // 3 

264 keys = [] 

265 keys = [str(self.read(sizes[j]).decode('ascii')) 

266 for j in range(n)] 

267 types = [str(self.read(sizes[j]).decode('ascii')) 

268 for j in range(n, 2*n)] 

269 for key, type, j in zip(keys, types, range(2*n, 3*n)): 

270 yield key, type, sizes[j] 

271 

272 def unpack(self, exclude=None): 

273 

274 d = {} 

275 for key, type, size in self.entries(): 

276 if self.format[key] != type: 

277 FileError('Record value in unexpected format.') 

278 

279 if not exclude or key not in exclude: 

280 d[key] = unpack_value(type, self.read(size)) 

281 else: 

282 self.skip(size) 

283 d[key] = None 

284 

285 for key in self.format: 

286 if key not in d: 

287 raise FileError('Missing record entry: %s.' % key) 

288 

289 return d 

290 

291 def pack(self, d): 

292 for key in self.format: 

293 if key not in d: 

294 raise MissingRecordValue() 

295 

296 keys = [] 

297 types = [] 

298 values = [] 

299 for key in d.keys(): 

300 if key in self.format: 

301 type = self.format[key] 

302 if isinstance(type, tuple): 

303 type = self._parent.get_type(key, d[key]) 

304 

305 keys.append(key.encode('ascii')) 

306 types.append(type.encode('ascii')) 

307 values.append(pack_value(type, d[key])) 

308 

309 sizes = [len(x) for x in keys+types+values] 

310 

311 self.write(pack('>%iQ' % len(sizes), *sizes)) 

312 for x in keys+types+values: 

313 self.write(x) 

314 

315 

316class File(object): 

317 

318 def __init__( 

319 self, f, 

320 type_label='TEST', 

321 version='0000', 

322 record_formats={}): 

323 

324 assert len(type_label) == 4 

325 assert len(version) == 4 

326 

327 self._file_type_label = type_label 

328 self._file_version = version 

329 self._record_formats = record_formats 

330 self._current_record = None 

331 self._f = f 

332 

333 def read_record_header(self): 

334 data = self._f.read(size_record_header) 

335 

336 if len(data) == 0: 

337 raise NoDataAvailable() 

338 

339 if len(data) != size_record_header: 

340 raise FileError('Read returned less data than expected.') 

341 

342 label, version, size_record, size_payload, hash, type = unpack( 

343 '>4s4sQQ20s20s', data) 

344 

345 label = str(label.decode('ascii')) 

346 version = str(version.decode('ascii')) 

347 type = str(type.rstrip().decode('ascii')) 

348 

349 if label != self._file_type_label: 

350 raise FileError('record file type label missing.') 

351 

352 if version != self._file_version: 

353 raise FileError('file version %s not supported.' % version) 

354 

355 type = type.rstrip() 

356 

357 if hash == no_hash: 

358 hash = None 

359 

360 return size_record, size_payload, hash, type 

361 

362 def write_record_header(self, size_record, size_payload, hash, type): 

363 if hash is None: 

364 hash = no_hash 

365 data = pack( 

366 '>4s4sQQ20s20s', 

367 self._file_type_label.encode('ascii'), 

368 self._file_version.encode('ascii'), 

369 size_record, 

370 size_payload, 

371 hash, 

372 type.encode('ascii').ljust(20)[:20]) 

373 

374 self._f.write(data) 

375 

376 def next_record(self, check_hash=False): 

377 if self._current_record: 

378 self._current_record.close() 

379 

380 size_record, size_payload, hash, type = self.read_record_header() 

381 format = self._record_formats[type] 

382 self._current_record = Record( 

383 self, 'r', size_record, size_payload, hash, type, format, 

384 check_hash) 

385 

386 return self._current_record 

387 

388 def add_record(self, type, size_record=None, make_hash=False): 

389 if self._current_record: 

390 self._current_record.close() 

391 

392 format = self._record_formats[type] 

393 self._current_record = Record( 

394 self, 'w', size_record, 0, None, type, format, make_hash) 

395 return self._current_record 

396 

397 def close(self): 

398 if self._current_record: 

399 self._current_record.close()