drop support for Python 2.x
[~helmut/debian-dedup.git] / webapp.py
1 #!/usr/bin/python3
2
3 import argparse
4 import contextlib
5 import datetime
6 import sqlite3
7 from wsgiref.simple_server import make_server
8
9 import jinja2
10 from werkzeug.exceptions import HTTPException, NotFound
11 from werkzeug.routing import Map, Rule
12 from werkzeug.utils import redirect
13 from werkzeug.wrappers import Request, Response
14 from werkzeug.wsgi import SharedDataMiddleware
15
16 from dedup.utils import fetchiter
17
18 jinjaenv = jinja2.Environment(loader=jinja2.PackageLoader("dedup", "templates"))
19
20 def format_size(size):
21     size = float(size)
22     fmt = "%d B"
23     if size >= 1024:
24         size /= 1024
25         fmt = "%.1f KB"
26     if size >= 1024:
27         size /= 1024
28         fmt = "%.1f MB"
29     if size >= 1024:
30         size /= 1024
31         fmt = "%.1f GB"
32     return fmt % size
33
34 def function_combination(function1, function2):
35     if function1 == function2:
36         return function1
37     return "%s -> %s" % (function1, function2)
38
39 # Workaround for jinja bug #59 (broken filesizeformat)
40 jinjaenv.filters["filesizeformat"] = format_size
41
42 base_template = jinjaenv.get_template("base.html")
43 package_template = jinjaenv.get_template("binary.html")
44 detail_template = jinjaenv.get_template("compare.html")
45 hash_template = jinjaenv.get_template("hash.html")
46 index_template = jinjaenv.get_template("index.html")
47 source_template = jinjaenv.get_template("source.html")
48
49 def encode_and_buffer(iterator):
50     buff = b""
51     for elem in iterator:
52         buff += elem.encode("utf8")
53         if len(buff) >= 2048:
54             yield buff
55             buff = b""
56     if buff:
57         yield buff
58
59 def html_response(unicode_iterator, max_age=24 * 60 * 60):
60     resp = Response(encode_and_buffer(unicode_iterator), mimetype="text/html")
61     resp.cache_control.max_age = max_age
62     resp.expires = datetime.datetime.now() + datetime.timedelta(seconds=max_age)
63     return resp
64
65 class InternalRedirect(Exception):
66     def __init__(self, target, code=301):
67         Exception.__init__(self)
68         self.target = target
69         self.code = code
70
71 class Application:
72     def __init__(self, db):
73         self.db = db
74         self.routingmap = Map([
75             Rule("/", methods=("GET",), endpoint="index"),
76             Rule("/binary/<package>", methods=("GET",), endpoint="package"),
77             Rule("/compare/<package1>/<package2>", methods=("GET",), endpoint="detail"),
78             Rule("/hash/<function>/<hashvalue>", methods=("GET",), endpoint="hash"),
79             Rule("/source/<package>", methods=("GET",), endpoint="source"),
80         ])
81
82     @Request.application
83     def __call__(self, request):
84         mapadapter = self.routingmap.bind_to_environ(request.environ)
85         try:
86             endpoint, args = mapadapter.match()
87             if endpoint == "package":
88                 return self.show_package(args["package"])
89             elif endpoint == "detail":
90                 return self.show_detail(args["package1"], args["package2"])
91             elif endpoint == "hash":
92                 if args["function"] == "image_sha512":
93                     # backwards compatibility
94                     raise InternalRedirect("/hash/png_sha512/%s" %
95                                            args["hashvalue"])
96                 return self.show_hash(args["function"], args["hashvalue"])
97             elif endpoint == "index":
98                 if not request.environ["PATH_INFO"]:
99                     raise InternalRedirect("/")
100                 return html_response(index_template.render(dict(urlroot="")))
101             elif endpoint == "source":
102                 return self.show_source(args["package"])
103             raise NotFound()
104         except InternalRedirect as r:
105             return redirect(request.environ["SCRIPT_NAME"] + r.target, r.code)
106         except HTTPException as e:
107             return e
108
109     def get_details(self, package):
110         with contextlib.closing(self.db.cursor()) as cur:
111             cur.execute("SELECT id, version, architecture FROM package WHERE name = ?;",
112                         (package,))
113             row = cur.fetchone()
114             if not row:
115                 raise NotFound()
116             pid, version, architecture = row
117             details = dict(pid=pid,
118                            package=package,
119                            version=version,
120                            architecture=architecture)
121             cur.execute("SELECT count(filename), sum(size) FROM content WHERE pid = ?;",
122                         (pid,))
123             num_files, total_size = cur.fetchone()
124         if total_size is None:
125             total_size = 0
126         details.update(dict(num_files=num_files, total_size=total_size))
127         return details
128
129     def get_dependencies(self, pid):
130         with contextlib.closing(self.db.cursor()) as cur:
131             cur.execute("SELECT required FROM dependency WHERE pid = ?;",
132                         (pid,))
133             return set(row[0] for row in fetchiter(cur))
134
135     def cached_sharedstats(self, pid):
136         sharedstats = {}
137         with contextlib.closing(self.db.cursor()) as cur:
138             cur.execute("SELECT pid2, package.name, f1.name, f2.name, files, size FROM sharing JOIN package ON sharing.pid2 = package.id JOIN function AS f1 ON sharing.fid1 = f1.id JOIN function AS f2 ON sharing.fid2 = f2.id WHERE pid1 = ? AND f1.eqclass = f2.eqclass;",
139                         (pid,))
140             for pid2, package2, func1, func2, files, size in fetchiter(cur):
141                 curstats = sharedstats.setdefault(
142                         function_combination(func1, func2), list())
143                 if pid2 == pid:
144                     package2 = None
145                 curstats.append(dict(package=package2, duplicate=files,
146                                      savable=size))
147         return sharedstats
148
149     def show_package(self, package):
150         params = self.get_details(package)
151         params["dependencies"] = self.get_dependencies(params["pid"])
152         params["shared"] = self.cached_sharedstats(params["pid"])
153         params["urlroot"] = ".."
154         cur = self.db.cursor()
155         cur.execute("SELECT content.filename, issue.issue FROM content JOIN issue ON content.id = issue.cid WHERE content.pid = ?;",
156                     (params["pid"],))
157         params["issues"] = dict(cur.fetchall())
158         cur.close()
159         return html_response(package_template.render(params))
160
161     def compute_comparison(self, pid1, pid2):
162         """Compute a sequence of comparison objects ordered by the size of the
163         object in the first package. Each element of the sequence is a dict
164         defining the following keys:
165          * filenames: A set of filenames in package 1 (pid1) all referring to
166            the same object.
167          * size: Size of the object in bytes.
168          * matches: A mapping from filenames in package 2 (pid2) to a mapping
169            from hash function pairs to hash values.
170         """
171         cur = self.db.cursor()
172         cur.execute("SELECT content.id, content.filename, content.size, hash.hash FROM content JOIN hash ON content.id = hash.cid JOIN duplicate ON content.id = duplicate.cid JOIN function ON hash.fid = function.id WHERE pid = ? AND function.name = 'sha512' ORDER BY size DESC;",
173                     (pid1,))
174         cursize = -1
175         files = dict()
176         minmatch = 2 if pid1 == pid2 else 1
177         cur2 = self.db.cursor()
178         for cid, filename, size, hashvalue in fetchiter(cur):
179             if cursize != size:
180                 for entry in files.values():
181                     if len(entry["matches"]) >= minmatch:
182                         yield entry
183                 files.clear()
184                 cursize = size
185
186             if hashvalue in files:
187                 files[hashvalue]["filenames"].add(filename)
188                 continue
189
190             entry = dict(filenames=set((filename,)), size=size, matches={})
191             files[hashvalue] = entry
192
193             cur2.execute("SELECT fa.name, ha.hash, fb.name, filename FROM hash AS ha JOIN hash AS hb ON ha.hash = hb.hash JOIN content ON hb.cid = content.id JOIN function AS fa ON ha.fid = fa.id JOIN function AS fb ON hb.fid = fb.id WHERE ha.cid = ? AND pid = ? AND fa.eqclass = fb.eqclass;",
194                          (cid, pid2))
195             for func1, hashvalue, func2, filename in fetchiter(cur2):
196                 entry["matches"].setdefault(filename, {})[func1, func2] = \
197                         hashvalue
198         cur2.close()
199         cur.close()
200
201         for entry in files.values():
202             if len(entry["matches"]) >= minmatch:
203                 yield entry
204
205     def show_detail(self, package1, package2):
206         details1 = details2 = self.get_details(package1)
207         if package1 != package2:
208             details2 = self.get_details(package2)
209
210         shared = self.compute_comparison(details1["pid"], details2["pid"])
211         params = dict(
212             details1=details1,
213             details2=details2,
214             urlroot="../..",
215             shared=shared)
216         return html_response(detail_template.stream(params))
217
218     def show_hash(self, function, hashvalue):
219         with contextlib.closing(self.db.cursor()) as cur:
220             cur.execute("SELECT package.name, content.filename, content.size, f2.name FROM hash JOIN content ON hash.cid = content.id JOIN package ON content.pid = package.id JOIN function AS f2 ON hash.fid = f2.id JOIN function AS f1 ON f2.eqclass = f1.eqclass WHERE f1.name = ? AND hash = ?;",
221                         (function, hashvalue,))
222             entries = [dict(package=package, filename=filename, size=size,
223                             function=otherfunc)
224                        for package, filename, size, otherfunc in fetchiter(cur)]
225             if not entries:
226                 # Assumption: '~' serves as an infinite character larger than
227                 # any other character in the hash column.
228                 cur.execute("SELECT DISTINCT hash.hash FROM hash JOIN function ON hash.fid = function.id WHERE function.name = ? AND hash.hash >= ? AND hash.hash <= ? LIMIT 2;",
229                             (function, hashvalue, hashvalue + '~'))
230                 values = cur.fetchall()
231                 if len(values) == 1:
232                     raise InternalRedirect("/hash/%s/%s" %
233                                            (function, values[0][0]), 302)
234                 raise NotFound()
235         params = dict(function=function, hashvalue=hashvalue, entries=entries,
236                       urlroot="../..")
237         return html_response(hash_template.render(params))
238
239     def show_source(self, package):
240         with contextlib.closing(self.db.cursor()) as cur:
241             cur.execute("SELECT name FROM package WHERE source = ?;",
242                         (package,))
243             binpkgs = dict.fromkeys(pkg for pkg, in fetchiter(cur))
244             if not binpkgs:
245                 raise NotFound
246             cur.execute("SELECT p1.name, p2.name, f1.name, f2.name, sharing.files, sharing.size FROM sharing JOIN package AS p1 ON sharing.pid1 = p1.id JOIN package AS p2 ON sharing.pid2 = p2.id JOIN function AS f1 ON sharing.fid1 = f1.id JOIN function AS f2 ON sharing.fid2 = f2.id WHERE p1.source = ?;",
247                         (package,))
248             for binary, otherbin, func1, func2, files, size in fetchiter(cur):
249                 entry = dict(package=otherbin,
250                              funccomb=function_combination(func1, func2),
251                              duplicate=files, savable=size)
252                 oldentry = binpkgs.get(binary)
253                 if not (oldentry and oldentry["savable"] >= size):
254                     binpkgs[binary] = entry
255         params = dict(source=package, packages=binpkgs, urlroot="..")
256         return html_response(source_template.render(params))
257
258 def main():
259     parser = argparse.ArgumentParser()
260     parser.add_argument("-d", "--database", action="store",
261                         default="test.sqlite3",
262                         help="path to the sqlite3 database file")
263     args = parser.parse_args()
264     app = Application(sqlite3.connect(args.database))
265     app = SharedDataMiddleware(app, {"/static": ("dedup", "static")})
266     make_server("0.0.0.0", 8800, app).serve_forever()
267
268 if __name__ == "__main__":
269     main()