1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6# The container format:
7# * A file consists of records.
8# * A record consists of a record header a record payload, and possibly
9# padding.
10# * A record header consists of a label, a version, the record size, the
11# payload size, a hash, and a record type.
12# * A record payload consists of a sequence of record entries.
13# * A record entry consists of a key, a type, and a value.
15from struct import unpack, pack
16from io import BytesIO
17import numpy as num
18try:
19 from hashlib import sha1
20except ImportError:
21 from sha import new as sha1
23try:
24 from os import SEEK_CUR
25except ImportError:
26 SEEK_CUR = 1
28from . import util
30try:
31 range = xrange
32except NameError:
33 pass
36size_record_header = 64
37no_hash = '\0' * 20
39numtypes = {
40 '@i2': (num.int16, '>i2'),
41 '@i4': (num.int32, '>i4'),
42 '@i8': (num.int64, '>i8'),
43 '@u2': (num.uint16, '>u2'),
44 '@u4': (num.uint32, '>u4'),
45 '@u8': (num.uint64, '>u8'),
46 '@f4': (num.float32, '>f4'),
47 '@f8': (num.float64, '>f8'),
48}
50numtype2type = dict([(v[0], k) for (k, v) in numtypes.items()])
53def packer(fmt):
54 return ((lambda x: pack('>'+fmt, x)), (lambda x: unpack('>'+fmt, x)[0]))
57def unpack_array(fmt, data):
58 return num.frombuffer(
59 data, dtype=numtypes[fmt][1]).astype(numtypes[fmt][0])
62def pack_array(fmt, data):
63 return data.astype(numtypes[fmt][1]).tobytes()
66def array_packer(fmt):
67 return ((lambda x: pack_array(fmt, x)), (lambda x: unpack_array(fmt, x)))
70def encoding_packer(enc):
71 return ((lambda x: x.encode(enc)), (lambda x: str(x.decode(enc))))
74def noop(x):
75 return x
78def time_to_str_ns(x):
79 return util.time_to_str(x, format=9).encode('utf-8')
82def str_to_time(x):
83 return util.str_to_time(str(x.decode('utf8')))
86castings = {
87 'i2': packer('h'),
88 'i4': packer('i'),
89 'i8': packer('q'),
90 'u2': packer('H'),
91 'u4': packer('I'),
92 'u8': packer('Q'),
93 'f4': packer('f'),
94 'f8': packer('d'),
95 'string': encoding_packer('utf-8'),
96 'time_string': (time_to_str_ns, str_to_time),
97 '@i2': array_packer('@i2'),
98 '@i4': array_packer('@i4'),
99 '@i8': array_packer('@i8'),
100 '@u2': array_packer('@u2'),
101 '@u4': array_packer('@u4'),
102 '@u8': array_packer('@u8'),
103 '@f4': array_packer('@f4'),
104 '@f8': array_packer('@f8'),
105}
108def pack_value(type, value):
109 try:
110 return castings[type][0](value)
111 except Exception as e:
112 raise FileError(
113 'Packing value failed (type=%s, value=%s, error=%s).' %
114 (type, str(value)[:500], e))
117def unpack_value(type, value):
118 try:
119 return castings[type][1](value)
120 except Exception as e:
121 raise FileError(
122 'Unpacking value failed (type=%s, error=%s).' % (type, e))
125class FileError(Exception):
126 pass
129class NoDataAvailable(Exception):
130 pass
133class WrongRecordType(Exception):
134 pass
137class MissingRecordValue(Exception):
138 pass
141class Record(object):
142 def __init__(
143 self, parent, mode, size_record, size_payload, hash, type, format,
144 do_hash):
146 self.mode = mode
147 self.size_record = size_record
148 self.size_payload = size_payload
149 self.hash = hash
150 self.type = type
151 if mode == 'w':
152 self.size_payload = 0
153 self.hash = None
154 self._out = BytesIO()
155 else:
156 self.size_remaining = self.size_record - size_record_header
157 self.size_padding = self.size_record - size_record_header - \
158 self.size_payload
160 self._f = parent._f
161 self._parent = parent
162 self._hasher = None
163 self.format = format
164 if do_hash and (self.mode == 'w' or self.hash):
165 self._hasher = sha1()
166 self._closed = False
168 def read(self, n=None):
170 assert not self._closed
171 assert self.mode == 'r'
173 if n is None:
174 n = self.size_payload
176 n = min(n, self.size_remaining - self.size_padding)
177 data = self._f.read(n)
178 self.size_remaining -= len(data)
180 if len(data) != n:
181 raise FileError('Read returned less data than expected.')
183 if self._hasher:
184 self._hasher.update(data)
186 return data
188 def write(self, data):
189 assert not self._closed
190 assert self.mode == 'w'
191 self._out.write(data)
192 if self._hasher:
193 self._hasher.update(data)
195 self.size_payload += len(data)
197 def seek(self, n, whence):
198 assert not self._closed
199 assert self.mode == 'r'
200 assert whence == SEEK_CUR
201 assert n >= 0
203 n = min(n, self.size_remaining - self.size_padding)
204 self._f.seek(n, whence)
205 self._hasher = None
206 self.size_remaining -= n
208 def skip(self, n):
209 self.seek(n, SEEK_CUR)
211 def close(self):
212 if self._closed:
213 return
215 if self.mode == 'r':
216 if self._hasher and self._hasher.digest() != self.hash:
217 self.read(self.size_remaining)
218 raise FileError(
219 'Hash computed from record data does not match value '
220 'given in header.')
221 else:
222 self.seek(self.size_remaining, SEEK_CUR)
224 if self.size_padding:
225 self._f.seek(self.size_padding, SEEK_CUR)
226 else:
227 if self.size_record is not None and \
228 self.size_payload > self.size_record - size_record_header:
230 raise FileError(
231 'Too much data to fit into size-limited record.')
233 if self.size_record is None:
234 self.size_record = self.size_payload + size_record_header
236 self.size_padding = self.size_record - self.size_payload - \
237 size_record_header
239 if self._hasher is not None:
240 self.hash = self._hasher.digest()
242 self._parent.write_record_header(
243 self.size_record, self.size_payload, self.hash, self.type)
245 self._f.write(self._out.getvalue())
246 self._out.close()
247 self._f.write(b'\0' * self.size_padding)
249 self._closed = True
250 self._parent = None
251 self._f = None
253 def entries(self):
255 sizes = []
256 sum = 0
257 while sum < self.size_payload:
258 size = unpack('>Q', self.read(8))[0]
259 sum += size + 8
260 sizes.append(size)
262 n = len(sizes) // 3
263 keys = []
264 keys = [str(self.read(sizes[j]).decode('ascii'))
265 for j in range(n)]
266 types = [str(self.read(sizes[j]).decode('ascii'))
267 for j in range(n, 2*n)]
268 for key, type, j in zip(keys, types, range(2*n, 3*n)):
269 yield key, type, sizes[j]
271 def unpack(self, exclude=None):
273 d = {}
274 for key, type, size in self.entries():
275 if self.format[key] != type:
276 FileError('Record value in unexpected format.')
278 if not exclude or key not in exclude:
279 d[key] = unpack_value(type, self.read(size))
280 else:
281 self.skip(size)
282 d[key] = None
284 for key in self.format:
285 if key not in d:
286 raise FileError('Missing record entry: %s.' % key)
288 return d
290 def pack(self, d):
291 for key in self.format:
292 if key not in d:
293 raise MissingRecordValue()
295 keys = []
296 types = []
297 values = []
298 for key in d.keys():
299 if key in self.format:
300 type = self.format[key]
301 if isinstance(type, tuple):
302 type = self._parent.get_type(key, d[key])
304 keys.append(key.encode('ascii'))
305 types.append(type.encode('ascii'))
306 values.append(pack_value(type, d[key]))
308 sizes = [len(x) for x in keys+types+values]
310 self.write(pack('>%iQ' % len(sizes), *sizes))
311 for x in keys+types+values:
312 self.write(x)
315class File(object):
317 def __init__(
318 self, f,
319 type_label='TEST',
320 version='0000',
321 record_formats={}):
323 assert len(type_label) == 4
324 assert len(version) == 4
326 self._file_type_label = type_label
327 self._file_version = version
328 self._record_formats = record_formats
329 self._current_record = None
330 self._f = f
332 def read_record_header(self):
333 data = self._f.read(size_record_header)
335 if len(data) == 0:
336 raise NoDataAvailable()
338 if len(data) != size_record_header:
339 raise FileError('Read returned less data than expected.')
341 label, version, size_record, size_payload, hash, type = unpack(
342 '>4s4sQQ20s20s', data)
344 label = str(label.decode('ascii'))
345 version = str(version.decode('ascii'))
346 type = str(type.rstrip().decode('ascii'))
348 if label != self._file_type_label:
349 raise FileError('record file type label missing.')
351 if version != self._file_version:
352 raise FileError('file version %s not supported.' % version)
354 type = type.rstrip()
356 if hash == no_hash:
357 hash = None
359 return size_record, size_payload, hash, type
361 def write_record_header(self, size_record, size_payload, hash, type):
362 if hash is None:
363 hash = no_hash
364 data = pack(
365 '>4s4sQQ20s20s',
366 self._file_type_label.encode('ascii'),
367 self._file_version.encode('ascii'),
368 size_record,
369 size_payload,
370 hash,
371 type.encode('ascii').ljust(20)[:20])
373 self._f.write(data)
375 def next_record(self, check_hash=False):
376 if self._current_record:
377 self._current_record.close()
379 size_record, size_payload, hash, type = self.read_record_header()
380 format = self._record_formats[type]
381 self._current_record = Record(
382 self, 'r', size_record, size_payload, hash, type, format,
383 check_hash)
385 return self._current_record
387 def add_record(self, type, size_record=None, make_hash=False):
388 if self._current_record:
389 self._current_record.close()
391 format = self._record_formats[type]
392 self._current_record = Record(
393 self, 'w', size_record, 0, None, type, format, make_hash)
394 return self._current_record
396 def close(self):
397 if self._current_record:
398 self._current_record.close()