many improvements
[~helmut/debian-dedup.git] / test.py
1 #!/usr/bin/python
2 """
3 CREATE TABLE package (package TEXT PRIMARY KEY, version TEXT, architecture TEXT);
4 CREATE TABLE content (package TEXT, filename TEXT, size INTEGER, function TEXT, hash TEXT, FOREIGN KEY (package) REFERENCES package(package));
5 CREATE INDEX content_package_index ON content (package);
6 CREATE INDEX content_hash_index ON content (hash);
7 """
8
9 import hashlib
10 import os
11 import re
12 import sqlite3
13 import struct
14 import sys
15 import tarfile
16 import zlib
17
18 import apt_pkg
19 import lzma
20
21 apt_pkg.init()
22
23 class ArReader(object):
24     global_magic = b"!<arch>\n"
25     file_magic = b"`\n"
26
27     def __init__(self, fileobj, membertest):
28         self.fileobj = fileobj
29         self.membertest = membertest
30         self.remaining = None
31
32     def skip(self, length):
33         while length:
34             data = self.fileobj.read(min(4096, length))
35             if not data:
36                 raise ValueError("archive truncated")
37             length -= len(data)
38
39     def skiptillmember(self):
40         data = self.fileobj.read(len(self.global_magic))
41         if data != self.global_magic:
42             raise ValueError("ar global header not found")
43         while True:
44             file_header = self.fileobj.read(60)
45             if not file_header:
46                 raise ValueError("end of archive found")
47             parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
48             parts = [p.rstrip(" ") for p in parts]
49             if parts.pop() != self.file_magic:
50                 print(repr(file_header))
51                 raise ValueError("ar file header not found")
52             name = parts[0]
53             length = int(parts[5])
54             if self.membertest(name):
55                 self.remaining = length
56                 return name, length
57             self.skip(length + length % 2)
58
59     def read(self, length=None):
60         if length is None:
61             length = self.remaining
62         else:
63             length = min(self.remaining, length)
64         data = self.fileobj.read(length)
65         self.remaining -= len(data)
66         return data
67
68     def close(self):
69         self.fileobj.close()
70
71 class XzStream(object):
72     blocksize = 65536
73
74     def __init__(self, fileobj):
75         self.fileobj = fileobj
76         self.decomp = lzma.LZMADecompressor()
77         self.buff = b""
78
79     def read(self, length):
80         data = True
81         while True:
82             if len(self.buff) >= length:
83                 ret = self.buff[:length]
84                 self.buff = self.buff[length:]
85                 return ret
86             elif not data: # read EOF in last iteration
87                 ret = self.buff
88                 self.buff = b""
89                 return ret
90             data = self.fileobj.read(self.blocksize)
91             if data:
92                 self.buff += self.decomp.decompress(data)
93             else:
94                 self.buff += self.decomp.flush()
95
96 class MultiHash(object):
97     def __init__(self, *hashes):
98         self.hashes = hashes
99
100     def update(self, data):
101         for hasher in self.hashes:
102             hasher.update(data)
103
104 class HashBlacklist(object):
105     def __init__(self, hasher, blacklist=set()):
106         self.hasher = hasher
107         self.blacklist = blacklist
108         self.update = self.hasher.update
109         self.name = hasher.name
110
111     def hexdigest(self):
112         digest = self.hasher.hexdigest()
113         if digest in self.blacklist:
114             return None
115         return digest
116
117 class GzipDecompressor(object):
118     def __init__(self):
119         self.inbuffer = b""
120         self.decompressor = None # zlib.decompressobj(-zlib.MAX_WBITS)
121
122     def decompress(self, data):
123         if self.decompressor:
124             data = self.decompressor.decompress(data)
125             if not self.decompressor.unused_data:
126                 return data
127             unused_data = self.decompressor.unused_data
128             self.decompressor = None
129             return data + self.decompress(unused_data)
130         self.inbuffer += data
131         skip = 10
132         if len(self.inbuffer) < skip:
133             return b""
134         if not self.inbuffer.startswith(b"\037\213\010"):
135             raise ValueError("gzip magic not found")
136         flag = ord(self.inbuffer[3])
137         if flag & 4:
138             if len(self.inbuffer) < skip + 2:
139                 return b""
140             length, = struct.unpack("<H", self.inbuffer[skip:skip+2])
141             skip += 2 + length
142         for field in (8, 16):
143             if flag & field:
144                 length = self.inbuffer.find("\0", skip)
145                 if length < 0:
146                     return b""
147                 skip = length + 1
148         if flag & 2:
149             skip += 2
150         if len(self.inbuffer) < skip:
151             return b""
152         data = self.inbuffer[skip:]
153         self.inbuffer = b""
154         self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
155         return self.decompress(data)
156
157     @property
158     def unused_data(self):
159         if self.decompressor:
160             return self.decompressor.unused_data
161         else:
162             return self.inbuffer
163
164     def flush(self):
165         if not self.decompressor:
166             return b""
167         return self.decompressor.flush()
168
169     def copy(self):
170         new = GzipDecompressor()
171         new.inbuffer = self.inbuffer
172         if self.decompressor:
173             new.decompressor = self.decompressor.copy()
174         return new
175
176 class DecompressedHash(object):
177     def __init__(self, decompressor, hashobj):
178         self.decompressor = decompressor
179         self.hashobj = hashobj
180
181     def update(self, data):
182         self.hashobj.update(self.decompressor.decompress(data))
183
184     def hexdigest(self):
185         if not hasattr(self.decompressor, "flush"):
186             return self.hashobj.hexdigest()
187         tmpdecomp = self.decompressor.copy()
188         data = tmpdecomp.flush()
189         tmphash = self.hashobj.copy()
190         tmphash.update(data)
191         return tmphash.hexdigest()
192
193 class SuppressingHash(object):
194     def __init__(self, hashobj, exceptions=()):
195         self.hashobj = hashobj
196         self.exceptions = exceptions
197
198     def update(self, data):
199         if self.hashobj:
200             try:
201                 self.hashobj.update(data)
202             except self.exceptions:
203                 self.hashobj = None
204
205     def hexdigest(self):
206         if self.hashobj:
207             try:
208                 return self.hashobj.hexdigest()
209             except self.exceptions:
210                 self.hashobj = None
211         return None
212
213 def hash_file(hashobj, filelike, blocksize=65536):
214     data = filelike.read(blocksize)
215     while data:
216         hashobj.update(data)
217         data = filelike.read(blocksize)
218     return hashobj
219
220 boring_sha512_hashes = set((
221     # ""
222     "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
223     # "\n"
224     "be688838ca8686e5c90689bf2ab585cef1137c999b48c70b92f67a5c34dc15697b5d11c982ed6d71be1e1e7f7b4e0733884aa97c3f7a339a8ed03577cf74be09"))
225
226 def sha512_nontrivial():
227     return HashBlacklist(hashlib.sha512(), boring_sha512_hashes)
228
229 def gziphash():
230     hashobj = DecompressedHash(GzipDecompressor(), hashlib.sha512())
231     hashobj = SuppressingHash(hashobj, (ValueError, zlib.error))
232     hashobj.name = "gzip_sha512"
233     return HashBlacklist(hashobj, boring_sha512_hashes)
234
235 def get_hashes(filelike):
236     af = ArReader(filelike, lambda name: name.startswith("data.tar"))
237     name, membersize = af.skiptillmember()
238     if name == "data.tar.gz":
239         tf = tarfile.open(fileobj=af, mode="r|gz")
240     elif name == "data.tar.bz2":
241         tf = tarfile.open(fileobj=af, mode="r|bz2")
242     elif name == "data.tar.xz":
243         zf = XzStream(af)
244         tf = tarfile.open(fileobj=zf, mode="r|")
245     else:
246         raise ValueError("unsupported compression %r" % name)
247     for elem in tf:
248         if elem.size == 0: # boring
249             continue
250         if not elem.isreg(): # excludes hard links as well
251             continue
252         hasher = MultiHash(sha512_nontrivial(), gziphash())
253         hasher = hash_file(hasher, tf.extractfile(elem))
254         for hashobj in hasher.hashes:
255             hashvalue = hashobj.hexdigest()
256             if hashvalue:
257                 yield (elem.name, elem.size, hashobj.name, hashvalue)
258
259 def main():
260     filename = sys.argv[1]
261     match = re.match("(?:.*/)?(?P<name>[^_]+)_(?P<version>[^_]+)_(?P<architecture>[^_.]+)\\.deb$", filename)
262     package, version, architecture = match.groups()
263     db = sqlite3.connect("test.sqlite3")
264     cur = db.cursor()
265
266     cur.execute("SELECT version FROM package WHERE package = ?;", (package,))
267     versions = [tpl[0] for tpl in cur.fetchall()]
268     versions.append(version)
269     versions.sort(cmp=apt_pkg.version_compare)
270     if versions[-1] != version:
271         return # not the newest version
272
273     cur.execute("DELETE FROM package WHERE package = ?;", (package,))
274     cur.execute("DELETE FROM content WHERE package = ?;", (package,))
275     cur.execute("INSERT INTO package (package, version, architecture) VALUES (?, ?, ?);",
276                 (package, version, architecture))
277     with open(filename) as pkg:
278         for name, size, function, hexhash in get_hashes(pkg):
279             name = name.decode("utf8")
280             cur.execute("INSERT INTO content (package, filename, size, function, hash) VALUES (?, ?, ?, ?, ?);",
281                     (package, name, size, function, hexhash))
282     db.commit()
283
284 if __name__ == "__main__":
285     main()