autoimport: avoid hard coded temporary directory
[~helmut/debian-dedup.git] / autoimport.py
1 #!/usr/bin/python
2 """This scrip takes a directory or a http base url to a mirror and imports all
3 packages contained. It has rather strong assumptions on the working directory.
4 """
5
6 import gzip
7 import errno
8 import io
9 import multiprocessing
10 import optparse
11 import os
12 import sqlite3
13 import subprocess
14 import tempfile
15 import urllib
16
17 import concurrent.futures
18 from debian import deb822
19 from debian.debian_support import version_compare
20
21 from readyaml import readyaml
22
23 def process_http(pkgs, url):
24     pkglist = urllib.urlopen(url + "/dists/sid/main/binary-amd64/Packages.gz").read()
25     pkglist = gzip.GzipFile(fileobj=io.BytesIO(pkglist)).read()
26     pkglist = io.BytesIO(pkglist)
27     pkglist = deb822.Packages.iter_paragraphs(pkglist)
28     for pkg in pkglist:
29         name = pkg["Package"]
30         if name in pkgs and \
31                 version_compare(pkgs[name]["version"], pkg["Version"]) > 0:
32             continue
33         pkgs[name] = dict(version=pkg["Version"],
34                           filename="%s/%s" % (url, pkg["Filename"]),
35                           sha256hash=pkg["SHA256"])
36
37 def process_file(pkgs, filename):
38     base = os.path.basename(filename)
39     if not base.endswith(".deb"):
40         raise ValueError("filename does not end in .deb")
41     parts = base.split("_")
42     if len(parts) != 3:
43         raise ValueError("filename not in form name_version_arch.deb")
44     name, version, _ = parts
45     version = urllib.unquote(version)
46     if name in pkgs and version_compare(pkgs[name]["version"], version) > 0:
47         return
48     pkgs[name] = dict(version=version, filename=filename)
49
50 def process_dir(pkgs, d):
51     for entry in os.listdir(d):
52         try:
53             process_file(pkgs, os.path.join(d, entry))
54         except ValueError:
55             pass
56
57 def process_pkg(name, pkgdict, outpath):
58     filename = pkgdict["filename"]
59     print("importing %s" % filename)
60     importcmd = ["python", "importpkg.py"]
61     if "sha256hash" in pkgdict:
62         importcmd.extend(["-H", pkgdict["sha256hash"]])
63     if filename.startswith("http://"):
64         with open(outpath, "w") as outp:
65             dl = subprocess.Popen(["curl", "-s", filename],
66                                   stdout=subprocess.PIPE, close_fds=True)
67             imp = subprocess.Popen(importcmd, stdin=dl.stdout, stdout=outp,
68                                    close_fds=True)
69             if imp.wait():
70                 raise ValueError("importpkg failed")
71             if dl.wait():
72                 raise ValueError("curl failed")
73     else:
74         with open(filename) as inp:
75             with open(outpath, "w") as outp:
76                 subprocess.check_call(importcmd, stdin=inp, stdout=outp,
77                                       close_fds=True)
78     print("preprocessed %s" % name)
79
80 def main():
81     parser = optparse.OptionParser()
82     parser.add_option("-n", "--new", action="store_true",
83                       help="avoid reimporting same versions")
84     parser.add_option("-p", "--prune", action="store_true",
85                       help="prune packages old packages")
86     options, args = parser.parse_args()
87     tmpdir = tempfile.mkdtemp(prefix=b"debian-dedup")
88     db = sqlite3.connect("test.sqlite3")
89     cur = db.cursor()
90     cur.execute("PRAGMA foreign_keys = ON;")
91     e = concurrent.futures.ThreadPoolExecutor(multiprocessing.cpu_count())
92     pkgs = {}
93     for d in args:
94         print("processing %s" % d)
95         if d.startswith("http://"):
96             process_http(pkgs, d)
97         elif os.path.isdir(d):
98             process_dir(pkgs, d)
99         else:
100             process_file(pkgs, d)
101
102     print("reading database")
103     cur.execute("SELECT name, version FROM package;")
104     knownpkgs = dict((row[0], row[1]) for row in cur.fetchall())
105     distpkgs = set(pkgs.keys())
106     if options.new:
107         for name in distpkgs:
108             if name in knownpkgs and version_compare(pkgs[name]["version"],
109                     knownpkgs[name]) <= 0:
110                 del pkgs[name]
111     knownpkgs = set(knownpkgs)
112
113     with e:
114         fs = {}
115         for name, pkg in pkgs.items():
116             outpath = os.path.join(tmpdir, name)
117             fs[e.submit(process_pkg, name, pkg, outpath)] = name
118
119         for f in concurrent.futures.as_completed(fs.keys()):
120             name = fs[f]
121             if f.exception():
122                 print("%s failed to import: %r" % (name, f.exception()))
123                 continue
124             inf = os.path.join(tmpdir, name)
125             print("sqlimporting %s" % name)
126             with open(inf) as inp:
127                 try:
128                     readyaml(db, inp)
129                 except Exception as exc:
130                     print("%s failed sql with exception %r" % (name, exc))
131                 else:
132                     os.unlink(inf)
133
134     if options.prune:
135         delpkgs = knownpkgs - distpkgs
136         print("clearing packages %s" % " ".join(delpkgs))
137         cur.executemany("DELETE FROM package WHERE name = ?;",
138                         ((pkg,) for pkg in delpkgs))
139         # Tables content, dependency and sharing will also be pruned
140         # due to ON DELETE CASCADE clauses.
141         db.commit()
142     try:
143         os.rmdir(tmpdir)
144     except OSError as err:
145         if err.errno != errno.ENOTEMPTY:
146             raise
147         print("keeping temporary directory %s due to failed packages %s" %
148               (tmpdir, " ".join(os.listdir(tmpdir))))
149
150 if __name__ == "__main__":
151     main()