first prototype
[~helmut/debian-dedup.git] / test.py
1 #!/usr/bin/python
2 """
3 CREATE TABLE content (package TEXT, version TEXT, architecture TEXT, filename TEXT, size INTEGER, hash TEXT);
4 CREATE INDEX content_package_index ON content (package);
5 CREATE INDEX content_hash_index ON content (hash);
6 """
7
8 import hashlib
9 import os
10 import re
11 import sqlite3
12 import struct
13 import sys
14 import tarfile
15
16 import apt_pkg
17 import lzma
18
19 apt_pkg.init()
20
21 class ArReader(object):
22     global_magic = b"!<arch>\n"
23     file_magic = b"`\n"
24
25     def __init__(self, fileobj, membertest):
26         self.fileobj = fileobj
27         self.membertest = membertest
28         self.remaining = None
29
30     def skip(self, length):
31         while length:
32             data = self.fileobj.read(min(4096, length))
33             if not data:
34                 raise ValueError("archive truncated")
35             length -= len(data)
36
37     def skiptillmember(self):
38         data = self.fileobj.read(len(self.global_magic))
39         if data != self.global_magic:
40             raise ValueError("ar global header not found")
41         while True:
42             file_header = self.fileobj.read(60)
43             if not file_header:
44                 raise ValueError("end of archive found")
45             parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
46             parts = [p.rstrip(" ") for p in parts]
47             if parts.pop() != self.file_magic:
48                 print(repr(file_header))
49                 raise ValueError("ar file header not found")
50             name = parts[0]
51             length = int(parts[5])
52             if self.membertest(name):
53                 self.remaining = length
54                 return name, length
55             self.skip(length + length % 2)
56
57     def read(self, length=None):
58         if length is None:
59             length = self.remaining
60         else:
61             length = min(self.remaining, length)
62         data = self.fileobj.read(length)
63         self.remaining -= len(data)
64         return data
65
66     def close(self):
67         self.fileobj.close()
68
69 class XzStream(object):
70     blocksize = 65536
71
72     def __init__(self, fileobj):
73         self.fileobj = fileobj
74         self.decomp = lzma.LZMADecompressor()
75         self.buff = b""
76
77     def read(self, length):
78         data = True
79         while True:
80             if len(self.buff) >= length:
81                 ret = self.buff[:length]
82                 self.buff = self.buff[length:]
83                 return ret
84             elif not data: # read EOF in last iteration
85                 ret = self.buff
86                 self.buff = b""
87                 return ret
88             data = self.fileobj.read(self.blocksize)
89             if data:
90                 self.buff += self.decomp.decompress(data)
91             else:
92                 self.buff += self.decomp.flush()
93
94 def hash_file(hashobj, filelike, blocksize=65536):
95     data = filelike.read(blocksize)
96     while data:
97         hashobj.update(data)
98         data = filelike.read(blocksize)
99     return hashobj
100
101 def get_hashes(filelike):
102     af = ArReader(filelike, lambda name: name.startswith("data.tar"))
103     name, membersize = af.skiptillmember()
104     if name == "data.tar.gz":
105         tf = tarfile.open(fileobj=af, mode="r|gz")
106     elif name == "data.tar.bz2":
107         tf = tarfile.open(fileobj=af, mode="r|bz2")
108     elif name == "data.tar.xz":
109         zf = XzStream(af)
110         tf = tarfile.open(fileobj=zf, mode="r|")
111     else:
112         raise ValueError("unsupported compression %r" % name)
113     for elem in tf:
114         if elem.size == 0: # boring
115             continue
116         if not elem.isreg(): # excludes hard links as well
117             continue
118         hasher = hash_file(hashlib.sha512(), tf.extractfile(elem))
119         yield (elem.name, elem.size, hasher.hexdigest())
120
121 def main():
122     filename = sys.argv[1]
123     match = re.match("(?:.*/)?(?P<name>[^_]+)_(?P<version>[^_]+)_(?P<architecture>[^_.]+)\\.deb$", filename)
124     package, version, architecture = match.groups()
125     db = sqlite3.connect("test.sqlite3")
126     cur = db.cursor()
127
128     cur.execute("SELECT version FROM content WHERE package = ?;", (package,))
129     versions = [tpl[0] for tpl in cur.fetchall()]
130     versions.append(version)
131     versions.sort(cmp=apt_pkg.version_compare)
132     if versions[-1] != version:
133         return # not the newest version
134
135     cur.execute("DELETE FROM content WHERE package = ?;", (package,))
136     #cur.execute("DELETE FROM content WHERE package = ? AND version = ? AND architecture = ?;",
137     #        (package, version, architecture))
138     with open(filename) as pkg:
139         for name, size, hexhash in get_hashes(pkg):
140             name = name.decode("utf8")
141             cur.execute("INSERT INTO content (package, version, architecture, filename, size, hash) VALUES (?, ?, ?, ?, ?, ?);",
142                     (package, version, architecture, name, size, hexhash))
143     db.commit()
144
145 if __name__ == "__main__":
146     main()