8d1912bbedebae782cc6dd09e0487d12e00955f6
[~helmut/debian-dedup.git] / dedup / compression.py
1 import bz2
2 import struct
3 import sys
4 import zlib
5
6 import lzma
7
8 crc32_type = "L" if sys.version_info.major >= 3 else "l"
9
10 class GzipDecompressor(object):
11     """An interface to gzip which is similar to bz2.BZ2Decompressor and
12     lzma.LZMADecompressor."""
13     def __init__(self):
14         self.sawheader = False
15         self.inbuffer = b""
16         self.decompressor = None
17         self.crc = 0
18         self.size = 0
19
20     def decompress(self, data):
21         """
22         @raises ValueError: if no gzip magic is found
23         @raises zlib.error: from zlib invocations
24         """
25         while True:
26             if self.decompressor:
27                 data = self.decompressor.decompress(data)
28                 self.crc = zlib.crc32(data, self.crc)
29                 self.size += len(data)
30                 unused_data = self.decompressor.unused_data
31                 if not unused_data:
32                     return data
33                 self.decompressor = None
34                 return data + self.decompress(unused_data)
35             self.inbuffer += data
36             skip = 10
37             if len(self.inbuffer) < skip:
38                 return b""
39             if not self.inbuffer.startswith(b"\037\213\010"):
40                 raise ValueError("gzip magic not found")
41             flag = ord(self.inbuffer[3:4])
42             if flag & 4:
43                 if len(self.inbuffer) < skip + 2:
44                     return b""
45                 length, = struct.unpack("<H", self.inbuffer[skip:skip+2])
46                 skip += 2 + length
47             for field in (8, 16):
48                 if flag & field:
49                     length = self.inbuffer.find(b"\0", skip)
50                     if length < 0:
51                         return b""
52                     skip = length + 1
53             if flag & 2:
54                 skip += 2
55             if len(self.inbuffer) < skip:
56                 return b""
57             data = self.inbuffer[skip:]
58             self.inbuffer = b""
59             self.sawheader = True
60             self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
61
62     @property
63     def unused_data(self):
64         if self.decompressor:
65             return self.decompressor.unused_data
66         elif not self.sawheader:
67             return self.inbuffer
68         else:
69             expect = struct.pack("<" + crc32_type + "L", self.crc, self.size)
70             if self.inbuffer.startswith(expect) and \
71                     self.inbuffer[len(expect):].replace(b"\0", b"") == b"":
72                 return b""
73             return self.inbuffer
74
75     def flush(self):
76         """
77         @raises zlib.error: from zlib invocations
78         """
79         if not self.decompressor:
80             return b""
81         return self.decompressor.flush()
82
83     def copy(self):
84         new = GzipDecompressor()
85         new.inbuffer = self.inbuffer
86         if self.decompressor:
87             new.decompressor = self.decompressor.copy()
88         new.sawheader = self.sawheader
89         new.crc = self.crc
90         new.size = self.size
91         return new
92
93 class DecompressedStream(object):
94     """Turn a readable file-like into a decompressed file-like. It supports
95     read(optional length), tell, seek(forward only) and close."""
96     blocksize = 65536
97
98     def __init__(self, fileobj, decompressor):
99         """
100         @param fileobj: a file-like object providing read(size)
101         @param decompressor: a bz2.BZ2Decompressor or lzma.LZMADecompressor
102             like object providing methods decompress and flush and an
103             attribute unused_data
104         """
105         self.fileobj = fileobj
106         self.decompressor = decompressor
107         self.buff = b""
108         self.pos = 0
109         self.closed = False
110
111     def _fill_buff_until(self, predicate):
112         assert not self.closed
113         data = True
114         while True:
115             if predicate(self.buff) or not data:
116                 return
117             data = self.fileobj.read(self.blocksize)
118             if data:
119                 self.buff += self.decompressor.decompress(data)
120             elif hasattr(self.decompressor, "flush"):
121                 self.buff += self.decompressor.flush()
122
123     def _read_from_buff(self, length):
124         ret = self.buff[:length]
125         self.buff = self.buff[length:]
126         self.pos += length
127         return ret
128
129     def read(self, length=None):
130         if length is None:
131             self._fill_buff_until(lambda _: False)
132             length = len(self.buff)
133         else:
134             self._fill_buff_until(lambda b, l=length: len(b) >= l)
135         return self._read_from_buff(length)
136
137     def readline(self):
138         self._fill_buff_until(lambda b: b'\n' in b)
139         try:
140             length = self.buff.index(b'\n') + 1
141         except ValueError:
142             length = len(self.buff)
143         return self._read_from_buff(length)
144
145     def __iter__(self):
146         return iter(self.readline, b'')
147
148     def tell(self):
149         assert not self.closed
150         return self.pos
151
152     def seek(self, pos):
153         """Forward seeks by absolute position only."""
154         assert not self.closed
155         if pos < self.pos:
156             raise ValueError("negative seek not allowed on decompressed stream")
157         while True:
158             left = pos - self.pos
159             # Reading self.buff entirely avoids string concatenation.
160             size = len(self.buff) or self.blocksize
161             if left > size:
162                 self.read(size)
163             else:
164                 self.read(left)
165                 return
166
167     def close(self):
168         if not self.closed:
169             self.fileobj.close()
170             self.fileobj = None
171             self.decompressor = None
172             self.buff = b""
173             self.closed = True
174
175 decompressors = {
176     u'.gz':   GzipDecompressor,
177     u'.bz2':  bz2.BZ2Decompressor,
178     u'.lzma': lzma.LZMADecompressor,
179     u'.xz':   lzma.LZMADecompressor,
180 }
181
182 def decompress(filelike, extension):
183     """Decompress a stream according to its extension.
184     @param filelike: is a read-only byte-stream. It must support read(size) and
185                      close().
186     @param extension: permitted values are "", ".gz", ".bz2", ".lzma", and
187                       ".xz"
188     @type extension: unicode
189     @returns: a read-only byte-stream with the decompressed contents of the
190               original filelike. It supports read(size) and close(). If the
191               original supports seek(pos) and tell(), then it also supports
192               those.
193     @raises ValueError: on unkown extensions
194     """
195     if not extension:
196         return filelike
197     try:
198         decompressor = decompressors[extension]
199     except KeyError:
200         raise ValueError("unknown compression format with extension %r" %
201                          extension)
202     return DecompressedStream(filelike, decompressor())