move from deprecated optparse to argparse
[~helmut/debian-dedup.git] / autoimport.py
1 #!/usr/bin/python
2 """This scrip takes a directory or a http base url to a mirror and imports all
3 packages contained. It has rather strong assumptions on the working directory.
4 """
5
6 import argparse
7 import gzip
8 import errno
9 import io
10 import multiprocessing
11 import os
12 import sqlite3
13 import subprocess
14 import sys
15 import tempfile
16 try:
17     from urllib.parse import unquote
18 except ImportError:
19     from urllib import unquote
20 try:
21     from urllib.request import urlopen
22 except ImportError:
23     from urllib import urlopen
24
25 import concurrent.futures
26 from debian import deb822
27 from debian.debian_support import version_compare
28
29 from readyaml import readyaml
30
31 def process_http(pkgs, url):
32     pkglist = urlopen(url + "/dists/sid/main/binary-amd64/Packages.gz").read()
33     pkglist = gzip.GzipFile(fileobj=io.BytesIO(pkglist)).read()
34     pkglist = io.BytesIO(pkglist)
35     pkglist = deb822.Packages.iter_paragraphs(pkglist)
36     for pkg in pkglist:
37         name = pkg["Package"]
38         if name in pkgs and \
39                 version_compare(pkgs[name]["version"], pkg["Version"]) > 0:
40             continue
41         pkgs[name] = dict(version=pkg["Version"],
42                           filename="%s/%s" % (url, pkg["Filename"]),
43                           sha256hash=pkg["SHA256"])
44
45 def process_file(pkgs, filename):
46     base = os.path.basename(filename)
47     if not base.endswith(".deb"):
48         raise ValueError("filename does not end in .deb")
49     parts = base.split("_")
50     if len(parts) != 3:
51         raise ValueError("filename not in form name_version_arch.deb")
52     name, version, _ = parts
53     version = unquote(version)
54     if name in pkgs and version_compare(pkgs[name]["version"], version) > 0:
55         return
56     pkgs[name] = dict(version=version, filename=filename)
57
58 def process_dir(pkgs, d):
59     for entry in os.listdir(d):
60         try:
61             process_file(pkgs, os.path.join(d, entry))
62         except ValueError:
63             pass
64
65 def process_pkg(name, pkgdict, outpath):
66     filename = pkgdict["filename"]
67     print("importing %s" % filename)
68     importcmd = [sys.executable, "importpkg.py"]
69     if "sha256hash" in pkgdict:
70         importcmd.extend(["-H", pkgdict["sha256hash"]])
71     if filename.startswith(("http://", "https://", "ftp://", "file://")):
72         with open(outpath, "w") as outp:
73             dl = subprocess.Popen(["curl", "-s", filename],
74                                   stdout=subprocess.PIPE, close_fds=True)
75             imp = subprocess.Popen(importcmd, stdin=dl.stdout, stdout=outp,
76                                    close_fds=True)
77             if imp.wait():
78                 raise ValueError("importpkg failed")
79             if dl.wait():
80                 raise ValueError("curl failed")
81     else:
82         with open(filename) as inp:
83             with open(outpath, "w") as outp:
84                 subprocess.check_call(importcmd, stdin=inp, stdout=outp,
85                                       close_fds=True)
86     print("preprocessed %s" % name)
87
88 def main():
89     parser = argparse.ArgumentParser()
90     parser.add_argument("-n", "--new", action="store_true",
91                         help="avoid reimporting same versions")
92     parser.add_argument("-p", "--prune", action="store_true",
93                         help="prune packages old packages")
94     parser.add_argument("-d", "--database", action="store",
95                         default="test.sqlite3",
96                         help="path to the sqlite3 database file")
97     parser.add_argument("files", nargs='+',
98                         help="files or directories or repository urls")
99     args = parser.parse_args()
100     tmpdir = tempfile.mkdtemp(prefix="debian-dedup")
101     db = sqlite3.connect(args.database)
102     cur = db.cursor()
103     cur.execute("PRAGMA foreign_keys = ON;")
104     e = concurrent.futures.ThreadPoolExecutor(multiprocessing.cpu_count())
105     pkgs = {}
106     for d in args.files:
107         print("processing %s" % d)
108         if d.startswith(("http://", "https://", "ftp://", "file://")):
109             process_http(pkgs, d)
110         elif os.path.isdir(d):
111             process_dir(pkgs, d)
112         else:
113             process_file(pkgs, d)
114
115     print("reading database")
116     cur.execute("SELECT name, version FROM package;")
117     knownpkgs = dict((row[0], row[1]) for row in cur.fetchall())
118     distpkgs = set(pkgs.keys())
119     if args.new:
120         for name in distpkgs:
121             if name in knownpkgs and version_compare(pkgs[name]["version"],
122                     knownpkgs[name]) <= 0:
123                 del pkgs[name]
124     knownpkgs = set(knownpkgs)
125
126     with e:
127         fs = {}
128         for name, pkg in pkgs.items():
129             outpath = os.path.join(tmpdir, name)
130             fs[e.submit(process_pkg, name, pkg, outpath)] = name
131
132         for f in concurrent.futures.as_completed(fs.keys()):
133             name = fs[f]
134             if f.exception():
135                 print("%s failed to import: %r" % (name, f.exception()))
136                 continue
137             inf = os.path.join(tmpdir, name)
138             print("sqlimporting %s" % name)
139             with open(inf) as inp:
140                 try:
141                     readyaml(db, inp)
142                 except Exception as exc:
143                     print("%s failed sql with exception %r" % (name, exc))
144                 else:
145                     os.unlink(inf)
146
147     if args.prune:
148         delpkgs = knownpkgs - distpkgs
149         print("clearing packages %s" % " ".join(delpkgs))
150         cur.executemany("DELETE FROM package WHERE name = ?;",
151                         ((pkg,) for pkg in delpkgs))
152         # Tables content, dependency and sharing will also be pruned
153         # due to ON DELETE CASCADE clauses.
154         db.commit()
155     try:
156         os.rmdir(tmpdir)
157     except OSError as err:
158         if err.errno != errno.ENOTEMPTY:
159             raise
160         print("keeping temporary directory %s due to failed packages %s" %
161               (tmpdir, " ".join(os.listdir(tmpdir))))
162
163 if __name__ == "__main__":
164     main()