1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6# The container format:
7# * A file consists of records.
8# * A record consists of a record header a record payload, and possibly
9# padding.
10# * A record header consists of a label, a version, the record size, the
11# payload size, a hash, and a record type.
12# * A record payload consists of a sequence of record entries.
13# * A record entry consists of a key, a type, and a value.
14from __future__ import absolute_import, division
16from struct import unpack, pack
17from io import BytesIO
18import numpy as num
19try:
20 from hashlib import sha1
21except ImportError:
22 from sha import new as sha1
24try:
25 from os import SEEK_CUR
26except ImportError:
27 SEEK_CUR = 1
29from . import util
31try:
32 range = xrange
33except NameError:
34 pass
37size_record_header = 64
38no_hash = '\0' * 20
40numtypes = {
41 '@i2': (num.int16, '>i2'),
42 '@i4': (num.int32, '>i4'),
43 '@i8': (num.int64, '>i8'),
44 '@u2': (num.uint16, '>u2'),
45 '@u4': (num.uint32, '>u4'),
46 '@u8': (num.uint64, '>u8'),
47 '@f4': (num.float32, '>f4'),
48 '@f8': (num.float64, '>f8'),
49}
51numtype2type = dict([(v[0], k) for (k, v) in numtypes.items()])
54def packer(fmt):
55 return ((lambda x: pack('>'+fmt, x)), (lambda x: unpack('>'+fmt, x)[0]))
58def unpack_array(fmt, data):
59 return num.frombuffer(
60 data, dtype=numtypes[fmt][1]).astype(numtypes[fmt][0])
63def pack_array(fmt, data):
64 return data.astype(numtypes[fmt][1]).tobytes()
67def array_packer(fmt):
68 return ((lambda x: pack_array(fmt, x)), (lambda x: unpack_array(fmt, x)))
71def encoding_packer(enc):
72 return ((lambda x: x.encode(enc)), (lambda x: str(x.decode(enc))))
75def noop(x):
76 return x
79def time_to_str_ns(x):
80 return util.time_to_str(x, format=9).encode('utf-8')
83def str_to_time(x):
84 return util.str_to_time(str(x.decode('utf8')))
87castings = {
88 'i2': packer('h'),
89 'i4': packer('i'),
90 'i8': packer('q'),
91 'u2': packer('H'),
92 'u4': packer('I'),
93 'u8': packer('Q'),
94 'f4': packer('f'),
95 'f8': packer('d'),
96 'string': encoding_packer('utf-8'),
97 'time_string': (time_to_str_ns, str_to_time),
98 '@i2': array_packer('@i2'),
99 '@i4': array_packer('@i4'),
100 '@i8': array_packer('@i8'),
101 '@u2': array_packer('@u2'),
102 '@u4': array_packer('@u4'),
103 '@u8': array_packer('@u8'),
104 '@f4': array_packer('@f4'),
105 '@f8': array_packer('@f8'),
106}
109def pack_value(type, value):
110 try:
111 return castings[type][0](value)
112 except Exception as e:
113 raise FileError(
114 'Packing value failed (type=%s, value=%s, error=%s).' %
115 (type, str(value)[:500], e))
118def unpack_value(type, value):
119 try:
120 return castings[type][1](value)
121 except Exception as e:
122 raise FileError(
123 'Unpacking value failed (type=%s, error=%s).' % (type, e))
126class FileError(Exception):
127 pass
130class NoDataAvailable(Exception):
131 pass
134class WrongRecordType(Exception):
135 pass
138class MissingRecordValue(Exception):
139 pass
142class Record(object):
143 def __init__(
144 self, parent, mode, size_record, size_payload, hash, type, format,
145 do_hash):
147 self.mode = mode
148 self.size_record = size_record
149 self.size_payload = size_payload
150 self.hash = hash
151 self.type = type
152 if mode == 'w':
153 self.size_payload = 0
154 self.hash = None
155 self._out = BytesIO()
156 else:
157 self.size_remaining = self.size_record - size_record_header
158 self.size_padding = self.size_record - size_record_header - \
159 self.size_payload
161 self._f = parent._f
162 self._parent = parent
163 self._hasher = None
164 self.format = format
165 if do_hash and (self.mode == 'w' or self.hash):
166 self._hasher = sha1()
167 self._closed = False
169 def read(self, n=None):
171 assert not self._closed
172 assert self.mode == 'r'
174 if n is None:
175 n = self.size_payload
177 n = min(n, self.size_remaining - self.size_padding)
178 data = self._f.read(n)
179 self.size_remaining -= len(data)
181 if len(data) != n:
182 raise FileError('Read returned less data than expected.')
184 if self._hasher:
185 self._hasher.update(data)
187 return data
189 def write(self, data):
190 assert not self._closed
191 assert self.mode == 'w'
192 self._out.write(data)
193 if self._hasher:
194 self._hasher.update(data)
196 self.size_payload += len(data)
198 def seek(self, n, whence):
199 assert not self._closed
200 assert self.mode == 'r'
201 assert whence == SEEK_CUR
202 assert n >= 0
204 n = min(n, self.size_remaining - self.size_padding)
205 self._f.seek(n, whence)
206 self._hasher = None
207 self.size_remaining -= n
209 def skip(self, n):
210 self.seek(n, SEEK_CUR)
212 def close(self):
213 if self._closed:
214 return
216 if self.mode == 'r':
217 if self._hasher and self._hasher.digest() != self.hash:
218 self.read(self.size_remaining)
219 raise FileError(
220 'Hash computed from record data does not match value '
221 'given in header.')
222 else:
223 self.seek(self.size_remaining, SEEK_CUR)
225 if self.size_padding:
226 self._f.seek(self.size_padding, SEEK_CUR)
227 else:
228 if self.size_record is not None and \
229 self.size_payload > self.size_record - size_record_header:
231 raise FileError(
232 'Too much data to fit into size-limited record.')
234 if self.size_record is None:
235 self.size_record = self.size_payload + size_record_header
237 self.size_padding = self.size_record - self.size_payload - \
238 size_record_header
240 if self._hasher is not None:
241 self.hash = self._hasher.digest()
243 self._parent.write_record_header(
244 self.size_record, self.size_payload, self.hash, self.type)
246 self._f.write(self._out.getvalue())
247 self._out.close()
248 self._f.write(b'\0' * self.size_padding)
250 self._closed = True
251 self._parent = None
252 self._f = None
254 def entries(self):
256 sizes = []
257 sum = 0
258 while sum < self.size_payload:
259 size = unpack('>Q', self.read(8))[0]
260 sum += size + 8
261 sizes.append(size)
263 n = len(sizes) // 3
264 keys = []
265 keys = [str(self.read(sizes[j]).decode('ascii'))
266 for j in range(n)]
267 types = [str(self.read(sizes[j]).decode('ascii'))
268 for j in range(n, 2*n)]
269 for key, type, j in zip(keys, types, range(2*n, 3*n)):
270 yield key, type, sizes[j]
272 def unpack(self, exclude=None):
274 d = {}
275 for key, type, size in self.entries():
276 if self.format[key] != type:
277 FileError('Record value in unexpected format.')
279 if not exclude or key not in exclude:
280 d[key] = unpack_value(type, self.read(size))
281 else:
282 self.skip(size)
283 d[key] = None
285 for key in self.format:
286 if key not in d:
287 raise FileError('Missing record entry: %s.' % key)
289 return d
291 def pack(self, d):
292 for key in self.format:
293 if key not in d:
294 raise MissingRecordValue()
296 keys = []
297 types = []
298 values = []
299 for key in d.keys():
300 if key in self.format:
301 type = self.format[key]
302 if isinstance(type, tuple):
303 type = self._parent.get_type(key, d[key])
305 keys.append(key.encode('ascii'))
306 types.append(type.encode('ascii'))
307 values.append(pack_value(type, d[key]))
309 sizes = [len(x) for x in keys+types+values]
311 self.write(pack('>%iQ' % len(sizes), *sizes))
312 for x in keys+types+values:
313 self.write(x)
316class File(object):
318 def __init__(
319 self, f,
320 type_label='TEST',
321 version='0000',
322 record_formats={}):
324 assert len(type_label) == 4
325 assert len(version) == 4
327 self._file_type_label = type_label
328 self._file_version = version
329 self._record_formats = record_formats
330 self._current_record = None
331 self._f = f
333 def read_record_header(self):
334 data = self._f.read(size_record_header)
336 if len(data) == 0:
337 raise NoDataAvailable()
339 if len(data) != size_record_header:
340 raise FileError('Read returned less data than expected.')
342 label, version, size_record, size_payload, hash, type = unpack(
343 '>4s4sQQ20s20s', data)
345 label = str(label.decode('ascii'))
346 version = str(version.decode('ascii'))
347 type = str(type.rstrip().decode('ascii'))
349 if label != self._file_type_label:
350 raise FileError('record file type label missing.')
352 if version != self._file_version:
353 raise FileError('file version %s not supported.' % version)
355 type = type.rstrip()
357 if hash == no_hash:
358 hash = None
360 return size_record, size_payload, hash, type
362 def write_record_header(self, size_record, size_payload, hash, type):
363 if hash is None:
364 hash = no_hash
365 data = pack(
366 '>4s4sQQ20s20s',
367 self._file_type_label.encode('ascii'),
368 self._file_version.encode('ascii'),
369 size_record,
370 size_payload,
371 hash,
372 type.encode('ascii').ljust(20)[:20])
374 self._f.write(data)
376 def next_record(self, check_hash=False):
377 if self._current_record:
378 self._current_record.close()
380 size_record, size_payload, hash, type = self.read_record_header()
381 format = self._record_formats[type]
382 self._current_record = Record(
383 self, 'r', size_record, size_payload, hash, type, format,
384 check_hash)
386 return self._current_record
388 def add_record(self, type, size_record=None, make_hash=False):
389 if self._current_record:
390 self._current_record.close()
392 format = self._record_formats[type]
393 self._current_record = Record(
394 self, 'w', size_record, 0, None, type, format, make_hash)
395 return self._current_record
397 def close(self):
398 if self._current_record:
399 self._current_record.close()