dedup.arreader: remove trailing slash from ar members
[~helmut/debian-dedup.git] / autoimport.py
1 #!/usr/bin/python
2 """This scrip takes a directory or a http base url to a mirror and imports all
3 packages contained. It has rather strong assumptions on the working directory.
4 """
5
6 import gzip
7 import errno
8 import io
9 import multiprocessing
10 import optparse
11 import os
12 import sqlite3
13 import subprocess
14 import tempfile
15 import urllib
16
17 import concurrent.futures
18 from debian import deb822
19 from debian.debian_support import version_compare
20
21 from readyaml import readyaml
22
23 def process_http(pkgs, url):
24     pkglist = urllib.urlopen(url + "/dists/sid/main/binary-amd64/Packages.gz").read()
25     pkglist = gzip.GzipFile(fileobj=io.BytesIO(pkglist)).read()
26     pkglist = io.BytesIO(pkglist)
27     pkglist = deb822.Packages.iter_paragraphs(pkglist)
28     for pkg in pkglist:
29         name = pkg["Package"]
30         if name in pkgs and \
31                 version_compare(pkgs[name]["version"], pkg["Version"]) > 0:
32             continue
33         pkgs[name] = dict(version=pkg["Version"],
34                           filename="%s/%s" % (url, pkg["Filename"]),
35                           sha256hash=pkg["SHA256"])
36
37 def process_file(pkgs, filename):
38     base = os.path.basename(filename)
39     if not base.endswith(".deb"):
40         raise ValueError("filename does not end in .deb")
41     parts = base.split("_")
42     if len(parts) != 3:
43         raise ValueError("filename not in form name_version_arch.deb")
44     name, version, _ = parts
45     version = urllib.unquote(version)
46     if name in pkgs and version_compare(pkgs[name]["version"], version) > 0:
47         return
48     pkgs[name] = dict(version=version, filename=filename)
49
50 def process_dir(pkgs, d):
51     for entry in os.listdir(d):
52         try:
53             process_file(pkgs, os.path.join(d, entry))
54         except ValueError:
55             pass
56
57 def process_pkg(name, pkgdict, outpath):
58     filename = pkgdict["filename"]
59     print("importing %s" % filename)
60     importcmd = ["python", "importpkg.py"]
61     if "sha256hash" in pkgdict:
62         importcmd.extend(["-H", pkgdict["sha256hash"]])
63     if filename.startswith(("http://", "https://", "ftp://", "file://")):
64         with open(outpath, "w") as outp:
65             dl = subprocess.Popen(["curl", "-s", filename],
66                                   stdout=subprocess.PIPE, close_fds=True)
67             imp = subprocess.Popen(importcmd, stdin=dl.stdout, stdout=outp,
68                                    close_fds=True)
69             if imp.wait():
70                 raise ValueError("importpkg failed")
71             if dl.wait():
72                 raise ValueError("curl failed")
73     else:
74         with open(filename) as inp:
75             with open(outpath, "w") as outp:
76                 subprocess.check_call(importcmd, stdin=inp, stdout=outp,
77                                       close_fds=True)
78     print("preprocessed %s" % name)
79
80 def main():
81     parser = optparse.OptionParser()
82     parser.add_option("-n", "--new", action="store_true",
83                       help="avoid reimporting same versions")
84     parser.add_option("-p", "--prune", action="store_true",
85                       help="prune packages old packages")
86     parser.add_option("-d", "--database", action="store",
87                       default="test.sqlite3",
88                       help="path to the sqlite3 database file")
89     options, args = parser.parse_args()
90     tmpdir = tempfile.mkdtemp(prefix=b"debian-dedup")
91     db = sqlite3.connect(options.database)
92     cur = db.cursor()
93     cur.execute("PRAGMA foreign_keys = ON;")
94     e = concurrent.futures.ThreadPoolExecutor(multiprocessing.cpu_count())
95     pkgs = {}
96     for d in args:
97         print("processing %s" % d)
98         if d.startswith(("http://", "https://", "ftp://", "file://")):
99             process_http(pkgs, d)
100         elif os.path.isdir(d):
101             process_dir(pkgs, d)
102         else:
103             process_file(pkgs, d)
104
105     print("reading database")
106     cur.execute("SELECT name, version FROM package;")
107     knownpkgs = dict((row[0], row[1]) for row in cur.fetchall())
108     distpkgs = set(pkgs.keys())
109     if options.new:
110         for name in distpkgs:
111             if name in knownpkgs and version_compare(pkgs[name]["version"],
112                     knownpkgs[name]) <= 0:
113                 del pkgs[name]
114     knownpkgs = set(knownpkgs)
115
116     with e:
117         fs = {}
118         for name, pkg in pkgs.items():
119             outpath = os.path.join(tmpdir, name)
120             fs[e.submit(process_pkg, name, pkg, outpath)] = name
121
122         for f in concurrent.futures.as_completed(fs.keys()):
123             name = fs[f]
124             if f.exception():
125                 print("%s failed to import: %r" % (name, f.exception()))
126                 continue
127             inf = os.path.join(tmpdir, name)
128             print("sqlimporting %s" % name)
129             with open(inf) as inp:
130                 try:
131                     readyaml(db, inp)
132                 except Exception as exc:
133                     print("%s failed sql with exception %r" % (name, exc))
134                 else:
135                     os.unlink(inf)
136
137     if options.prune:
138         delpkgs = knownpkgs - distpkgs
139         print("clearing packages %s" % " ".join(delpkgs))
140         cur.executemany("DELETE FROM package WHERE name = ?;",
141                         ((pkg,) for pkg in delpkgs))
142         # Tables content, dependency and sharing will also be pruned
143         # due to ON DELETE CASCADE clauses.
144         db.commit()
145     try:
146         os.rmdir(tmpdir)
147     except OSError as err:
148         if err.errno != errno.ENOTEMPTY:
149             raise
150         print("keeping temporary directory %s due to failed packages %s" %
151               (tmpdir, " ".join(os.listdir(tmpdir))))
152
153 if __name__ == "__main__":
154     main()