push more functionality into DebExtractor
[~helmut/debian-dedup.git] / importpkg.py
index e8cc2fa..933ec80 100755 (executable)
@@ -8,15 +8,15 @@ And finally a document consisting of the string "commit" is emitted."""
 import hashlib
 import optparse
 import sys
-import tarfile
 import zlib
 
 import yaml
 
-from dedup.debpkg import DebExtractor, process_control, get_tar_hashes
+from dedup.debpkg import DebExtractor, decodetarname, get_tar_hashes, \
+        process_control
 from dedup.hashing import DecompressedHash, SuppressingHash, HashedStream, \
         HashBlacklistContent
-from dedup.compression import GzipDecompressor, decompress
+from dedup.compression import GzipDecompressor
 from dedup.image import GIFHash, PNGHash
 
 boring_content = set(("", "\n"))
@@ -42,33 +42,6 @@ def gifhash():
     hashobj.name = "gif_sha512"
     return hashobj
 
-if sys.version_info.major >= 3:
-    def decompress_tar(filelike, extension):
-        filelike = decompress(filelike, extension.decode("ascii"))
-        return tarfile.open(fileobj=filelike, mode="r|")
-
-    def decodetarname(name):
-        """Decoded name of a tarinfo.
-        @raises UnicodeDecodeError:
-        """
-        try:
-            name.encode("utf8", "strict")
-        except UnicodeEncodeError as e:
-            if e.reason == "surrogates not allowed":
-                name.encode("utf8", "surrogateescape").decode("utf8", "strict")
-        return name
-else:
-    def decompress_tar(filelike, extension):
-        filelike = decompress(filelike, extension.decode("ascii"))
-        return tarfile.open(fileobj=filelike, mode="r|", encoding="utf8",
-                            errors="surrogateescape")
-
-    def decodetarname(name):
-        """Decoded name of a tarinfo.
-        @raises UnicodeDecodeError:
-        """
-        return name.decode("utf8")
-
 class ProcessingFinished(Exception):
     pass
 
@@ -76,40 +49,27 @@ class ImportpkgExtractor(DebExtractor):
     hash_functions = [sha512_nontrivial, gziphash, pnghash, gifhash]
 
     def __init__(self, callback):
-        self.state = "start"
+        DebExtractor.__init__(self)
         self.callback = callback
 
-    def handle_ar_member(self, name, filelike):
-        if name.startswith(b"control.tar"):
-            if self.state != "start":
-                raise ValueError("unexpected control.tar")
-            self.state = "control"
-            tf = decompress_tar(filelike, name[11:])
-            for elem in tf:
-                if elem.name not in ("./control", "control"):
-                    continue
-                if self.state != "control":
-                    raise ValueError("duplicate control file")
-                self.state = "control_file"
-                self.callback(process_control(tf.extractfile(elem).read()))
-                break
-        elif name.startswith(b"data.tar"):
-            if self.state != "control_file":
-                raise ValueError("missing control file")
-            self.state = "data"
-            tf = decompress_tar(filelike, name[8:])
-            for name, size, hashes in get_tar_hashes(tf, self.hash_functions):
-                try:
-                    name = decodetarname(name)
-                except UnicodeDecodeError:
-                    print("warning: skipping filename with encoding error")
-                    continue # skip files with non-utf8 encoding for now
-                self.callback(dict(name=name, size=size, hashes=hashes))
-            raise ProcessingFinished()
-
-    def handle_ar_end(self):
-        if self.state != "data":
-            raise ValueError("data.tar not found")
+    def handle_control_tar(self, tarfileobj):
+        for elem in tarfileobj:
+            if elem.name not in ("./control", "control"):
+                continue
+            self.callback(process_control(tarfileobj.extractfile(elem).read()))
+            return
+        raise ValueError("missing control file")
+
+    def handle_data_tar(self, tarfileobj):
+        for name, size, hashes in get_tar_hashes(tarfileobj,
+                                                 self.hash_functions):
+            try:
+                name = decodetarname(name)
+            except UnicodeDecodeError:
+                print("warning: skipping filename with encoding error")
+                continue # skip files with non-utf8 encoding for now
+            self.callback(dict(name=name, size=size, hashes=hashes))
+        raise ProcessingFinished()
 
 def main():
     parser = optparse.OptionParser()