split content table to a hash table
[~helmut/debian-dedup.git] / importpkg.py
1 #!/usr/bin/python
2
3 import hashlib
4 import sqlite3
5 import struct
6 import sys
7 import tarfile
8 import zlib
9
10 from debian.debian_support import version_compare
11 from debian import deb822
12 import lzma
13
14 from dedup.hashing import HashBlacklist, DecompressedHash, SuppressingHash, hash_file
15 from dedup.compression import GzipDecompressor, DecompressedStream
16 from dedup.image import ImageHash
17
18 class ArReader(object):
19     global_magic = b"!<arch>\n"
20     file_magic = b"`\n"
21
22     def __init__(self, fileobj):
23         self.fileobj = fileobj
24         self.remaining = None
25         self.padding = 0
26
27     def skip(self, length):
28         while length:
29             data = self.fileobj.read(min(4096, length))
30             if not data:
31                 raise ValueError("archive truncated")
32             length -= len(data)
33
34     def read_magic(self):
35         data = self.fileobj.read(len(self.global_magic))
36         if data != self.global_magic:
37             raise ValueError("ar global header not found")
38         self.remaining = 0
39
40     def read_entry(self):
41         self.skip_current_entry()
42         if self.padding:
43             if self.fileobj.read(1) != '\n':
44                 raise ValueError("missing ar padding")
45             self.padding = 0
46         file_header = self.fileobj.read(60)
47         if not file_header:
48             raise EOFError("end of archive found")
49         parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
50         parts = [p.rstrip(" ") for p in parts]
51         if parts.pop() != self.file_magic:
52             raise ValueError("ar file header not found")
53         self.remaining = int(parts[5])
54         self.padding = self.remaining % 2
55         return parts[0] # name
56
57     def skip_current_entry(self):
58         self.skip(self.remaining)
59         self.remaining = 0
60
61     def read(self, length=None):
62         if length is None:
63             length = self.remaining
64         else:
65             length = min(self.remaining, length)
66         data = self.fileobj.read(length)
67         self.remaining -= len(data)
68         return data
69
70 class MultiHash(object):
71     def __init__(self, *hashes):
72         self.hashes = hashes
73
74     def update(self, data):
75         for hasher in self.hashes:
76             hasher.update(data)
77
78 boring_sha512_hashes = set((
79     # ""
80     "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
81     # "\n"
82     "be688838ca8686e5c90689bf2ab585cef1137c999b48c70b92f67a5c34dc15697b5d11c982ed6d71be1e1e7f7b4e0733884aa97c3f7a339a8ed03577cf74be09"))
83
84 def sha512_nontrivial():
85     return HashBlacklist(hashlib.sha512(), boring_sha512_hashes)
86
87 def gziphash():
88     hashobj = DecompressedHash(GzipDecompressor(), hashlib.sha512())
89     hashobj = SuppressingHash(hashobj, (ValueError, zlib.error))
90     hashobj.name = "gzip_sha512"
91     return HashBlacklist(hashobj, boring_sha512_hashes)
92
93 def imagehash():
94     hashobj = ImageHash(hashlib.sha512())
95     hashobj = SuppressingHash(hashobj, (ValueError,))
96     hashobj.name = "image_sha512"
97     return hashobj
98
99 def get_hashes(tar):
100     for elem in tar:
101         if not elem.isreg(): # excludes hard links as well
102             continue
103         hasher = MultiHash(sha512_nontrivial(), gziphash(), imagehash())
104         hasher = hash_file(hasher, tar.extractfile(elem))
105         hashes = {}
106         for hashobj in hasher.hashes:
107             hashvalue = hashobj.hexdigest()
108             if hashvalue:
109                 hashes[hashobj.name] = hashvalue
110         yield (elem.name, elem.size, hashes)
111
112 def process_package(db, filelike):
113     cur = db.cursor()
114     cur.execute("PRAGMA foreign_keys = ON;")
115     af = ArReader(filelike)
116     af.read_magic()
117     state = "start"
118     while True:
119         try:
120             name = af.read_entry()
121         except EOFError:
122             break
123         if name == "control.tar.gz":
124             if state != "start":
125                 raise ValueError("unexpected control.tar.gz")
126             state = "control"
127             tf = tarfile.open(fileobj=af, mode="r|gz")
128             for elem in tf:
129                 if elem.name != "./control":
130                     continue
131                 if state != "control":
132                     raise ValueError("duplicate control file")
133                 state = "control_file"
134                 control = tf.extractfile(elem).read()
135                 control = deb822.Packages(control)
136                 package = control["package"].encode("ascii")
137                 try:
138                     source = control["source"].encode("ascii").split()[0]
139                 except KeyError:
140                     source = package
141                 version = control["version"].encode("ascii")
142                 architecture = control["architecture"].encode("ascii")
143
144                 cur.execute("SELECT version FROM package WHERE package = ?;",
145                             (package,))
146                 row = cur.fetchone()
147                 if row and version_compare(row[0], version) > 0:
148                     return # already seen a newer package
149
150                 cur.execute("DELETE FROM content WHERE package = ?;",
151                             (package,))
152                 cur.execute("INSERT OR REPLACE INTO package (package, version, architecture, source) VALUES (?, ?, ?, ?);",
153                             (package, version, architecture, source))
154                 depends = control.relations.get("depends", [])
155                 depends = set(dep[0]["name"].encode("ascii")
156                               for dep in depends if len(dep) == 1)
157                 cur.execute("DELETE FROM dependency WHERE package = ?;",
158                             (package,))
159                 cur.executemany("INSERT INTO dependency (package, required) VALUES (?, ?);",
160                                 ((package, dep) for dep in depends))
161                 break
162             continue
163         elif name == "data.tar.gz":
164             tf = tarfile.open(fileobj=af, mode="r|gz")
165         elif name == "data.tar.bz2":
166             tf = tarfile.open(fileobj=af, mode="r|bz2")
167         elif name == "data.tar.xz":
168             zf = DecompressedStream(af, lzma.LZMADecompressor())
169             tf = tarfile.open(fileobj=zf, mode="r|")
170         else:
171             continue
172         if state != "control_file":
173             raise ValueError("missing control file")
174         for name, size, hashes in get_hashes(tf):
175             try:
176                 name = name.decode("utf8")
177             except UnicodeDecodeError:
178                 print("warning: skipping filename with encoding error")
179                 continue # skip files with non-utf8 encoding for now
180             cur.execute("INSERT INTO content (package, filename, size) VALUES (?, ?, ?);",
181                         (package, name, size))
182             cid = cur.lastrowid
183             cur.executemany("INSERT INTO hash (cid, function, hash) VALUES (?, ?, ?);",
184                             ((cid, func, hexhash) for func, hexhash in hashes.items()))
185         db.commit()
186         return
187     raise ValueError("data.tar not found")
188
189 def main():
190     db = sqlite3.connect("test.sqlite3")
191     process_package(db, sys.stdin)
192
193 if __name__ == "__main__":
194     main()