3076f6ef44396427baffb5d64772ab6c7455afa2
[~helmut/debian-dedup.git] / autoimport.py
1 #!/usr/bin/python
2 """This scrip takes a directory or a http base url to a mirror and imports all
3 packages contained. It has rather strong assumptions on the working directory.
4 """
5
6 import gzip
7 import errno
8 import io
9 import multiprocessing
10 import optparse
11 import os
12 import sqlite3
13 import subprocess
14 import tempfile
15 import urllib
16 try:
17     from urllib.parse import unquote
18 except ImportError:
19     from urllib import unquote
20
21 import concurrent.futures
22 from debian import deb822
23 from debian.debian_support import version_compare
24
25 from readyaml import readyaml
26
27 def process_http(pkgs, url):
28     pkglist = urllib.urlopen(url + "/dists/sid/main/binary-amd64/Packages.gz").read()
29     pkglist = gzip.GzipFile(fileobj=io.BytesIO(pkglist)).read()
30     pkglist = io.BytesIO(pkglist)
31     pkglist = deb822.Packages.iter_paragraphs(pkglist)
32     for pkg in pkglist:
33         name = pkg["Package"]
34         if name in pkgs and \
35                 version_compare(pkgs[name]["version"], pkg["Version"]) > 0:
36             continue
37         pkgs[name] = dict(version=pkg["Version"],
38                           filename="%s/%s" % (url, pkg["Filename"]),
39                           sha256hash=pkg["SHA256"])
40
41 def process_file(pkgs, filename):
42     base = os.path.basename(filename)
43     if not base.endswith(".deb"):
44         raise ValueError("filename does not end in .deb")
45     parts = base.split("_")
46     if len(parts) != 3:
47         raise ValueError("filename not in form name_version_arch.deb")
48     name, version, _ = parts
49     version = unquote(version)
50     if name in pkgs and version_compare(pkgs[name]["version"], version) > 0:
51         return
52     pkgs[name] = dict(version=version, filename=filename)
53
54 def process_dir(pkgs, d):
55     for entry in os.listdir(d):
56         try:
57             process_file(pkgs, os.path.join(d, entry))
58         except ValueError:
59             pass
60
61 def process_pkg(name, pkgdict, outpath):
62     filename = pkgdict["filename"]
63     print("importing %s" % filename)
64     importcmd = ["python", "importpkg.py"]
65     if "sha256hash" in pkgdict:
66         importcmd.extend(["-H", pkgdict["sha256hash"]])
67     if filename.startswith(("http://", "https://", "ftp://", "file://")):
68         with open(outpath, "w") as outp:
69             dl = subprocess.Popen(["curl", "-s", filename],
70                                   stdout=subprocess.PIPE, close_fds=True)
71             imp = subprocess.Popen(importcmd, stdin=dl.stdout, stdout=outp,
72                                    close_fds=True)
73             if imp.wait():
74                 raise ValueError("importpkg failed")
75             if dl.wait():
76                 raise ValueError("curl failed")
77     else:
78         with open(filename) as inp:
79             with open(outpath, "w") as outp:
80                 subprocess.check_call(importcmd, stdin=inp, stdout=outp,
81                                       close_fds=True)
82     print("preprocessed %s" % name)
83
84 def main():
85     parser = optparse.OptionParser()
86     parser.add_option("-n", "--new", action="store_true",
87                       help="avoid reimporting same versions")
88     parser.add_option("-p", "--prune", action="store_true",
89                       help="prune packages old packages")
90     parser.add_option("-d", "--database", action="store",
91                       default="test.sqlite3",
92                       help="path to the sqlite3 database file")
93     options, args = parser.parse_args()
94     tmpdir = tempfile.mkdtemp(prefix=b"debian-dedup")
95     db = sqlite3.connect(options.database)
96     cur = db.cursor()
97     cur.execute("PRAGMA foreign_keys = ON;")
98     e = concurrent.futures.ThreadPoolExecutor(multiprocessing.cpu_count())
99     pkgs = {}
100     for d in args:
101         print("processing %s" % d)
102         if d.startswith(("http://", "https://", "ftp://", "file://")):
103             process_http(pkgs, d)
104         elif os.path.isdir(d):
105             process_dir(pkgs, d)
106         else:
107             process_file(pkgs, d)
108
109     print("reading database")
110     cur.execute("SELECT name, version FROM package;")
111     knownpkgs = dict((row[0], row[1]) for row in cur.fetchall())
112     distpkgs = set(pkgs.keys())
113     if options.new:
114         for name in distpkgs:
115             if name in knownpkgs and version_compare(pkgs[name]["version"],
116                     knownpkgs[name]) <= 0:
117                 del pkgs[name]
118     knownpkgs = set(knownpkgs)
119
120     with e:
121         fs = {}
122         for name, pkg in pkgs.items():
123             outpath = os.path.join(tmpdir, name)
124             fs[e.submit(process_pkg, name, pkg, outpath)] = name
125
126         for f in concurrent.futures.as_completed(fs.keys()):
127             name = fs[f]
128             if f.exception():
129                 print("%s failed to import: %r" % (name, f.exception()))
130                 continue
131             inf = os.path.join(tmpdir, name)
132             print("sqlimporting %s" % name)
133             with open(inf) as inp:
134                 try:
135                     readyaml(db, inp)
136                 except Exception as exc:
137                     print("%s failed sql with exception %r" % (name, exc))
138                 else:
139                     os.unlink(inf)
140
141     if options.prune:
142         delpkgs = knownpkgs - distpkgs
143         print("clearing packages %s" % " ".join(delpkgs))
144         cur.executemany("DELETE FROM package WHERE name = ?;",
145                         ((pkg,) for pkg in delpkgs))
146         # Tables content, dependency and sharing will also be pruned
147         # due to ON DELETE CASCADE clauses.
148         db.commit()
149     try:
150         os.rmdir(tmpdir)
151     except OSError as err:
152         if err.errno != errno.ENOTEMPTY:
153             raise
154         print("keeping temporary directory %s due to failed packages %s" %
155               (tmpdir, " ".join(os.listdir(tmpdir))))
156
157 if __name__ == "__main__":
158     main()