e51d0521d071320fc10de2322d009693e57706d7
[~helmut/debian-dedup.git] / autoimport.py
1 #!/usr/bin/python
2 """This scrip takes a directory or a http base url to a mirror and imports all
3 packages contained. It has rather strong assumptions on the working directory.
4 """
5
6 import argparse
7 import contextlib
8 import errno
9 import multiprocessing
10 import os
11 import sqlite3
12 import subprocess
13 import sys
14 import tempfile
15 try:
16     from urllib.parse import unquote
17 except ImportError:
18     from urllib import unquote
19
20 import concurrent.futures
21 from debian import deb822
22 from debian.debian_support import version_compare
23
24 from dedup.utils import open_compressed_mirror_url
25
26 from readyaml import readyaml
27
28 def process_http(pkgs, url, addhash=True):
29     listurl = url + "/dists/sid/main/binary-amd64/Packages"
30     with contextlib.closing(open_compressed_mirror_url(listurl)) as pkglist:
31         for pkg in deb822.Packages.iter_paragraphs(pkglist):
32             name = pkg["Package"]
33             if name in pkgs and \
34                     version_compare(pkgs[name]["version"], pkg["Version"]) > 0:
35                 continue
36             inst = dict(version=pkg["Version"],
37                         filename="%s/%s" % (url, pkg["Filename"]))
38             if addhash:
39                 inst["sha256hash"] = pkg["SHA256"]
40             pkgs[name] = inst
41
42 def process_file(pkgs, filename):
43     base = os.path.basename(filename)
44     if not base.endswith(".deb"):
45         raise ValueError("filename does not end in .deb")
46     parts = base.split("_")
47     if len(parts) != 3:
48         raise ValueError("filename not in form name_version_arch.deb")
49     name, version, _ = parts
50     version = unquote(version)
51     if name in pkgs and version_compare(pkgs[name]["version"], version) > 0:
52         return
53     pkgs[name] = dict(version=version, filename=filename)
54
55 def process_dir(pkgs, d):
56     for entry in os.listdir(d):
57         try:
58             process_file(pkgs, os.path.join(d, entry))
59         except ValueError:
60             pass
61
62 def process_pkg(name, pkgdict, outpath):
63     filename = pkgdict["filename"]
64     print("importing %s" % filename)
65     importcmd = [sys.executable, "importpkg.py"]
66     if "sha256hash" in pkgdict:
67         importcmd.extend(["-H", pkgdict["sha256hash"]])
68     if filename.startswith(("http://", "https://", "ftp://", "file://")):
69         importcmd.append(filename)
70         with open(outpath, "w") as outp:
71             subprocess.check_call(importcmd, stdout=outp, close_fds=True)
72     else:
73         with open(filename) as inp:
74             with open(outpath, "w") as outp:
75                 subprocess.check_call(importcmd, stdin=inp, stdout=outp,
76                                       close_fds=True)
77     print("preprocessed %s" % name)
78
79 def main():
80     parser = argparse.ArgumentParser()
81     parser.add_argument("-n", "--new", action="store_true",
82                         help="avoid reimporting same versions")
83     parser.add_argument("-p", "--prune", action="store_true",
84                         help="prune packages old packages")
85     parser.add_argument("-d", "--database", action="store",
86                         default="test.sqlite3",
87                         help="path to the sqlite3 database file")
88     parser.add_argument("--noverify", action="store_true",
89                         help="do not verify binary package hashes")
90     parser.add_argument("files", nargs='+',
91                         help="files or directories or repository urls")
92     args = parser.parse_args()
93     tmpdir = tempfile.mkdtemp(prefix="debian-dedup")
94     db = sqlite3.connect(args.database)
95     cur = db.cursor()
96     cur.execute("PRAGMA foreign_keys = ON;")
97     e = concurrent.futures.ThreadPoolExecutor(multiprocessing.cpu_count())
98     pkgs = {}
99     for d in args.files:
100         print("processing %s" % d)
101         if d.startswith(("http://", "https://", "ftp://", "file://")):
102             process_http(pkgs, d, not args.noverify)
103         elif os.path.isdir(d):
104             process_dir(pkgs, d)
105         else:
106             process_file(pkgs, d)
107
108     print("reading database")
109     cur.execute("SELECT name, version FROM package;")
110     knownpkgs = dict((row[0], row[1]) for row in cur.fetchall())
111     distpkgs = set(pkgs.keys())
112     if args.new:
113         for name in distpkgs:
114             if name in knownpkgs and version_compare(pkgs[name]["version"],
115                     knownpkgs[name]) <= 0:
116                 del pkgs[name]
117     knownpkgs = set(knownpkgs)
118
119     with e:
120         fs = {}
121         for name, pkg in pkgs.items():
122             outpath = os.path.join(tmpdir, name)
123             fs[e.submit(process_pkg, name, pkg, outpath)] = name
124
125         for f in concurrent.futures.as_completed(fs.keys()):
126             name = fs[f]
127             if f.exception():
128                 print("%s failed to import: %r" % (name, f.exception()))
129                 continue
130             inf = os.path.join(tmpdir, name)
131             print("sqlimporting %s" % name)
132             with open(inf) as inp:
133                 try:
134                     readyaml(db, inp)
135                 except Exception as exc:
136                     print("%s failed sql with exception %r" % (name, exc))
137                 else:
138                     os.unlink(inf)
139
140     if args.prune:
141         delpkgs = knownpkgs - distpkgs
142         print("clearing packages %s" % " ".join(delpkgs))
143         cur.executemany("DELETE FROM package WHERE name = ?;",
144                         ((pkg,) for pkg in delpkgs))
145         # Tables content, dependency and sharing will also be pruned
146         # due to ON DELETE CASCADE clauses.
147         db.commit()
148     try:
149         os.rmdir(tmpdir)
150     except OSError as err:
151         if err.errno != errno.ENOTEMPTY:
152             raise
153         print("keeping temporary directory %s due to failed packages %s" %
154               (tmpdir, " ".join(os.listdir(tmpdir))))
155
156 if __name__ == "__main__":
157     main()