decouple a function decompress out of decompress_tar
[~helmut/debian-dedup.git] / autoimport.py
1 #!/usr/bin/python
2 """This scrip takes a directory or a http base url to a mirror and imports all
3 packages contained. It has rather strong assumptions on the working directory.
4 """
5
6 import gzip
7 import errno
8 import io
9 import multiprocessing
10 import optparse
11 import os
12 import sqlite3
13 import subprocess
14 import tempfile
15 try:
16     from urllib.parse import unquote
17 except ImportError:
18     from urllib import unquote
19 try:
20     from urllib.request import urlopen
21 except ImportError:
22     from urllib import urlopen
23
24 import concurrent.futures
25 from debian import deb822
26 from debian.debian_support import version_compare
27
28 from readyaml import readyaml
29
30 def process_http(pkgs, url):
31     pkglist = urlopen(url + "/dists/sid/main/binary-amd64/Packages.gz").read()
32     pkglist = gzip.GzipFile(fileobj=io.BytesIO(pkglist)).read()
33     pkglist = io.BytesIO(pkglist)
34     pkglist = deb822.Packages.iter_paragraphs(pkglist)
35     for pkg in pkglist:
36         name = pkg["Package"]
37         if name in pkgs and \
38                 version_compare(pkgs[name]["version"], pkg["Version"]) > 0:
39             continue
40         pkgs[name] = dict(version=pkg["Version"],
41                           filename="%s/%s" % (url, pkg["Filename"]),
42                           sha256hash=pkg["SHA256"])
43
44 def process_file(pkgs, filename):
45     base = os.path.basename(filename)
46     if not base.endswith(".deb"):
47         raise ValueError("filename does not end in .deb")
48     parts = base.split("_")
49     if len(parts) != 3:
50         raise ValueError("filename not in form name_version_arch.deb")
51     name, version, _ = parts
52     version = unquote(version)
53     if name in pkgs and version_compare(pkgs[name]["version"], version) > 0:
54         return
55     pkgs[name] = dict(version=version, filename=filename)
56
57 def process_dir(pkgs, d):
58     for entry in os.listdir(d):
59         try:
60             process_file(pkgs, os.path.join(d, entry))
61         except ValueError:
62             pass
63
64 def process_pkg(name, pkgdict, outpath):
65     filename = pkgdict["filename"]
66     print("importing %s" % filename)
67     importcmd = ["python", "importpkg.py"]
68     if "sha256hash" in pkgdict:
69         importcmd.extend(["-H", pkgdict["sha256hash"]])
70     if filename.startswith(("http://", "https://", "ftp://", "file://")):
71         with open(outpath, "w") as outp:
72             dl = subprocess.Popen(["curl", "-s", filename],
73                                   stdout=subprocess.PIPE, close_fds=True)
74             imp = subprocess.Popen(importcmd, stdin=dl.stdout, stdout=outp,
75                                    close_fds=True)
76             if imp.wait():
77                 raise ValueError("importpkg failed")
78             if dl.wait():
79                 raise ValueError("curl failed")
80     else:
81         with open(filename) as inp:
82             with open(outpath, "w") as outp:
83                 subprocess.check_call(importcmd, stdin=inp, stdout=outp,
84                                       close_fds=True)
85     print("preprocessed %s" % name)
86
87 def main():
88     parser = optparse.OptionParser()
89     parser.add_option("-n", "--new", action="store_true",
90                       help="avoid reimporting same versions")
91     parser.add_option("-p", "--prune", action="store_true",
92                       help="prune packages old packages")
93     parser.add_option("-d", "--database", action="store",
94                       default="test.sqlite3",
95                       help="path to the sqlite3 database file")
96     options, args = parser.parse_args()
97     tmpdir = tempfile.mkdtemp(prefix="debian-dedup")
98     db = sqlite3.connect(options.database)
99     cur = db.cursor()
100     cur.execute("PRAGMA foreign_keys = ON;")
101     e = concurrent.futures.ThreadPoolExecutor(multiprocessing.cpu_count())
102     pkgs = {}
103     for d in args:
104         print("processing %s" % d)
105         if d.startswith(("http://", "https://", "ftp://", "file://")):
106             process_http(pkgs, d)
107         elif os.path.isdir(d):
108             process_dir(pkgs, d)
109         else:
110             process_file(pkgs, d)
111
112     print("reading database")
113     cur.execute("SELECT name, version FROM package;")
114     knownpkgs = dict((row[0], row[1]) for row in cur.fetchall())
115     distpkgs = set(pkgs.keys())
116     if options.new:
117         for name in distpkgs:
118             if name in knownpkgs and version_compare(pkgs[name]["version"],
119                     knownpkgs[name]) <= 0:
120                 del pkgs[name]
121     knownpkgs = set(knownpkgs)
122
123     with e:
124         fs = {}
125         for name, pkg in pkgs.items():
126             outpath = os.path.join(tmpdir, name)
127             fs[e.submit(process_pkg, name, pkg, outpath)] = name
128
129         for f in concurrent.futures.as_completed(fs.keys()):
130             name = fs[f]
131             if f.exception():
132                 print("%s failed to import: %r" % (name, f.exception()))
133                 continue
134             inf = os.path.join(tmpdir, name)
135             print("sqlimporting %s" % name)
136             with open(inf) as inp:
137                 try:
138                     readyaml(db, inp)
139                 except Exception as exc:
140                     print("%s failed sql with exception %r" % (name, exc))
141                 else:
142                     os.unlink(inf)
143
144     if options.prune:
145         delpkgs = knownpkgs - distpkgs
146         print("clearing packages %s" % " ".join(delpkgs))
147         cur.executemany("DELETE FROM package WHERE name = ?;",
148                         ((pkg,) for pkg in delpkgs))
149         # Tables content, dependency and sharing will also be pruned
150         # due to ON DELETE CASCADE clauses.
151         db.commit()
152     try:
153         os.rmdir(tmpdir)
154     except OSError as err:
155         if err.errno != errno.ENOTEMPTY:
156             raise
157         print("keeping temporary directory %s due to failed packages %s" %
158               (tmpdir, " ".join(os.listdir(tmpdir))))
159
160 if __name__ == "__main__":
161     main()