Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/file.py: 92%
213 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-06 06:59 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-06 06:59 +0000
1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6'''
7IO library for simple binary files.
8'''
10# The container format:
11# * A file consists of records.
12# * A record consists of a record header a record payload, and possibly
13# padding.
14# * A record header consists of a label, a version, the record size, the
15# payload size, a hash, and a record type.
16# * A record payload consists of a sequence of record entries.
17# * A record entry consists of a key, a type, and a value.
19from struct import unpack, pack
20from io import BytesIO
21import numpy as num
22try:
23 from hashlib import sha1
24except ImportError:
25 from sha import new as sha1
27try:
28 from os import SEEK_CUR
29except ImportError:
30 SEEK_CUR = 1
32from . import util
34try:
35 range = xrange
36except NameError:
37 pass
40size_record_header = 64
41no_hash = '\0' * 20
43numtypes = {
44 '@i2': (num.int16, '>i2'),
45 '@i4': (num.int32, '>i4'),
46 '@i8': (num.int64, '>i8'),
47 '@u2': (num.uint16, '>u2'),
48 '@u4': (num.uint32, '>u4'),
49 '@u8': (num.uint64, '>u8'),
50 '@f4': (num.float32, '>f4'),
51 '@f8': (num.float64, '>f8'),
52}
54numtype2type = dict([(v[0], k) for (k, v) in numtypes.items()])
57def packer(fmt):
58 return ((lambda x: pack('>'+fmt, x)), (lambda x: unpack('>'+fmt, x)[0]))
61def unpack_array(fmt, data):
62 return num.frombuffer(
63 data, dtype=numtypes[fmt][1]).astype(numtypes[fmt][0])
66def pack_array(fmt, data):
67 return data.astype(numtypes[fmt][1]).tobytes()
70def array_packer(fmt):
71 return ((lambda x: pack_array(fmt, x)), (lambda x: unpack_array(fmt, x)))
74def encoding_packer(enc):
75 return ((lambda x: x.encode(enc)), (lambda x: str(x.decode(enc))))
78def noop(x):
79 return x
82def time_to_str_ns(x):
83 return util.time_to_str(x, format=9).encode('utf-8')
86def str_to_time(x):
87 return util.str_to_time(str(x.decode('utf8')))
90castings = {
91 'i2': packer('h'),
92 'i4': packer('i'),
93 'i8': packer('q'),
94 'u2': packer('H'),
95 'u4': packer('I'),
96 'u8': packer('Q'),
97 'f4': packer('f'),
98 'f8': packer('d'),
99 'string': encoding_packer('utf-8'),
100 'time_string': (time_to_str_ns, str_to_time),
101 '@i2': array_packer('@i2'),
102 '@i4': array_packer('@i4'),
103 '@i8': array_packer('@i8'),
104 '@u2': array_packer('@u2'),
105 '@u4': array_packer('@u4'),
106 '@u8': array_packer('@u8'),
107 '@f4': array_packer('@f4'),
108 '@f8': array_packer('@f8'),
109}
112def pack_value(type, value):
113 try:
114 return castings[type][0](value)
115 except Exception as e:
116 raise FileError(
117 'Packing value failed (type=%s, value=%s, error=%s).' %
118 (type, str(value)[:500], e))
121def unpack_value(type, value):
122 try:
123 return castings[type][1](value)
124 except Exception as e:
125 raise FileError(
126 'Unpacking value failed (type=%s, error=%s).' % (type, e))
129class FileError(Exception):
130 pass
133class NoDataAvailable(Exception):
134 pass
137class WrongRecordType(Exception):
138 pass
141class MissingRecordValue(Exception):
142 pass
145class Record(object):
146 def __init__(
147 self, parent, mode, size_record, size_payload, hash, type, format,
148 do_hash):
150 self.mode = mode
151 self.size_record = size_record
152 self.size_payload = size_payload
153 self.hash = hash
154 self.type = type
155 if mode == 'w':
156 self.size_payload = 0
157 self.hash = None
158 self._out = BytesIO()
159 else:
160 self.size_remaining = self.size_record - size_record_header
161 self.size_padding = self.size_record - size_record_header - \
162 self.size_payload
164 self._f = parent._f
165 self._parent = parent
166 self._hasher = None
167 self.format = format
168 if do_hash and (self.mode == 'w' or self.hash):
169 self._hasher = sha1()
170 self._closed = False
172 def read(self, n=None):
174 assert not self._closed
175 assert self.mode == 'r'
177 if n is None:
178 n = self.size_payload
180 n = min(n, self.size_remaining - self.size_padding)
181 data = self._f.read(n)
182 self.size_remaining -= len(data)
184 if len(data) != n:
185 raise FileError('Read returned less data than expected.')
187 if self._hasher:
188 self._hasher.update(data)
190 return data
192 def write(self, data):
193 assert not self._closed
194 assert self.mode == 'w'
195 self._out.write(data)
196 if self._hasher:
197 self._hasher.update(data)
199 self.size_payload += len(data)
201 def seek(self, n, whence):
202 assert not self._closed
203 assert self.mode == 'r'
204 assert whence == SEEK_CUR
205 assert n >= 0
207 n = min(n, self.size_remaining - self.size_padding)
208 self._f.seek(n, whence)
209 self._hasher = None
210 self.size_remaining -= n
212 def skip(self, n):
213 self.seek(n, SEEK_CUR)
215 def close(self):
216 if self._closed:
217 return
219 if self.mode == 'r':
220 if self._hasher and self._hasher.digest() != self.hash:
221 self.read(self.size_remaining)
222 raise FileError(
223 'Hash computed from record data does not match value '
224 'given in header.')
225 else:
226 self.seek(self.size_remaining, SEEK_CUR)
228 if self.size_padding:
229 self._f.seek(self.size_padding, SEEK_CUR)
230 else:
231 if self.size_record is not None and \
232 self.size_payload > self.size_record - size_record_header:
234 raise FileError(
235 'Too much data to fit into size-limited record.')
237 if self.size_record is None:
238 self.size_record = self.size_payload + size_record_header
240 self.size_padding = self.size_record - self.size_payload - \
241 size_record_header
243 if self._hasher is not None:
244 self.hash = self._hasher.digest()
246 self._parent.write_record_header(
247 self.size_record, self.size_payload, self.hash, self.type)
249 self._f.write(self._out.getvalue())
250 self._out.close()
251 self._f.write(b'\0' * self.size_padding)
253 self._closed = True
254 self._parent = None
255 self._f = None
257 def entries(self):
259 sizes = []
260 sum = 0
261 while sum < self.size_payload:
262 size = unpack('>Q', self.read(8))[0]
263 sum += size + 8
264 sizes.append(size)
266 n = len(sizes) // 3
267 keys = []
268 keys = [str(self.read(sizes[j]).decode('ascii'))
269 for j in range(n)]
270 types = [str(self.read(sizes[j]).decode('ascii'))
271 for j in range(n, 2*n)]
272 for key, type, j in zip(keys, types, range(2*n, 3*n)):
273 yield key, type, sizes[j]
275 def unpack(self, exclude=None):
277 d = {}
278 for key, type, size in self.entries():
279 if self.format[key] != type:
280 FileError('Record value in unexpected format.')
282 if not exclude or key not in exclude:
283 d[key] = unpack_value(type, self.read(size))
284 else:
285 self.skip(size)
286 d[key] = None
288 for key in self.format:
289 if key not in d:
290 raise FileError('Missing record entry: %s.' % key)
292 return d
294 def pack(self, d):
295 for key in self.format:
296 if key not in d:
297 raise MissingRecordValue()
299 keys = []
300 types = []
301 values = []
302 for key in d.keys():
303 if key in self.format:
304 type = self.format[key]
305 if isinstance(type, tuple):
306 type = self._parent.get_type(key, d[key])
308 keys.append(key.encode('ascii'))
309 types.append(type.encode('ascii'))
310 values.append(pack_value(type, d[key]))
312 sizes = [len(x) for x in keys+types+values]
314 self.write(pack('>%iQ' % len(sizes), *sizes))
315 for x in keys+types+values:
316 self.write(x)
319class File(object):
321 def __init__(
322 self, f,
323 type_label='TEST',
324 version='0000',
325 record_formats={}):
327 assert len(type_label) == 4
328 assert len(version) == 4
330 self._file_type_label = type_label
331 self._file_version = version
332 self._record_formats = record_formats
333 self._current_record = None
334 self._f = f
336 def read_record_header(self):
337 data = self._f.read(size_record_header)
339 if len(data) == 0:
340 raise NoDataAvailable()
342 if len(data) != size_record_header:
343 raise FileError('Read returned less data than expected.')
345 label, version, size_record, size_payload, hash, type = unpack(
346 '>4s4sQQ20s20s', data)
348 label = str(label.decode('ascii'))
349 version = str(version.decode('ascii'))
350 type = str(type.rstrip().decode('ascii'))
352 if label != self._file_type_label:
353 raise FileError('record file type label missing.')
355 if version != self._file_version:
356 raise FileError('file version %s not supported.' % version)
358 type = type.rstrip()
360 if hash == no_hash:
361 hash = None
363 return size_record, size_payload, hash, type
365 def write_record_header(self, size_record, size_payload, hash, type):
366 if hash is None:
367 hash = no_hash
368 data = pack(
369 '>4s4sQQ20s20s',
370 self._file_type_label.encode('ascii'),
371 self._file_version.encode('ascii'),
372 size_record,
373 size_payload,
374 hash,
375 type.encode('ascii').ljust(20)[:20])
377 self._f.write(data)
379 def next_record(self, check_hash=False):
380 if self._current_record:
381 self._current_record.close()
383 size_record, size_payload, hash, type = self.read_record_header()
384 format = self._record_formats[type]
385 self._current_record = Record(
386 self, 'r', size_record, size_payload, hash, type, format,
387 check_hash)
389 return self._current_record
391 def add_record(self, type, size_record=None, make_hash=False):
392 if self._current_record:
393 self._current_record.close()
395 format = self._record_formats[type]
396 self._current_record = Record(
397 self, 'w', size_record, 0, None, type, format, make_hash)
398 return self._current_record
400 def close(self):
401 if self._current_record:
402 self._current_record.close()