Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/file.py: 92%

213 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-06 09:48 +0000

1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6''' 

7IO library for simple binary files. 

8''' 

9 

10# The container format: 

11# * A file consists of records. 

12# * A record consists of a record header a record payload, and possibly 

13# padding. 

14# * A record header consists of a label, a version, the record size, the 

15# payload size, a hash, and a record type. 

16# * A record payload consists of a sequence of record entries. 

17# * A record entry consists of a key, a type, and a value. 

18 

19from struct import unpack, pack 

20from io import BytesIO 

21import numpy as num 

22try: 

23 from hashlib import sha1 

24except ImportError: 

25 from sha import new as sha1 

26 

27try: 

28 from os import SEEK_CUR 

29except ImportError: 

30 SEEK_CUR = 1 

31 

32from . import util 

33 

34try: 

35 range = xrange 

36except NameError: 

37 pass 

38 

39 

40size_record_header = 64 

41no_hash = '\0' * 20 

42 

43numtypes = { 

44 '@i2': (num.int16, '>i2'), 

45 '@i4': (num.int32, '>i4'), 

46 '@i8': (num.int64, '>i8'), 

47 '@u2': (num.uint16, '>u2'), 

48 '@u4': (num.uint32, '>u4'), 

49 '@u8': (num.uint64, '>u8'), 

50 '@f4': (num.float32, '>f4'), 

51 '@f8': (num.float64, '>f8'), 

52} 

53 

54numtype2type = dict([(v[0], k) for (k, v) in numtypes.items()]) 

55 

56 

57def packer(fmt): 

58 return ((lambda x: pack('>'+fmt, x)), (lambda x: unpack('>'+fmt, x)[0])) 

59 

60 

61def unpack_array(fmt, data): 

62 return num.frombuffer( 

63 data, dtype=numtypes[fmt][1]).astype(numtypes[fmt][0]) 

64 

65 

66def pack_array(fmt, data): 

67 return data.astype(numtypes[fmt][1]).tobytes() 

68 

69 

70def array_packer(fmt): 

71 return ((lambda x: pack_array(fmt, x)), (lambda x: unpack_array(fmt, x))) 

72 

73 

74def encoding_packer(enc): 

75 return ((lambda x: x.encode(enc)), (lambda x: str(x.decode(enc)))) 

76 

77 

78def noop(x): 

79 return x 

80 

81 

82def time_to_str_ns(x): 

83 return util.time_to_str(x, format=9).encode('utf-8') 

84 

85 

86def str_to_time(x): 

87 return util.str_to_time(str(x.decode('utf8'))) 

88 

89 

90castings = { 

91 'i2': packer('h'), 

92 'i4': packer('i'), 

93 'i8': packer('q'), 

94 'u2': packer('H'), 

95 'u4': packer('I'), 

96 'u8': packer('Q'), 

97 'f4': packer('f'), 

98 'f8': packer('d'), 

99 'string': encoding_packer('utf-8'), 

100 'time_string': (time_to_str_ns, str_to_time), 

101 '@i2': array_packer('@i2'), 

102 '@i4': array_packer('@i4'), 

103 '@i8': array_packer('@i8'), 

104 '@u2': array_packer('@u2'), 

105 '@u4': array_packer('@u4'), 

106 '@u8': array_packer('@u8'), 

107 '@f4': array_packer('@f4'), 

108 '@f8': array_packer('@f8'), 

109} 

110 

111 

112def pack_value(type, value): 

113 try: 

114 return castings[type][0](value) 

115 except Exception as e: 

116 raise FileError( 

117 'Packing value failed (type=%s, value=%s, error=%s).' % 

118 (type, str(value)[:500], e)) 

119 

120 

121def unpack_value(type, value): 

122 try: 

123 return castings[type][1](value) 

124 except Exception as e: 

125 raise FileError( 

126 'Unpacking value failed (type=%s, error=%s).' % (type, e)) 

127 

128 

129class FileError(Exception): 

130 pass 

131 

132 

133class NoDataAvailable(Exception): 

134 pass 

135 

136 

137class WrongRecordType(Exception): 

138 pass 

139 

140 

141class MissingRecordValue(Exception): 

142 pass 

143 

144 

145class Record(object): 

146 def __init__( 

147 self, parent, mode, size_record, size_payload, hash, type, format, 

148 do_hash): 

149 

150 self.mode = mode 

151 self.size_record = size_record 

152 self.size_payload = size_payload 

153 self.hash = hash 

154 self.type = type 

155 if mode == 'w': 

156 self.size_payload = 0 

157 self.hash = None 

158 self._out = BytesIO() 

159 else: 

160 self.size_remaining = self.size_record - size_record_header 

161 self.size_padding = self.size_record - size_record_header - \ 

162 self.size_payload 

163 

164 self._f = parent._f 

165 self._parent = parent 

166 self._hasher = None 

167 self.format = format 

168 if do_hash and (self.mode == 'w' or self.hash): 

169 self._hasher = sha1() 

170 self._closed = False 

171 

172 def read(self, n=None): 

173 

174 assert not self._closed 

175 assert self.mode == 'r' 

176 

177 if n is None: 

178 n = self.size_payload 

179 

180 n = min(n, self.size_remaining - self.size_padding) 

181 data = self._f.read(n) 

182 self.size_remaining -= len(data) 

183 

184 if len(data) != n: 

185 raise FileError('Read returned less data than expected.') 

186 

187 if self._hasher: 

188 self._hasher.update(data) 

189 

190 return data 

191 

192 def write(self, data): 

193 assert not self._closed 

194 assert self.mode == 'w' 

195 self._out.write(data) 

196 if self._hasher: 

197 self._hasher.update(data) 

198 

199 self.size_payload += len(data) 

200 

201 def seek(self, n, whence): 

202 assert not self._closed 

203 assert self.mode == 'r' 

204 assert whence == SEEK_CUR 

205 assert n >= 0 

206 

207 n = min(n, self.size_remaining - self.size_padding) 

208 self._f.seek(n, whence) 

209 self._hasher = None 

210 self.size_remaining -= n 

211 

212 def skip(self, n): 

213 self.seek(n, SEEK_CUR) 

214 

215 def close(self): 

216 if self._closed: 

217 return 

218 

219 if self.mode == 'r': 

220 if self._hasher and self._hasher.digest() != self.hash: 

221 self.read(self.size_remaining) 

222 raise FileError( 

223 'Hash computed from record data does not match value ' 

224 'given in header.') 

225 else: 

226 self.seek(self.size_remaining, SEEK_CUR) 

227 

228 if self.size_padding: 

229 self._f.seek(self.size_padding, SEEK_CUR) 

230 else: 

231 if self.size_record is not None and \ 

232 self.size_payload > self.size_record - size_record_header: 

233 

234 raise FileError( 

235 'Too much data to fit into size-limited record.') 

236 

237 if self.size_record is None: 

238 self.size_record = self.size_payload + size_record_header 

239 

240 self.size_padding = self.size_record - self.size_payload - \ 

241 size_record_header 

242 

243 if self._hasher is not None: 

244 self.hash = self._hasher.digest() 

245 

246 self._parent.write_record_header( 

247 self.size_record, self.size_payload, self.hash, self.type) 

248 

249 self._f.write(self._out.getvalue()) 

250 self._out.close() 

251 self._f.write(b'\0' * self.size_padding) 

252 

253 self._closed = True 

254 self._parent = None 

255 self._f = None 

256 

257 def entries(self): 

258 

259 sizes = [] 

260 sum = 0 

261 while sum < self.size_payload: 

262 size = unpack('>Q', self.read(8))[0] 

263 sum += size + 8 

264 sizes.append(size) 

265 

266 n = len(sizes) // 3 

267 keys = [] 

268 keys = [str(self.read(sizes[j]).decode('ascii')) 

269 for j in range(n)] 

270 types = [str(self.read(sizes[j]).decode('ascii')) 

271 for j in range(n, 2*n)] 

272 for key, type, j in zip(keys, types, range(2*n, 3*n)): 

273 yield key, type, sizes[j] 

274 

275 def unpack(self, exclude=None): 

276 

277 d = {} 

278 for key, type, size in self.entries(): 

279 if self.format[key] != type: 

280 FileError('Record value in unexpected format.') 

281 

282 if not exclude or key not in exclude: 

283 d[key] = unpack_value(type, self.read(size)) 

284 else: 

285 self.skip(size) 

286 d[key] = None 

287 

288 for key in self.format: 

289 if key not in d: 

290 raise FileError('Missing record entry: %s.' % key) 

291 

292 return d 

293 

294 def pack(self, d): 

295 for key in self.format: 

296 if key not in d: 

297 raise MissingRecordValue() 

298 

299 keys = [] 

300 types = [] 

301 values = [] 

302 for key in d.keys(): 

303 if key in self.format: 

304 type = self.format[key] 

305 if isinstance(type, tuple): 

306 type = self._parent.get_type(key, d[key]) 

307 

308 keys.append(key.encode('ascii')) 

309 types.append(type.encode('ascii')) 

310 values.append(pack_value(type, d[key])) 

311 

312 sizes = [len(x) for x in keys+types+values] 

313 

314 self.write(pack('>%iQ' % len(sizes), *sizes)) 

315 for x in keys+types+values: 

316 self.write(x) 

317 

318 

319class File(object): 

320 

321 def __init__( 

322 self, f, 

323 type_label='TEST', 

324 version='0000', 

325 record_formats={}): 

326 

327 assert len(type_label) == 4 

328 assert len(version) == 4 

329 

330 self._file_type_label = type_label 

331 self._file_version = version 

332 self._record_formats = record_formats 

333 self._current_record = None 

334 self._f = f 

335 

336 def read_record_header(self): 

337 data = self._f.read(size_record_header) 

338 

339 if len(data) == 0: 

340 raise NoDataAvailable() 

341 

342 if len(data) != size_record_header: 

343 raise FileError('Read returned less data than expected.') 

344 

345 label, version, size_record, size_payload, hash, type = unpack( 

346 '>4s4sQQ20s20s', data) 

347 

348 label = str(label.decode('ascii')) 

349 version = str(version.decode('ascii')) 

350 type = str(type.rstrip().decode('ascii')) 

351 

352 if label != self._file_type_label: 

353 raise FileError('record file type label missing.') 

354 

355 if version != self._file_version: 

356 raise FileError('file version %s not supported.' % version) 

357 

358 type = type.rstrip() 

359 

360 if hash == no_hash: 

361 hash = None 

362 

363 return size_record, size_payload, hash, type 

364 

365 def write_record_header(self, size_record, size_payload, hash, type): 

366 if hash is None: 

367 hash = no_hash 

368 data = pack( 

369 '>4s4sQQ20s20s', 

370 self._file_type_label.encode('ascii'), 

371 self._file_version.encode('ascii'), 

372 size_record, 

373 size_payload, 

374 hash, 

375 type.encode('ascii').ljust(20)[:20]) 

376 

377 self._f.write(data) 

378 

379 def next_record(self, check_hash=False): 

380 if self._current_record: 

381 self._current_record.close() 

382 

383 size_record, size_payload, hash, type = self.read_record_header() 

384 format = self._record_formats[type] 

385 self._current_record = Record( 

386 self, 'r', size_record, size_payload, hash, type, format, 

387 check_hash) 

388 

389 return self._current_record 

390 

391 def add_record(self, type, size_record=None, make_hash=False): 

392 if self._current_record: 

393 self._current_record.close() 

394 

395 format = self._record_formats[type] 

396 self._current_record = Record( 

397 self, 'w', size_record, 0, None, type, format, make_hash) 

398 return self._current_record 

399 

400 def close(self): 

401 if self._current_record: 

402 self._current_record.close()