rename test.py to importpkg.py
[~helmut/debian-dedup.git] / importpkg.py
1 #!/usr/bin/python
2 """
3 CREATE TABLE package (package TEXT PRIMARY KEY, version TEXT, architecture TEXT);
4 CREATE TABLE content (package TEXT, filename TEXT, size INTEGER, function TEXT, hash TEXT, FOREIGN KEY (package) REFERENCES package(package));
5 CREATE TABLE dependency (package TEXT, required TEXT, FOREIGN KEY (package) REFERENCES package(package), FOREIGN KEY (required) REFERENCES package(package));
6 CREATE INDEX content_package_index ON content (package);
7 CREATE INDEX content_hash_index ON content (hash);
8 """
9
10 import hashlib
11 import sqlite3
12 import struct
13 import sys
14 import tarfile
15 import zlib
16
17 from debian.debian_support import version_compare
18 from debian import deb822
19 import lzma
20
21 class ArReader(object):
22     global_magic = b"!<arch>\n"
23     file_magic = b"`\n"
24
25     def __init__(self, fileobj):
26         self.fileobj = fileobj
27         self.remaining = None
28         self.padding = 0
29
30     def skip(self, length):
31         while length:
32             data = self.fileobj.read(min(4096, length))
33             if not data:
34                 raise ValueError("archive truncated")
35             length -= len(data)
36
37     def read_magic(self):
38         data = self.fileobj.read(len(self.global_magic))
39         if data != self.global_magic:
40             raise ValueError("ar global header not found")
41         self.remaining = 0
42
43     def read_entry(self):
44         self.skip_current_entry()
45         if self.padding:
46             if self.fileobj.read(1) != '\n':
47                 raise ValueError("missing ar padding")
48             self.padding = 0
49         file_header = self.fileobj.read(60)
50         if not file_header:
51             raise EOFError("end of archive found")
52         parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
53         parts = [p.rstrip(" ") for p in parts]
54         if parts.pop() != self.file_magic:
55             raise ValueError("ar file header not found")
56         self.remaining = int(parts[5])
57         self.padding = self.remaining % 2
58         return parts[0] # name
59
60     def skip_current_entry(self):
61         self.skip(self.remaining)
62         self.remaining = 0
63
64     def read(self, length=None):
65         if length is None:
66             length = self.remaining
67         else:
68             length = min(self.remaining, length)
69         data = self.fileobj.read(length)
70         self.remaining -= len(data)
71         return data
72
73 class XzStream(object):
74     blocksize = 65536
75
76     def __init__(self, fileobj):
77         self.fileobj = fileobj
78         self.decomp = lzma.LZMADecompressor()
79         self.buff = b""
80
81     def read(self, length):
82         data = True
83         while True:
84             if len(self.buff) >= length:
85                 ret = self.buff[:length]
86                 self.buff = self.buff[length:]
87                 return ret
88             elif not data: # read EOF in last iteration
89                 ret = self.buff
90                 self.buff = b""
91                 return ret
92             data = self.fileobj.read(self.blocksize)
93             if data:
94                 self.buff += self.decomp.decompress(data)
95             else:
96                 self.buff += self.decomp.flush()
97
98 class MultiHash(object):
99     def __init__(self, *hashes):
100         self.hashes = hashes
101
102     def update(self, data):
103         for hasher in self.hashes:
104             hasher.update(data)
105
106 class HashBlacklist(object):
107     def __init__(self, hasher, blacklist=set()):
108         self.hasher = hasher
109         self.blacklist = blacklist
110         self.update = self.hasher.update
111         self.name = hasher.name
112
113     def hexdigest(self):
114         digest = self.hasher.hexdigest()
115         if digest in self.blacklist:
116             return None
117         return digest
118
119 class GzipDecompressor(object):
120     def __init__(self):
121         self.inbuffer = b""
122         self.decompressor = None # zlib.decompressobj(-zlib.MAX_WBITS)
123
124     def decompress(self, data):
125         if self.decompressor:
126             data = self.decompressor.decompress(data)
127             if not self.decompressor.unused_data:
128                 return data
129             unused_data = self.decompressor.unused_data
130             self.decompressor = None
131             return data + self.decompress(unused_data)
132         self.inbuffer += data
133         skip = 10
134         if len(self.inbuffer) < skip:
135             return b""
136         if not self.inbuffer.startswith(b"\037\213\010"):
137             raise ValueError("gzip magic not found")
138         flag = ord(self.inbuffer[3])
139         if flag & 4:
140             if len(self.inbuffer) < skip + 2:
141                 return b""
142             length, = struct.unpack("<H", self.inbuffer[skip:skip+2])
143             skip += 2 + length
144         for field in (8, 16):
145             if flag & field:
146                 length = self.inbuffer.find("\0", skip)
147                 if length < 0:
148                     return b""
149                 skip = length + 1
150         if flag & 2:
151             skip += 2
152         if len(self.inbuffer) < skip:
153             return b""
154         data = self.inbuffer[skip:]
155         self.inbuffer = b""
156         self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
157         return self.decompress(data)
158
159     @property
160     def unused_data(self):
161         if self.decompressor:
162             return self.decompressor.unused_data
163         else:
164             return self.inbuffer
165
166     def flush(self):
167         if not self.decompressor:
168             return b""
169         return self.decompressor.flush()
170
171     def copy(self):
172         new = GzipDecompressor()
173         new.inbuffer = self.inbuffer
174         if self.decompressor:
175             new.decompressor = self.decompressor.copy()
176         return new
177
178 class DecompressedHash(object):
179     def __init__(self, decompressor, hashobj):
180         self.decompressor = decompressor
181         self.hashobj = hashobj
182
183     def update(self, data):
184         self.hashobj.update(self.decompressor.decompress(data))
185
186     def hexdigest(self):
187         if not hasattr(self.decompressor, "flush"):
188             return self.hashobj.hexdigest()
189         tmpdecomp = self.decompressor.copy()
190         data = tmpdecomp.flush()
191         tmphash = self.hashobj.copy()
192         tmphash.update(data)
193         return tmphash.hexdigest()
194
195 class SuppressingHash(object):
196     def __init__(self, hashobj, exceptions=()):
197         self.hashobj = hashobj
198         self.exceptions = exceptions
199
200     def update(self, data):
201         if self.hashobj:
202             try:
203                 self.hashobj.update(data)
204             except self.exceptions:
205                 self.hashobj = None
206
207     def hexdigest(self):
208         if self.hashobj:
209             try:
210                 return self.hashobj.hexdigest()
211             except self.exceptions:
212                 self.hashobj = None
213         return None
214
215 def hash_file(hashobj, filelike, blocksize=65536):
216     data = filelike.read(blocksize)
217     while data:
218         hashobj.update(data)
219         data = filelike.read(blocksize)
220     return hashobj
221
222 boring_sha512_hashes = set((
223     # ""
224     "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
225     # "\n"
226     "be688838ca8686e5c90689bf2ab585cef1137c999b48c70b92f67a5c34dc15697b5d11c982ed6d71be1e1e7f7b4e0733884aa97c3f7a339a8ed03577cf74be09"))
227
228 def sha512_nontrivial():
229     return HashBlacklist(hashlib.sha512(), boring_sha512_hashes)
230
231 def gziphash():
232     hashobj = DecompressedHash(GzipDecompressor(), hashlib.sha512())
233     hashobj = SuppressingHash(hashobj, (ValueError, zlib.error))
234     hashobj.name = "gzip_sha512"
235     return HashBlacklist(hashobj, boring_sha512_hashes)
236
237 def get_hashes(tar):
238     for elem in tar:
239         if not elem.isreg(): # excludes hard links as well
240             continue
241         hasher = MultiHash(sha512_nontrivial(), gziphash())
242         hasher = hash_file(hasher, tar.extractfile(elem))
243         for hashobj in hasher.hashes:
244             hashvalue = hashobj.hexdigest()
245             if hashvalue:
246                 yield (elem.name, elem.size, hashobj.name, hashvalue)
247
248 def process_package(db, filelike):
249     cur = db.cursor()
250     af = ArReader(filelike)
251     af.read_magic()
252     state = "start"
253     while True:
254         try:
255             name = af.read_entry()
256         except EOFError:
257             break
258         if name == "control.tar.gz":
259             if state != "start":
260                 raise ValueError("unexpected control.tar.gz")
261             state = "control"
262             tf = tarfile.open(fileobj=af, mode="r|gz")
263             for elem in tf:
264                 if elem.name != "./control":
265                     continue
266                 if state != "control":
267                     raise ValueError("duplicate control file")
268                 state = "control_file"
269                 control = tf.extractfile(elem).read()
270                 control = deb822.Packages(control)
271                 package = control["package"].encode("ascii")
272                 version = control["version"].encode("ascii")
273                 architecture = control["architecture"].encode("ascii")
274
275                 cur.execute("SELECT version FROM package WHERE package = ?;",
276                             (package,))
277                 row = cur.fetchone()
278                 if row and version_compare(row[0], version) > 0:
279                     return # already seen a newer package
280
281                 cur.execute("DELETE FROM package WHERE package = ?;",
282                             (package,))
283                 cur.execute("DELETE FROM content WHERE package = ?;",
284                             (package,))
285                 cur.execute("INSERT INTO package (package, version, architecture) VALUES (?, ?, ?);",
286                             (package, version, architecture))
287                 depends = control.relations.get("depends", [])
288                 depends = set(dep[0]["name"].encode("ascii")
289                               for dep in depends if len(dep) == 1)
290                 cur.execute("DELETE FROM dependency WHERE package = ?;",
291                             (package,))
292                 cur.executemany("INSERT INTO dependency (package, required) VALUES (?, ?);",
293                                 ((package, dep) for dep in depends))
294                 break
295             continue
296         elif name == "data.tar.gz":
297             tf = tarfile.open(fileobj=af, mode="r|gz")
298         elif name == "data.tar.bz2":
299             tf = tarfile.open(fileobj=af, mode="r|bz2")
300         elif name == "data.tar.xz":
301             zf = XzStream(af)
302             tf = tarfile.open(fileobj=zf, mode="r|")
303         else:
304             continue
305         if state != "control_file":
306             raise ValueError("missing control file")
307         for name, size, function, hexhash in get_hashes(tf):
308             cur.execute("INSERT INTO content (package, filename, size, function, hash) VALUES (?, ?, ?, ?, ?);",
309                         (package, name.decode("utf8"), size, function, hexhash))
310         db.commit()
311         return
312     raise ValueError("data.tar not found")
313
314 def main():
315     db = sqlite3.connect("test.sqlite3")
316     process_package(db, sys.stdin)
317
318 if __name__ == "__main__":
319     main()