3939ff4e0d8a60b9196632bf0e94cf4890bc2d44
[~helmut/debian-dedup.git] / test.py
1 #!/usr/bin/python
2 """
3 CREATE TABLE package (package TEXT PRIMARY KEY, version TEXT, architecture TEXT);
4 CREATE TABLE content (package TEXT, filename TEXT, size INTEGER, function TEXT, hash TEXT, FOREIGN KEY (package) REFERENCES package(package));
5 CREATE INDEX content_package_index ON content (package);
6 CREATE INDEX content_hash_index ON content (hash);
7 """
8
9 import hashlib
10 import re
11 import sqlite3
12 import struct
13 import sys
14 import tarfile
15 import zlib
16
17 import apt_pkg
18 import lzma
19
20 apt_pkg.init()
21
22 class ArReader(object):
23     global_magic = b"!<arch>\n"
24     file_magic = b"`\n"
25
26     def __init__(self, fileobj):
27         self.fileobj = fileobj
28         self.remaining = None
29         self.padding = 0
30
31     def skip(self, length):
32         while length:
33             data = self.fileobj.read(min(4096, length))
34             if not data:
35                 raise ValueError("archive truncated")
36             length -= len(data)
37
38     def read_magic(self):
39         data = self.fileobj.read(len(self.global_magic))
40         if data != self.global_magic:
41             raise ValueError("ar global header not found")
42         self.remaining = 0
43
44     def read_entry(self):
45         self.skip_current_entry()
46         if self.padding:
47             if self.fileobj.read(1) != '\n':
48                 raise ValueError("missing ar padding")
49             self.padding = 0
50         file_header = self.fileobj.read(60)
51         if not file_header:
52             raise EOFError("end of archive found")
53         parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
54         parts = [p.rstrip(" ") for p in parts]
55         if parts.pop() != self.file_magic:
56             raise ValueError("ar file header not found")
57         self.remaining = int(parts[5])
58         self.padding = self.remaining % 2
59         return parts[0] # name
60
61     def skip_current_entry(self):
62         self.skip(self.remaining)
63         self.remaining = 0
64
65     def read(self, length=None):
66         if length is None:
67             length = self.remaining
68         else:
69             length = min(self.remaining, length)
70         data = self.fileobj.read(length)
71         self.remaining -= len(data)
72         return data
73
74 class XzStream(object):
75     blocksize = 65536
76
77     def __init__(self, fileobj):
78         self.fileobj = fileobj
79         self.decomp = lzma.LZMADecompressor()
80         self.buff = b""
81
82     def read(self, length):
83         data = True
84         while True:
85             if len(self.buff) >= length:
86                 ret = self.buff[:length]
87                 self.buff = self.buff[length:]
88                 return ret
89             elif not data: # read EOF in last iteration
90                 ret = self.buff
91                 self.buff = b""
92                 return ret
93             data = self.fileobj.read(self.blocksize)
94             if data:
95                 self.buff += self.decomp.decompress(data)
96             else:
97                 self.buff += self.decomp.flush()
98
99 class MultiHash(object):
100     def __init__(self, *hashes):
101         self.hashes = hashes
102
103     def update(self, data):
104         for hasher in self.hashes:
105             hasher.update(data)
106
107 class HashBlacklist(object):
108     def __init__(self, hasher, blacklist=set()):
109         self.hasher = hasher
110         self.blacklist = blacklist
111         self.update = self.hasher.update
112         self.name = hasher.name
113
114     def hexdigest(self):
115         digest = self.hasher.hexdigest()
116         if digest in self.blacklist:
117             return None
118         return digest
119
120 class GzipDecompressor(object):
121     def __init__(self):
122         self.inbuffer = b""
123         self.decompressor = None # zlib.decompressobj(-zlib.MAX_WBITS)
124
125     def decompress(self, data):
126         if self.decompressor:
127             data = self.decompressor.decompress(data)
128             if not self.decompressor.unused_data:
129                 return data
130             unused_data = self.decompressor.unused_data
131             self.decompressor = None
132             return data + self.decompress(unused_data)
133         self.inbuffer += data
134         skip = 10
135         if len(self.inbuffer) < skip:
136             return b""
137         if not self.inbuffer.startswith(b"\037\213\010"):
138             raise ValueError("gzip magic not found")
139         flag = ord(self.inbuffer[3])
140         if flag & 4:
141             if len(self.inbuffer) < skip + 2:
142                 return b""
143             length, = struct.unpack("<H", self.inbuffer[skip:skip+2])
144             skip += 2 + length
145         for field in (8, 16):
146             if flag & field:
147                 length = self.inbuffer.find("\0", skip)
148                 if length < 0:
149                     return b""
150                 skip = length + 1
151         if flag & 2:
152             skip += 2
153         if len(self.inbuffer) < skip:
154             return b""
155         data = self.inbuffer[skip:]
156         self.inbuffer = b""
157         self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
158         return self.decompress(data)
159
160     @property
161     def unused_data(self):
162         if self.decompressor:
163             return self.decompressor.unused_data
164         else:
165             return self.inbuffer
166
167     def flush(self):
168         if not self.decompressor:
169             return b""
170         return self.decompressor.flush()
171
172     def copy(self):
173         new = GzipDecompressor()
174         new.inbuffer = self.inbuffer
175         if self.decompressor:
176             new.decompressor = self.decompressor.copy()
177         return new
178
179 class DecompressedHash(object):
180     def __init__(self, decompressor, hashobj):
181         self.decompressor = decompressor
182         self.hashobj = hashobj
183
184     def update(self, data):
185         self.hashobj.update(self.decompressor.decompress(data))
186
187     def hexdigest(self):
188         if not hasattr(self.decompressor, "flush"):
189             return self.hashobj.hexdigest()
190         tmpdecomp = self.decompressor.copy()
191         data = tmpdecomp.flush()
192         tmphash = self.hashobj.copy()
193         tmphash.update(data)
194         return tmphash.hexdigest()
195
196 class SuppressingHash(object):
197     def __init__(self, hashobj, exceptions=()):
198         self.hashobj = hashobj
199         self.exceptions = exceptions
200
201     def update(self, data):
202         if self.hashobj:
203             try:
204                 self.hashobj.update(data)
205             except self.exceptions:
206                 self.hashobj = None
207
208     def hexdigest(self):
209         if self.hashobj:
210             try:
211                 return self.hashobj.hexdigest()
212             except self.exceptions:
213                 self.hashobj = None
214         return None
215
216 def hash_file(hashobj, filelike, blocksize=65536):
217     data = filelike.read(blocksize)
218     while data:
219         hashobj.update(data)
220         data = filelike.read(blocksize)
221     return hashobj
222
223 boring_sha512_hashes = set((
224     # ""
225     "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
226     # "\n"
227     "be688838ca8686e5c90689bf2ab585cef1137c999b48c70b92f67a5c34dc15697b5d11c982ed6d71be1e1e7f7b4e0733884aa97c3f7a339a8ed03577cf74be09"))
228
229 def sha512_nontrivial():
230     return HashBlacklist(hashlib.sha512(), boring_sha512_hashes)
231
232 def gziphash():
233     hashobj = DecompressedHash(GzipDecompressor(), hashlib.sha512())
234     hashobj = SuppressingHash(hashobj, (ValueError, zlib.error))
235     hashobj.name = "gzip_sha512"
236     return HashBlacklist(hashobj, boring_sha512_hashes)
237
238 def get_hashes(filelike):
239     af = ArReader(filelike)
240     af.read_magic()
241     tf = None
242     while True:
243         try:
244             name = af.read_entry()
245         except EOFError:
246             return
247         if name == "data.tar.gz":
248             tf = tarfile.open(fileobj=af, mode="r|gz")
249         elif name == "data.tar.bz2":
250             tf = tarfile.open(fileobj=af, mode="r|bz2")
251         elif name == "data.tar.xz":
252             zf = XzStream(af)
253             tf = tarfile.open(fileobj=zf, mode="r|")
254         else:
255             continue
256         for elem in tf:
257             if not elem.isreg(): # excludes hard links as well
258                 continue
259             hasher = MultiHash(sha512_nontrivial(), gziphash())
260             hasher = hash_file(hasher, tf.extractfile(elem))
261             for hashobj in hasher.hashes:
262                 hashvalue = hashobj.hexdigest()
263                 if hashvalue:
264                     yield (elem.name, elem.size, hashobj.name, hashvalue)
265     if not tf:
266         raise ValueError("data.tar not found")
267
268 def main():
269     filename = sys.argv[1]
270     match = re.match("(?:.*/)?(?P<name>[^_]+)_(?P<version>[^_]+)_(?P<architecture>[^_.]+)\\.deb$", filename)
271     package, version, architecture = match.groups()
272     db = sqlite3.connect("test.sqlite3")
273     cur = db.cursor()
274
275     cur.execute("SELECT version FROM package WHERE package = ?;", (package,))
276     versions = [tpl[0] for tpl in cur.fetchall()]
277     versions.append(version)
278     versions.sort(cmp=apt_pkg.version_compare)
279     if versions[-1] != version:
280         return # not the newest version
281
282     cur.execute("DELETE FROM package WHERE package = ?;", (package,))
283     cur.execute("DELETE FROM content WHERE package = ?;", (package,))
284     cur.execute("INSERT INTO package (package, version, architecture) VALUES (?, ?, ?);",
285                 (package, version, architecture))
286     with open(filename) as pkg:
287         for name, size, function, hexhash in get_hashes(pkg):
288             name = name.decode("utf8")
289             cur.execute("INSERT INTO content (package, filename, size, function, hash) VALUES (?, ?, ?, ?, ?);",
290                     (package, name, size, function, hexhash))
291     db.commit()
292
293 if __name__ == "__main__":
294     main()