4277d57e51bbc0c824cb16e09edca319f9e7d856
[~helmut/debian-dedup.git] / test.py
1 #!/usr/bin/python
2 """
3 CREATE TABLE package (package TEXT PRIMARY KEY, version TEXT, architecture TEXT);
4 CREATE TABLE content (package TEXT, filename TEXT, size INTEGER, function TEXT, hash TEXT, FOREIGN KEY (package) REFERENCES package(package));
5 CREATE INDEX content_package_index ON content (package);
6 CREATE INDEX content_hash_index ON content (hash);
7 """
8
9 import hashlib
10 import sqlite3
11 import struct
12 import sys
13 import tarfile
14 import zlib
15
16 from debian.debian_support import version_compare
17 from debian import deb822
18 import lzma
19
20 class ArReader(object):
21     global_magic = b"!<arch>\n"
22     file_magic = b"`\n"
23
24     def __init__(self, fileobj):
25         self.fileobj = fileobj
26         self.remaining = None
27         self.padding = 0
28
29     def skip(self, length):
30         while length:
31             data = self.fileobj.read(min(4096, length))
32             if not data:
33                 raise ValueError("archive truncated")
34             length -= len(data)
35
36     def read_magic(self):
37         data = self.fileobj.read(len(self.global_magic))
38         if data != self.global_magic:
39             raise ValueError("ar global header not found")
40         self.remaining = 0
41
42     def read_entry(self):
43         self.skip_current_entry()
44         if self.padding:
45             if self.fileobj.read(1) != '\n':
46                 raise ValueError("missing ar padding")
47             self.padding = 0
48         file_header = self.fileobj.read(60)
49         if not file_header:
50             raise EOFError("end of archive found")
51         parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
52         parts = [p.rstrip(" ") for p in parts]
53         if parts.pop() != self.file_magic:
54             raise ValueError("ar file header not found")
55         self.remaining = int(parts[5])
56         self.padding = self.remaining % 2
57         return parts[0] # name
58
59     def skip_current_entry(self):
60         self.skip(self.remaining)
61         self.remaining = 0
62
63     def read(self, length=None):
64         if length is None:
65             length = self.remaining
66         else:
67             length = min(self.remaining, length)
68         data = self.fileobj.read(length)
69         self.remaining -= len(data)
70         return data
71
72 class XzStream(object):
73     blocksize = 65536
74
75     def __init__(self, fileobj):
76         self.fileobj = fileobj
77         self.decomp = lzma.LZMADecompressor()
78         self.buff = b""
79
80     def read(self, length):
81         data = True
82         while True:
83             if len(self.buff) >= length:
84                 ret = self.buff[:length]
85                 self.buff = self.buff[length:]
86                 return ret
87             elif not data: # read EOF in last iteration
88                 ret = self.buff
89                 self.buff = b""
90                 return ret
91             data = self.fileobj.read(self.blocksize)
92             if data:
93                 self.buff += self.decomp.decompress(data)
94             else:
95                 self.buff += self.decomp.flush()
96
97 class MultiHash(object):
98     def __init__(self, *hashes):
99         self.hashes = hashes
100
101     def update(self, data):
102         for hasher in self.hashes:
103             hasher.update(data)
104
105 class HashBlacklist(object):
106     def __init__(self, hasher, blacklist=set()):
107         self.hasher = hasher
108         self.blacklist = blacklist
109         self.update = self.hasher.update
110         self.name = hasher.name
111
112     def hexdigest(self):
113         digest = self.hasher.hexdigest()
114         if digest in self.blacklist:
115             return None
116         return digest
117
118 class GzipDecompressor(object):
119     def __init__(self):
120         self.inbuffer = b""
121         self.decompressor = None # zlib.decompressobj(-zlib.MAX_WBITS)
122
123     def decompress(self, data):
124         if self.decompressor:
125             data = self.decompressor.decompress(data)
126             if not self.decompressor.unused_data:
127                 return data
128             unused_data = self.decompressor.unused_data
129             self.decompressor = None
130             return data + self.decompress(unused_data)
131         self.inbuffer += data
132         skip = 10
133         if len(self.inbuffer) < skip:
134             return b""
135         if not self.inbuffer.startswith(b"\037\213\010"):
136             raise ValueError("gzip magic not found")
137         flag = ord(self.inbuffer[3])
138         if flag & 4:
139             if len(self.inbuffer) < skip + 2:
140                 return b""
141             length, = struct.unpack("<H", self.inbuffer[skip:skip+2])
142             skip += 2 + length
143         for field in (8, 16):
144             if flag & field:
145                 length = self.inbuffer.find("\0", skip)
146                 if length < 0:
147                     return b""
148                 skip = length + 1
149         if flag & 2:
150             skip += 2
151         if len(self.inbuffer) < skip:
152             return b""
153         data = self.inbuffer[skip:]
154         self.inbuffer = b""
155         self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
156         return self.decompress(data)
157
158     @property
159     def unused_data(self):
160         if self.decompressor:
161             return self.decompressor.unused_data
162         else:
163             return self.inbuffer
164
165     def flush(self):
166         if not self.decompressor:
167             return b""
168         return self.decompressor.flush()
169
170     def copy(self):
171         new = GzipDecompressor()
172         new.inbuffer = self.inbuffer
173         if self.decompressor:
174             new.decompressor = self.decompressor.copy()
175         return new
176
177 class DecompressedHash(object):
178     def __init__(self, decompressor, hashobj):
179         self.decompressor = decompressor
180         self.hashobj = hashobj
181
182     def update(self, data):
183         self.hashobj.update(self.decompressor.decompress(data))
184
185     def hexdigest(self):
186         if not hasattr(self.decompressor, "flush"):
187             return self.hashobj.hexdigest()
188         tmpdecomp = self.decompressor.copy()
189         data = tmpdecomp.flush()
190         tmphash = self.hashobj.copy()
191         tmphash.update(data)
192         return tmphash.hexdigest()
193
194 class SuppressingHash(object):
195     def __init__(self, hashobj, exceptions=()):
196         self.hashobj = hashobj
197         self.exceptions = exceptions
198
199     def update(self, data):
200         if self.hashobj:
201             try:
202                 self.hashobj.update(data)
203             except self.exceptions:
204                 self.hashobj = None
205
206     def hexdigest(self):
207         if self.hashobj:
208             try:
209                 return self.hashobj.hexdigest()
210             except self.exceptions:
211                 self.hashobj = None
212         return None
213
214 def hash_file(hashobj, filelike, blocksize=65536):
215     data = filelike.read(blocksize)
216     while data:
217         hashobj.update(data)
218         data = filelike.read(blocksize)
219     return hashobj
220
221 boring_sha512_hashes = set((
222     # ""
223     "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
224     # "\n"
225     "be688838ca8686e5c90689bf2ab585cef1137c999b48c70b92f67a5c34dc15697b5d11c982ed6d71be1e1e7f7b4e0733884aa97c3f7a339a8ed03577cf74be09"))
226
227 def sha512_nontrivial():
228     return HashBlacklist(hashlib.sha512(), boring_sha512_hashes)
229
230 def gziphash():
231     hashobj = DecompressedHash(GzipDecompressor(), hashlib.sha512())
232     hashobj = SuppressingHash(hashobj, (ValueError, zlib.error))
233     hashobj.name = "gzip_sha512"
234     return HashBlacklist(hashobj, boring_sha512_hashes)
235
236 def get_hashes(tar):
237     for elem in tar:
238         if not elem.isreg(): # excludes hard links as well
239             continue
240         hasher = MultiHash(sha512_nontrivial(), gziphash())
241         hasher = hash_file(hasher, tar.extractfile(elem))
242         for hashobj in hasher.hashes:
243             hashvalue = hashobj.hexdigest()
244             if hashvalue:
245                 yield (elem.name, elem.size, hashobj.name, hashvalue)
246
247 def process_package(db, filelike):
248     cur = db.cursor()
249     af = ArReader(filelike)
250     af.read_magic()
251     state = "start"
252     while True:
253         try:
254             name = af.read_entry()
255         except EOFError:
256             break
257         if name == "control.tar.gz":
258             if state != "start":
259                 raise ValueError("unexpected control.tar.gz")
260             state = "control"
261             tf = tarfile.open(fileobj=af, mode="r|gz")
262             for elem in tf:
263                 if elem.name != "./control":
264                     continue
265                 if state != "control":
266                     raise ValueError("duplicate control file")
267                 state = "control_file"
268                 control = tf.extractfile(elem).read()
269                 control = deb822.Packages(control)
270                 package = control["package"].encode("ascii")
271                 version = control["version"].encode("ascii")
272                 architecture = control["architecture"].encode("ascii")
273
274                 cur.execute("SELECT version FROM package WHERE package = ?;",
275                             (package,))
276                 row = cur.fetchone()
277                 if row and version_compare(row[0], version) > 0:
278                     return # already seen a newer package
279
280                 cur.execute("DELETE FROM package WHERE package = ?;",
281                             (package,))
282                 cur.execute("DELETE FROM content WHERE package = ?;",
283                             (package,))
284                 cur.execute("INSERT INTO package (package, version, architecture) VALUES (?, ?, ?);",
285                             (package, version, architecture))
286                 depends = control.relations.get("depends", [])
287                 depends = set(dep[0]["name"].encode("ascii")
288                               for dep in depends if len(dep) == 1)
289                 break
290             continue
291         elif name == "data.tar.gz":
292             tf = tarfile.open(fileobj=af, mode="r|gz")
293         elif name == "data.tar.bz2":
294             tf = tarfile.open(fileobj=af, mode="r|bz2")
295         elif name == "data.tar.xz":
296             zf = XzStream(af)
297             tf = tarfile.open(fileobj=zf, mode="r|")
298         else:
299             continue
300         if state != "control_file":
301             raise ValueError("missing control file")
302         for name, size, function, hexhash in get_hashes(tf):
303             cur.execute("INSERT INTO content (package, filename, size, function, hash) VALUES (?, ?, ?, ?, ?);",
304                         (package, name.decode("utf8"), size, function, hexhash))
305         db.commit()
306         return
307     raise ValueError("data.tar not found")
308
309 def main():
310     db = sqlite3.connect("test.sqlite3")
311
312     with open(sys.argv[1]) as pkg:
313         process_package(db, pkg)
314
315 if __name__ == "__main__":
316     main()