move hashing functions to module dedup.hashing
[~helmut/debian-dedup.git] / importpkg.py
1 #!/usr/bin/python
2 """
3 CREATE TABLE package (package TEXT PRIMARY KEY, version TEXT, architecture TEXT);
4 CREATE TABLE content (package TEXT, filename TEXT, size INTEGER, function TEXT, hash TEXT, FOREIGN KEY (package) REFERENCES package(package));
5 CREATE TABLE dependency (package TEXT, required TEXT, FOREIGN KEY (package) REFERENCES package(package), FOREIGN KEY (required) REFERENCES package(package));
6 CREATE INDEX content_package_index ON content (package);
7 CREATE INDEX content_hash_index ON content (hash);
8 """
9
10 import hashlib
11 import sqlite3
12 import struct
13 import sys
14 import tarfile
15 import zlib
16
17 from debian.debian_support import version_compare
18 from debian import deb822
19 import lzma
20
21 from dedup.hashing import HashBlacklist, DecompressedHash, SuppressingHash, hash_file
22
23 class ArReader(object):
24     global_magic = b"!<arch>\n"
25     file_magic = b"`\n"
26
27     def __init__(self, fileobj):
28         self.fileobj = fileobj
29         self.remaining = None
30         self.padding = 0
31
32     def skip(self, length):
33         while length:
34             data = self.fileobj.read(min(4096, length))
35             if not data:
36                 raise ValueError("archive truncated")
37             length -= len(data)
38
39     def read_magic(self):
40         data = self.fileobj.read(len(self.global_magic))
41         if data != self.global_magic:
42             raise ValueError("ar global header not found")
43         self.remaining = 0
44
45     def read_entry(self):
46         self.skip_current_entry()
47         if self.padding:
48             if self.fileobj.read(1) != '\n':
49                 raise ValueError("missing ar padding")
50             self.padding = 0
51         file_header = self.fileobj.read(60)
52         if not file_header:
53             raise EOFError("end of archive found")
54         parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
55         parts = [p.rstrip(" ") for p in parts]
56         if parts.pop() != self.file_magic:
57             raise ValueError("ar file header not found")
58         self.remaining = int(parts[5])
59         self.padding = self.remaining % 2
60         return parts[0] # name
61
62     def skip_current_entry(self):
63         self.skip(self.remaining)
64         self.remaining = 0
65
66     def read(self, length=None):
67         if length is None:
68             length = self.remaining
69         else:
70             length = min(self.remaining, length)
71         data = self.fileobj.read(length)
72         self.remaining -= len(data)
73         return data
74
75 class XzStream(object):
76     blocksize = 65536
77
78     def __init__(self, fileobj):
79         self.fileobj = fileobj
80         self.decomp = lzma.LZMADecompressor()
81         self.buff = b""
82
83     def read(self, length):
84         data = True
85         while True:
86             if len(self.buff) >= length:
87                 ret = self.buff[:length]
88                 self.buff = self.buff[length:]
89                 return ret
90             elif not data: # read EOF in last iteration
91                 ret = self.buff
92                 self.buff = b""
93                 return ret
94             data = self.fileobj.read(self.blocksize)
95             if data:
96                 self.buff += self.decomp.decompress(data)
97             else:
98                 self.buff += self.decomp.flush()
99
100 class MultiHash(object):
101     def __init__(self, *hashes):
102         self.hashes = hashes
103
104     def update(self, data):
105         for hasher in self.hashes:
106             hasher.update(data)
107
108 class GzipDecompressor(object):
109     def __init__(self):
110         self.inbuffer = b""
111         self.decompressor = None # zlib.decompressobj(-zlib.MAX_WBITS)
112
113     def decompress(self, data):
114         if self.decompressor:
115             data = self.decompressor.decompress(data)
116             if not self.decompressor.unused_data:
117                 return data
118             unused_data = self.decompressor.unused_data
119             self.decompressor = None
120             return data + self.decompress(unused_data)
121         self.inbuffer += data
122         skip = 10
123         if len(self.inbuffer) < skip:
124             return b""
125         if not self.inbuffer.startswith(b"\037\213\010"):
126             raise ValueError("gzip magic not found")
127         flag = ord(self.inbuffer[3])
128         if flag & 4:
129             if len(self.inbuffer) < skip + 2:
130                 return b""
131             length, = struct.unpack("<H", self.inbuffer[skip:skip+2])
132             skip += 2 + length
133         for field in (8, 16):
134             if flag & field:
135                 length = self.inbuffer.find("\0", skip)
136                 if length < 0:
137                     return b""
138                 skip = length + 1
139         if flag & 2:
140             skip += 2
141         if len(self.inbuffer) < skip:
142             return b""
143         data = self.inbuffer[skip:]
144         self.inbuffer = b""
145         self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
146         return self.decompress(data)
147
148     @property
149     def unused_data(self):
150         if self.decompressor:
151             return self.decompressor.unused_data
152         else:
153             return self.inbuffer
154
155     def flush(self):
156         if not self.decompressor:
157             return b""
158         return self.decompressor.flush()
159
160     def copy(self):
161         new = GzipDecompressor()
162         new.inbuffer = self.inbuffer
163         if self.decompressor:
164             new.decompressor = self.decompressor.copy()
165         return new
166
167 boring_sha512_hashes = set((
168     # ""
169     "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
170     # "\n"
171     "be688838ca8686e5c90689bf2ab585cef1137c999b48c70b92f67a5c34dc15697b5d11c982ed6d71be1e1e7f7b4e0733884aa97c3f7a339a8ed03577cf74be09"))
172
173 def sha512_nontrivial():
174     return HashBlacklist(hashlib.sha512(), boring_sha512_hashes)
175
176 def gziphash():
177     hashobj = DecompressedHash(GzipDecompressor(), hashlib.sha512())
178     hashobj = SuppressingHash(hashobj, (ValueError, zlib.error))
179     hashobj.name = "gzip_sha512"
180     return HashBlacklist(hashobj, boring_sha512_hashes)
181
182 def get_hashes(tar):
183     for elem in tar:
184         if not elem.isreg(): # excludes hard links as well
185             continue
186         hasher = MultiHash(sha512_nontrivial(), gziphash())
187         hasher = hash_file(hasher, tar.extractfile(elem))
188         for hashobj in hasher.hashes:
189             hashvalue = hashobj.hexdigest()
190             if hashvalue:
191                 yield (elem.name, elem.size, hashobj.name, hashvalue)
192
193 def process_package(db, filelike):
194     cur = db.cursor()
195     af = ArReader(filelike)
196     af.read_magic()
197     state = "start"
198     while True:
199         try:
200             name = af.read_entry()
201         except EOFError:
202             break
203         if name == "control.tar.gz":
204             if state != "start":
205                 raise ValueError("unexpected control.tar.gz")
206             state = "control"
207             tf = tarfile.open(fileobj=af, mode="r|gz")
208             for elem in tf:
209                 if elem.name != "./control":
210                     continue
211                 if state != "control":
212                     raise ValueError("duplicate control file")
213                 state = "control_file"
214                 control = tf.extractfile(elem).read()
215                 control = deb822.Packages(control)
216                 package = control["package"].encode("ascii")
217                 version = control["version"].encode("ascii")
218                 architecture = control["architecture"].encode("ascii")
219
220                 cur.execute("SELECT version FROM package WHERE package = ?;",
221                             (package,))
222                 row = cur.fetchone()
223                 if row and version_compare(row[0], version) > 0:
224                     return # already seen a newer package
225
226                 cur.execute("DELETE FROM package WHERE package = ?;",
227                             (package,))
228                 cur.execute("DELETE FROM content WHERE package = ?;",
229                             (package,))
230                 cur.execute("INSERT INTO package (package, version, architecture) VALUES (?, ?, ?);",
231                             (package, version, architecture))
232                 depends = control.relations.get("depends", [])
233                 depends = set(dep[0]["name"].encode("ascii")
234                               for dep in depends if len(dep) == 1)
235                 cur.execute("DELETE FROM dependency WHERE package = ?;",
236                             (package,))
237                 cur.executemany("INSERT INTO dependency (package, required) VALUES (?, ?);",
238                                 ((package, dep) for dep in depends))
239                 break
240             continue
241         elif name == "data.tar.gz":
242             tf = tarfile.open(fileobj=af, mode="r|gz")
243         elif name == "data.tar.bz2":
244             tf = tarfile.open(fileobj=af, mode="r|bz2")
245         elif name == "data.tar.xz":
246             zf = XzStream(af)
247             tf = tarfile.open(fileobj=zf, mode="r|")
248         else:
249             continue
250         if state != "control_file":
251             raise ValueError("missing control file")
252         for name, size, function, hexhash in get_hashes(tf):
253             cur.execute("INSERT INTO content (package, filename, size, function, hash) VALUES (?, ?, ?, ?, ?);",
254                         (package, name.decode("utf8"), size, function, hexhash))
255         db.commit()
256         return
257     raise ValueError("data.tar not found")
258
259 def main():
260     db = sqlite3.connect("test.sqlite3")
261     process_package(db, sys.stdin)
262
263 if __name__ == "__main__":
264     main()