decouple a function decompress out of decompress_tar
[~helmut/debian-dedup.git] / dedup / compression.py
index 869c49f..5df6613 100644 (file)
@@ -1,12 +1,21 @@
+import bz2
 import struct
+import sys
 import zlib
 
+import lzma
+
+crc32_type = "L" if sys.version_info.major >= 3 else "l"
+
 class GzipDecompressor(object):
     """An interface to gzip which is similar to bz2.BZ2Decompressor and
     lzma.LZMADecompressor."""
     def __init__(self):
+        self.sawheader = False
         self.inbuffer = b""
         self.decompressor = None
+        self.crc = 0
+        self.size = 0
 
     def decompress(self, data):
         """
@@ -16,6 +25,8 @@ class GzipDecompressor(object):
         while True:
             if self.decompressor:
                 data = self.decompressor.decompress(data)
+                self.crc = zlib.crc32(data, self.crc)
+                self.size += len(data)
                 unused_data = self.decompressor.unused_data
                 if not unused_data:
                     return data
@@ -27,7 +38,7 @@ class GzipDecompressor(object):
                 return b""
             if not self.inbuffer.startswith(b"\037\213\010"):
                 raise ValueError("gzip magic not found")
-            flag = ord(self.inbuffer[3])
+            flag = ord(self.inbuffer[3:4])
             if flag & 4:
                 if len(self.inbuffer) < skip + 2:
                     return b""
@@ -45,13 +56,20 @@ class GzipDecompressor(object):
                 return b""
             data = self.inbuffer[skip:]
             self.inbuffer = b""
+            self.sawheader = True
             self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
 
     @property
     def unused_data(self):
         if self.decompressor:
             return self.decompressor.unused_data
+        elif not self.sawheader:
+            return self.inbuffer
         else:
+            expect = struct.pack("<" + crc32_type + "L", self.crc, self.size)
+            if self.inbuffer.startswith(expect) and \
+                    self.inbuffer[len(expect):].replace(b"\0", b"") == b"":
+                return b""
             return self.inbuffer
 
     def flush(self):
@@ -67,11 +85,14 @@ class GzipDecompressor(object):
         new.inbuffer = self.inbuffer
         if self.decompressor:
             new.decompressor = self.decompressor.copy()
+        new.sawheader = self.sawheader
+        new.crc = self.crc
+        new.size = self.size
         return new
 
 class DecompressedStream(object):
-    """Turn a readable file-like into a decompressed file-like. Te only part
-    of being file-like consists of the read(size) method in both cases."""
+    """Turn a readable file-like into a decompressed file-like. It supports
+    read(optional length), tell, seek(forward only) and close."""
     blocksize = 65536
 
     def __init__(self, fileobj, decompressor):
@@ -84,20 +105,81 @@ class DecompressedStream(object):
         self.fileobj = fileobj
         self.decompressor = decompressor
         self.buff = b""
+        self.pos = 0
+        self.closed = False
 
     def read(self, length=None):
+        assert not self.closed
         data = True
         while True:
             if length is not None and len(self.buff) >= length:
                 ret = self.buff[:length]
                 self.buff = self.buff[length:]
-                return ret
+                break
             elif not data: # read EOF in last iteration
                 ret = self.buff
                 self.buff = b""
-                return ret
+                break
             data = self.fileobj.read(self.blocksize)
             if data:
                 self.buff += self.decompressor.decompress(data)
             else:
                 self.buff += self.decompressor.flush()
+        self.pos += len(ret)
+        return ret
+
+    def tell(self):
+        assert not self.closed
+        return self.pos
+
+    def seek(self, pos):
+        """Forward seeks by absolute position only."""
+        assert not self.closed
+        if pos < self.pos:
+            raise ValueError("negative seek not allowed on decompressed stream")
+        while True:
+            left = pos - self.pos
+            # Reading self.buff entirely avoids string concatenation.
+            size = len(self.buff) or self.blocksize
+            if left > size:
+                self.read(size)
+            else:
+                self.read(left)
+                return
+
+    def close(self):
+        if not self.closed:
+            self.fileobj.close()
+            self.fileobj = None
+            self.decompressor = None
+            self.buff = b""
+            self.closed = True
+
+decompressors = {
+    '.gz':   GzipDecompressor,
+    '.bz2':  bz2.BZ2Decompressor,
+    '.lzma': lzma.LZMADecompressor,
+    '.xz':   lzma.LZMADecompressor,
+}
+
+def decompress(filelike, extension):
+    """Decompress a stream according to its extension.
+    @param filelike: is a read-only byte-stream. It must support read(size) and
+                     close().
+    @param extension: permitted values are "", ".gz", ".bz2", ".lzma", and
+                      ".xz"
+    @type extension: str
+    @returns: a read-only byte-stream with the decompressed contents of the
+              original filelike. It supports read(size) and close(). If the
+              original supports seek(pos) and tell(), then it also supports
+              those.
+    @raises ValueError: on unkown extensions
+    """
+    if not extension:
+        return filelike
+    try:
+        decompressor = decompressors[extension]
+    except KeyError:
+        raise ValueError("unknown compression format with extension %r" %
+                         extension)
+    return DecompressedStream(filelike, decompressor())