5df6613837a744e46781c08a7d51bfc52cb64e2e
[~helmut/debian-dedup.git] / dedup / compression.py
1 import bz2
2 import struct
3 import sys
4 import zlib
5
6 import lzma
7
8 crc32_type = "L" if sys.version_info.major >= 3 else "l"
9
10 class GzipDecompressor(object):
11     """An interface to gzip which is similar to bz2.BZ2Decompressor and
12     lzma.LZMADecompressor."""
13     def __init__(self):
14         self.sawheader = False
15         self.inbuffer = b""
16         self.decompressor = None
17         self.crc = 0
18         self.size = 0
19
20     def decompress(self, data):
21         """
22         @raises ValueError: if no gzip magic is found
23         @raises zlib.error: from zlib invocations
24         """
25         while True:
26             if self.decompressor:
27                 data = self.decompressor.decompress(data)
28                 self.crc = zlib.crc32(data, self.crc)
29                 self.size += len(data)
30                 unused_data = self.decompressor.unused_data
31                 if not unused_data:
32                     return data
33                 self.decompressor = None
34                 return data + self.decompress(unused_data)
35             self.inbuffer += data
36             skip = 10
37             if len(self.inbuffer) < skip:
38                 return b""
39             if not self.inbuffer.startswith(b"\037\213\010"):
40                 raise ValueError("gzip magic not found")
41             flag = ord(self.inbuffer[3:4])
42             if flag & 4:
43                 if len(self.inbuffer) < skip + 2:
44                     return b""
45                 length, = struct.unpack("<H", self.inbuffer[skip:skip+2])
46                 skip += 2 + length
47             for field in (8, 16):
48                 if flag & field:
49                     length = self.inbuffer.find(b"\0", skip)
50                     if length < 0:
51                         return b""
52                     skip = length + 1
53             if flag & 2:
54                 skip += 2
55             if len(self.inbuffer) < skip:
56                 return b""
57             data = self.inbuffer[skip:]
58             self.inbuffer = b""
59             self.sawheader = True
60             self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
61
62     @property
63     def unused_data(self):
64         if self.decompressor:
65             return self.decompressor.unused_data
66         elif not self.sawheader:
67             return self.inbuffer
68         else:
69             expect = struct.pack("<" + crc32_type + "L", self.crc, self.size)
70             if self.inbuffer.startswith(expect) and \
71                     self.inbuffer[len(expect):].replace(b"\0", b"") == b"":
72                 return b""
73             return self.inbuffer
74
75     def flush(self):
76         """
77         @raises zlib.error: from zlib invocations
78         """
79         if not self.decompressor:
80             return b""
81         return self.decompressor.flush()
82
83     def copy(self):
84         new = GzipDecompressor()
85         new.inbuffer = self.inbuffer
86         if self.decompressor:
87             new.decompressor = self.decompressor.copy()
88         new.sawheader = self.sawheader
89         new.crc = self.crc
90         new.size = self.size
91         return new
92
93 class DecompressedStream(object):
94     """Turn a readable file-like into a decompressed file-like. It supports
95     read(optional length), tell, seek(forward only) and close."""
96     blocksize = 65536
97
98     def __init__(self, fileobj, decompressor):
99         """
100         @param fileobj: a file-like object providing read(size)
101         @param decompressor: a bz2.BZ2Decompressor or lzma.LZMADecompressor
102             like object providing methods decompress and flush and an
103             attribute unused_data
104         """
105         self.fileobj = fileobj
106         self.decompressor = decompressor
107         self.buff = b""
108         self.pos = 0
109         self.closed = False
110
111     def read(self, length=None):
112         assert not self.closed
113         data = True
114         while True:
115             if length is not None and len(self.buff) >= length:
116                 ret = self.buff[:length]
117                 self.buff = self.buff[length:]
118                 break
119             elif not data: # read EOF in last iteration
120                 ret = self.buff
121                 self.buff = b""
122                 break
123             data = self.fileobj.read(self.blocksize)
124             if data:
125                 self.buff += self.decompressor.decompress(data)
126             else:
127                 self.buff += self.decompressor.flush()
128         self.pos += len(ret)
129         return ret
130
131     def tell(self):
132         assert not self.closed
133         return self.pos
134
135     def seek(self, pos):
136         """Forward seeks by absolute position only."""
137         assert not self.closed
138         if pos < self.pos:
139             raise ValueError("negative seek not allowed on decompressed stream")
140         while True:
141             left = pos - self.pos
142             # Reading self.buff entirely avoids string concatenation.
143             size = len(self.buff) or self.blocksize
144             if left > size:
145                 self.read(size)
146             else:
147                 self.read(left)
148                 return
149
150     def close(self):
151         if not self.closed:
152             self.fileobj.close()
153             self.fileobj = None
154             self.decompressor = None
155             self.buff = b""
156             self.closed = True
157
158 decompressors = {
159     '.gz':   GzipDecompressor,
160     '.bz2':  bz2.BZ2Decompressor,
161     '.lzma': lzma.LZMADecompressor,
162     '.xz':   lzma.LZMADecompressor,
163 }
164
165 def decompress(filelike, extension):
166     """Decompress a stream according to its extension.
167     @param filelike: is a read-only byte-stream. It must support read(size) and
168                      close().
169     @param extension: permitted values are "", ".gz", ".bz2", ".lzma", and
170                       ".xz"
171     @type extension: str
172     @returns: a read-only byte-stream with the decompressed contents of the
173               original filelike. It supports read(size) and close(). If the
174               original supports seek(pos) and tell(), then it also supports
175               those.
176     @raises ValueError: on unkown extensions
177     """
178     if not extension:
179         return filelike
180     try:
181         decompressor = decompressors[extension]
182     except KeyError:
183         raise ValueError("unknown compression format with extension %r" %
184                          extension)
185     return DecompressedStream(filelike, decompressor())