drop support for Python 2.x
[~helmut/debian-dedup.git] / autoimport.py
1 #!/usr/bin/python3
2 """This scrip takes a directory or a http base url to a mirror and imports all
3 packages contained. It has rather strong assumptions on the working directory.
4 """
5
6 import argparse
7 import contextlib
8 import errno
9 import multiprocessing
10 import os
11 import sqlite3
12 import subprocess
13 import sys
14 import tempfile
15 import urllib.parse
16 import concurrent.futures
17 from debian import deb822
18 from debian.debian_support import version_compare
19
20 from dedup.utils import open_compressed_mirror_url
21
22 from readyaml import readyaml
23
24 def process_http(pkgs, url, addhash=True):
25     listurl = url + "/dists/sid/main/binary-amd64/Packages"
26     with contextlib.closing(open_compressed_mirror_url(listurl)) as pkglist:
27         for pkg in deb822.Packages.iter_paragraphs(pkglist):
28             name = pkg["Package"]
29             if name in pkgs and \
30                     version_compare(pkgs[name]["version"], pkg["Version"]) > 0:
31                 continue
32             inst = dict(version=pkg["Version"],
33                         filename="%s/%s" % (url, pkg["Filename"]))
34             if addhash:
35                 inst["sha256hash"] = pkg["SHA256"]
36             pkgs[name] = inst
37
38 def process_file(pkgs, filename):
39     base = os.path.basename(filename)
40     if not base.endswith(".deb"):
41         raise ValueError("filename does not end in .deb")
42     parts = base.split("_")
43     if len(parts) != 3:
44         raise ValueError("filename not in form name_version_arch.deb")
45     name, version, _ = parts
46     version = urllib.parse.unquote(version)
47     if name in pkgs and version_compare(pkgs[name]["version"], version) > 0:
48         return
49     pkgs[name] = dict(version=version, filename=filename)
50
51 def process_dir(pkgs, d):
52     for entry in os.listdir(d):
53         try:
54             process_file(pkgs, os.path.join(d, entry))
55         except ValueError:
56             pass
57
58 def process_pkg(name, pkgdict, outpath):
59     filename = pkgdict["filename"]
60     print("importing %s" % filename)
61     importcmd = [sys.executable, "importpkg.py"]
62     if "sha256hash" in pkgdict:
63         importcmd.extend(["-H", pkgdict["sha256hash"]])
64     if filename.startswith(("http://", "https://", "ftp://", "file://")):
65         importcmd.append(filename)
66         with open(outpath, "w") as outp:
67             subprocess.check_call(importcmd, stdout=outp, close_fds=True)
68     else:
69         with open(filename) as inp:
70             with open(outpath, "w") as outp:
71                 subprocess.check_call(importcmd, stdin=inp, stdout=outp,
72                                       close_fds=True)
73     print("preprocessed %s" % name)
74
75 def main():
76     parser = argparse.ArgumentParser()
77     parser.add_argument("-n", "--new", action="store_true",
78                         help="avoid reimporting same versions")
79     parser.add_argument("-p", "--prune", action="store_true",
80                         help="prune packages old packages")
81     parser.add_argument("-d", "--database", action="store",
82                         default="test.sqlite3",
83                         help="path to the sqlite3 database file")
84     parser.add_argument("--noverify", action="store_true",
85                         help="do not verify binary package hashes")
86     parser.add_argument("files", nargs='+',
87                         help="files or directories or repository urls")
88     args = parser.parse_args()
89     tmpdir = tempfile.mkdtemp(prefix="debian-dedup")
90     db = sqlite3.connect(args.database)
91     cur = db.cursor()
92     cur.execute("PRAGMA foreign_keys = ON;")
93     e = concurrent.futures.ThreadPoolExecutor(multiprocessing.cpu_count())
94     pkgs = {}
95     for d in args.files:
96         print("processing %s" % d)
97         if d.startswith(("http://", "https://", "ftp://", "file://")):
98             process_http(pkgs, d, not args.noverify)
99         elif os.path.isdir(d):
100             process_dir(pkgs, d)
101         else:
102             process_file(pkgs, d)
103
104     print("reading database")
105     cur.execute("SELECT name, version FROM package;")
106     knownpkgs = dict((row[0], row[1]) for row in cur.fetchall())
107     distpkgs = set(pkgs.keys())
108     if args.new:
109         for name in distpkgs:
110             if name in knownpkgs and version_compare(pkgs[name]["version"],
111                     knownpkgs[name]) <= 0:
112                 del pkgs[name]
113     knownpkgs = set(knownpkgs)
114
115     with e:
116         fs = {}
117         for name, pkg in pkgs.items():
118             outpath = os.path.join(tmpdir, name)
119             fs[e.submit(process_pkg, name, pkg, outpath)] = name
120
121         for f in concurrent.futures.as_completed(fs.keys()):
122             name = fs[f]
123             if f.exception():
124                 print("%s failed to import: %r" % (name, f.exception()))
125                 continue
126             inf = os.path.join(tmpdir, name)
127             print("sqlimporting %s" % name)
128             with open(inf) as inp:
129                 try:
130                     readyaml(db, inp)
131                 except Exception as exc:
132                     print("%s failed sql with exception %r" % (name, exc))
133                 else:
134                     os.unlink(inf)
135
136     if args.prune:
137         delpkgs = knownpkgs - distpkgs
138         print("clearing packages %s" % " ".join(delpkgs))
139         cur.executemany("DELETE FROM package WHERE name = ?;",
140                         ((pkg,) for pkg in delpkgs))
141         # Tables content, dependency and sharing will also be pruned
142         # due to ON DELETE CASCADE clauses.
143         db.commit()
144     try:
145         os.rmdir(tmpdir)
146     except OSError as err:
147         if err.errno != errno.ENOTEMPTY:
148             raise
149         print("keeping temporary directory %s due to failed packages %s" %
150               (tmpdir, " ".join(os.listdir(tmpdir))))
151
152 if __name__ == "__main__":
153     main()