cleanup
[~helmut/debian-dedup.git] / test.py
1 #!/usr/bin/python
2 """
3 CREATE TABLE package (package TEXT PRIMARY KEY, version TEXT, architecture TEXT);
4 CREATE TABLE content (package TEXT, filename TEXT, size INTEGER, function TEXT, hash TEXT, FOREIGN KEY (package) REFERENCES package(package));
5 CREATE INDEX content_package_index ON content (package);
6 CREATE INDEX content_hash_index ON content (hash);
7 """
8
9 import hashlib
10 import re
11 import sqlite3
12 import struct
13 import sys
14 import tarfile
15 import zlib
16
17 import apt_pkg
18 import lzma
19
20 apt_pkg.init()
21
22 class ArReader(object):
23     global_magic = b"!<arch>\n"
24     file_magic = b"`\n"
25
26     def __init__(self, fileobj, membertest):
27         self.fileobj = fileobj
28         self.membertest = membertest
29         self.remaining = None
30
31     def skip(self, length):
32         while length:
33             data = self.fileobj.read(min(4096, length))
34             if not data:
35                 raise ValueError("archive truncated")
36             length -= len(data)
37
38     def skiptillmember(self):
39         data = self.fileobj.read(len(self.global_magic))
40         if data != self.global_magic:
41             raise ValueError("ar global header not found")
42         while True:
43             file_header = self.fileobj.read(60)
44             if not file_header:
45                 raise ValueError("end of archive found")
46             parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
47             parts = [p.rstrip(" ") for p in parts]
48             if parts.pop() != self.file_magic:
49                 print(repr(file_header))
50                 raise ValueError("ar file header not found")
51             name = parts[0]
52             length = int(parts[5])
53             if self.membertest(name):
54                 self.remaining = length
55                 return name
56             self.skip(length + length % 2)
57
58     def read(self, length=None):
59         if length is None:
60             length = self.remaining
61         else:
62             length = min(self.remaining, length)
63         data = self.fileobj.read(length)
64         self.remaining -= len(data)
65         return data
66
67 class XzStream(object):
68     blocksize = 65536
69
70     def __init__(self, fileobj):
71         self.fileobj = fileobj
72         self.decomp = lzma.LZMADecompressor()
73         self.buff = b""
74
75     def read(self, length):
76         data = True
77         while True:
78             if len(self.buff) >= length:
79                 ret = self.buff[:length]
80                 self.buff = self.buff[length:]
81                 return ret
82             elif not data: # read EOF in last iteration
83                 ret = self.buff
84                 self.buff = b""
85                 return ret
86             data = self.fileobj.read(self.blocksize)
87             if data:
88                 self.buff += self.decomp.decompress(data)
89             else:
90                 self.buff += self.decomp.flush()
91
92 class MultiHash(object):
93     def __init__(self, *hashes):
94         self.hashes = hashes
95
96     def update(self, data):
97         for hasher in self.hashes:
98             hasher.update(data)
99
100 class HashBlacklist(object):
101     def __init__(self, hasher, blacklist=set()):
102         self.hasher = hasher
103         self.blacklist = blacklist
104         self.update = self.hasher.update
105         self.name = hasher.name
106
107     def hexdigest(self):
108         digest = self.hasher.hexdigest()
109         if digest in self.blacklist:
110             return None
111         return digest
112
113 class GzipDecompressor(object):
114     def __init__(self):
115         self.inbuffer = b""
116         self.decompressor = None # zlib.decompressobj(-zlib.MAX_WBITS)
117
118     def decompress(self, data):
119         if self.decompressor:
120             data = self.decompressor.decompress(data)
121             if not self.decompressor.unused_data:
122                 return data
123             unused_data = self.decompressor.unused_data
124             self.decompressor = None
125             return data + self.decompress(unused_data)
126         self.inbuffer += data
127         skip = 10
128         if len(self.inbuffer) < skip:
129             return b""
130         if not self.inbuffer.startswith(b"\037\213\010"):
131             raise ValueError("gzip magic not found")
132         flag = ord(self.inbuffer[3])
133         if flag & 4:
134             if len(self.inbuffer) < skip + 2:
135                 return b""
136             length, = struct.unpack("<H", self.inbuffer[skip:skip+2])
137             skip += 2 + length
138         for field in (8, 16):
139             if flag & field:
140                 length = self.inbuffer.find("\0", skip)
141                 if length < 0:
142                     return b""
143                 skip = length + 1
144         if flag & 2:
145             skip += 2
146         if len(self.inbuffer) < skip:
147             return b""
148         data = self.inbuffer[skip:]
149         self.inbuffer = b""
150         self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
151         return self.decompress(data)
152
153     @property
154     def unused_data(self):
155         if self.decompressor:
156             return self.decompressor.unused_data
157         else:
158             return self.inbuffer
159
160     def flush(self):
161         if not self.decompressor:
162             return b""
163         return self.decompressor.flush()
164
165     def copy(self):
166         new = GzipDecompressor()
167         new.inbuffer = self.inbuffer
168         if self.decompressor:
169             new.decompressor = self.decompressor.copy()
170         return new
171
172 class DecompressedHash(object):
173     def __init__(self, decompressor, hashobj):
174         self.decompressor = decompressor
175         self.hashobj = hashobj
176
177     def update(self, data):
178         self.hashobj.update(self.decompressor.decompress(data))
179
180     def hexdigest(self):
181         if not hasattr(self.decompressor, "flush"):
182             return self.hashobj.hexdigest()
183         tmpdecomp = self.decompressor.copy()
184         data = tmpdecomp.flush()
185         tmphash = self.hashobj.copy()
186         tmphash.update(data)
187         return tmphash.hexdigest()
188
189 class SuppressingHash(object):
190     def __init__(self, hashobj, exceptions=()):
191         self.hashobj = hashobj
192         self.exceptions = exceptions
193
194     def update(self, data):
195         if self.hashobj:
196             try:
197                 self.hashobj.update(data)
198             except self.exceptions:
199                 self.hashobj = None
200
201     def hexdigest(self):
202         if self.hashobj:
203             try:
204                 return self.hashobj.hexdigest()
205             except self.exceptions:
206                 self.hashobj = None
207         return None
208
209 def hash_file(hashobj, filelike, blocksize=65536):
210     data = filelike.read(blocksize)
211     while data:
212         hashobj.update(data)
213         data = filelike.read(blocksize)
214     return hashobj
215
216 boring_sha512_hashes = set((
217     # ""
218     "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
219     # "\n"
220     "be688838ca8686e5c90689bf2ab585cef1137c999b48c70b92f67a5c34dc15697b5d11c982ed6d71be1e1e7f7b4e0733884aa97c3f7a339a8ed03577cf74be09"))
221
222 def sha512_nontrivial():
223     return HashBlacklist(hashlib.sha512(), boring_sha512_hashes)
224
225 def gziphash():
226     hashobj = DecompressedHash(GzipDecompressor(), hashlib.sha512())
227     hashobj = SuppressingHash(hashobj, (ValueError, zlib.error))
228     hashobj.name = "gzip_sha512"
229     return HashBlacklist(hashobj, boring_sha512_hashes)
230
231 def get_hashes(filelike):
232     af = ArReader(filelike, lambda name: name.startswith("data.tar"))
233     name = af.skiptillmember()
234     if name == "data.tar.gz":
235         tf = tarfile.open(fileobj=af, mode="r|gz")
236     elif name == "data.tar.bz2":
237         tf = tarfile.open(fileobj=af, mode="r|bz2")
238     elif name == "data.tar.xz":
239         zf = XzStream(af)
240         tf = tarfile.open(fileobj=zf, mode="r|")
241     else:
242         raise ValueError("unsupported compression %r" % name)
243     for elem in tf:
244         if not elem.isreg(): # excludes hard links as well
245             continue
246         hasher = MultiHash(sha512_nontrivial(), gziphash())
247         hasher = hash_file(hasher, tf.extractfile(elem))
248         for hashobj in hasher.hashes:
249             hashvalue = hashobj.hexdigest()
250             if hashvalue:
251                 yield (elem.name, elem.size, hashobj.name, hashvalue)
252
253 def main():
254     filename = sys.argv[1]
255     match = re.match("(?:.*/)?(?P<name>[^_]+)_(?P<version>[^_]+)_(?P<architecture>[^_.]+)\\.deb$", filename)
256     package, version, architecture = match.groups()
257     db = sqlite3.connect("test.sqlite3")
258     cur = db.cursor()
259
260     cur.execute("SELECT version FROM package WHERE package = ?;", (package,))
261     versions = [tpl[0] for tpl in cur.fetchall()]
262     versions.append(version)
263     versions.sort(cmp=apt_pkg.version_compare)
264     if versions[-1] != version:
265         return # not the newest version
266
267     cur.execute("DELETE FROM package WHERE package = ?;", (package,))
268     cur.execute("DELETE FROM content WHERE package = ?;", (package,))
269     cur.execute("INSERT INTO package (package, version, architecture) VALUES (?, ?, ?);",
270                 (package, version, architecture))
271     with open(filename) as pkg:
272         for name, size, function, hexhash in get_hashes(pkg):
273             name = name.decode("utf8")
274             cur.execute("INSERT INTO content (package, filename, size, function, hash) VALUES (?, ?, ?, ?, ?);",
275                     (package, name, size, function, hexhash))
276     db.commit()
277
278 if __name__ == "__main__":
279     main()