autoimport: don't fork for readyaml
[~helmut/debian-dedup.git] / importpkg.py
index dc581e7..6e22b54 100755 (executable)
@@ -1,72 +1,24 @@
 #!/usr/bin/python
+"""This tool reads a debian package from stdin and emits a yaml stream on
+stdout.  It does not access a database. Therefore it can be run in parallel and
+on multiple machines. The generated yaml conatins multiple documents. The first
+document contains package metadata. Then a document is emitted for each file.
+And finally a document consisting of the string "commit" is emitted."""
 
 import hashlib
-import sqlite3
-import struct
 import sys
 import tarfile
 import zlib
 
-from debian.debian_support import version_compare
 from debian import deb822
 import lzma
+import yaml
 
+from dedup.arreader import ArReader
 from dedup.hashing import HashBlacklist, DecompressedHash, SuppressingHash, hash_file
 from dedup.compression import GzipDecompressor, DecompressedStream
 from dedup.image import ImageHash
 
-class ArReader(object):
-    global_magic = b"!<arch>\n"
-    file_magic = b"`\n"
-
-    def __init__(self, fileobj):
-        self.fileobj = fileobj
-        self.remaining = None
-        self.padding = 0
-
-    def skip(self, length):
-        while length:
-            data = self.fileobj.read(min(4096, length))
-            if not data:
-                raise ValueError("archive truncated")
-            length -= len(data)
-
-    def read_magic(self):
-        data = self.fileobj.read(len(self.global_magic))
-        if data != self.global_magic:
-            raise ValueError("ar global header not found")
-        self.remaining = 0
-
-    def read_entry(self):
-        self.skip_current_entry()
-        if self.padding:
-            if self.fileobj.read(1) != '\n':
-                raise ValueError("missing ar padding")
-            self.padding = 0
-        file_header = self.fileobj.read(60)
-        if not file_header:
-            raise EOFError("end of archive found")
-        parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
-        parts = [p.rstrip(" ") for p in parts]
-        if parts.pop() != self.file_magic:
-            raise ValueError("ar file header not found")
-        self.remaining = int(parts[5])
-        self.padding = self.remaining % 2
-        return parts[0] # name
-
-    def skip_current_entry(self):
-        self.skip(self.remaining)
-        self.remaining = 0
-
-    def read(self, length=None):
-        if length is None:
-            length = self.remaining
-        else:
-            length = min(self.remaining, length)
-        data = self.fileobj.read(length)
-        self.remaining -= len(data)
-        return data
-
 class MultiHash(object):
     def __init__(self, *hashes):
         self.hashes = hashes
@@ -102,21 +54,23 @@ def get_hashes(tar):
             continue
         hasher = MultiHash(sha512_nontrivial(), gziphash(), imagehash())
         hasher = hash_file(hasher, tar.extractfile(elem))
+        hashes = {}
         for hashobj in hasher.hashes:
             hashvalue = hashobj.hexdigest()
             if hashvalue:
-                yield (elem.name, elem.size, hashobj.name, hashvalue)
+                hashes[hashobj.name] = hashvalue
+        yield (elem.name, elem.size, hashes)
 
-def process_package(db, filelike):
-    cur = db.cursor()
+def process_package(filelike):
     af = ArReader(filelike)
     af.read_magic()
     state = "start"
-    while True:
+    while state not in ("finished", "skipped"):
         try:
             name = af.read_entry()
         except EOFError:
-            break
+            if state != "finished":
+                raise ValueError("data.tar not found")
         if name == "control.tar.gz":
             if state != "start":
                 raise ValueError("unexpected control.tar.gz")
@@ -131,28 +85,18 @@ def process_package(db, filelike):
                 control = tf.extractfile(elem).read()
                 control = deb822.Packages(control)
                 package = control["package"].encode("ascii")
+                try:
+                    source = control["source"].encode("ascii").split()[0]
+                except KeyError:
+                    source = package
                 version = control["version"].encode("ascii")
                 architecture = control["architecture"].encode("ascii")
 
-                cur.execute("SELECT version FROM package WHERE package = ?;",
-                            (package,))
-                row = cur.fetchone()
-                if row and version_compare(row[0], version) > 0:
-                    return # already seen a newer package
-
-                cur.execute("DELETE FROM package WHERE package = ?;",
-                            (package,))
-                cur.execute("DELETE FROM content WHERE package = ?;",
-                            (package,))
-                cur.execute("INSERT INTO package (package, version, architecture) VALUES (?, ?, ?);",
-                            (package, version, architecture))
                 depends = control.relations.get("depends", [])
                 depends = set(dep[0]["name"].encode("ascii")
                               for dep in depends if len(dep) == 1)
-                cur.execute("DELETE FROM dependency WHERE package = ?;",
-                            (package,))
-                cur.executemany("INSERT INTO dependency (package, required) VALUES (?, ?);",
-                                ((package, dep) for dep in depends))
+                yield dict(package=package, source=source, version=version,
+                           architecture=architecture, depends=depends)
                 break
             continue
         elif name == "data.tar.gz":
@@ -166,21 +110,18 @@ def process_package(db, filelike):
             continue
         if state != "control_file":
             raise ValueError("missing control file")
-        for name, size, function, hexhash in get_hashes(tf):
+        for name, size, hashes in get_hashes(tf):
             try:
                 name = name.decode("utf8")
             except UnicodeDecodeError:
                 print("warning: skipping filename with encoding error")
                 continue # skip files with non-utf8 encoding for now
-            cur.execute("INSERT INTO content (package, filename, size, function, hash) VALUES (?, ?, ?, ?, ?);",
-                        (package, name, size, function, hexhash))
-        db.commit()
-        return
-    raise ValueError("data.tar not found")
+            yield dict(name=name, size=size, hashes=hashes)
+        state = "finished"
+        yield "commit"
 
 def main():
-    db = sqlite3.connect("test.sqlite3")
-    process_package(db, sys.stdin)
+    yaml.safe_dump_all(process_package(sys.stdin), sys.stdout)
 
 if __name__ == "__main__":
     main()