teach ArReader to read multiple entries
authorHelmut Grohne <helmut@subdivi.de>
Wed, 20 Feb 2013 14:55:05 +0000 (15:55 +0100)
committerHelmut Grohne <helmut@subdivi.de>
Wed, 20 Feb 2013 14:55:05 +0000 (15:55 +0100)
test.py

diff --git a/test.py b/test.py
index 4db7d83..3939ff4 100755 (executable)
--- a/test.py
+++ b/test.py
@@ -23,10 +23,10 @@ class ArReader(object):
     global_magic = b"!<arch>\n"
     file_magic = b"`\n"
 
-    def __init__(self, fileobj, membertest):
+    def __init__(self, fileobj):
         self.fileobj = fileobj
-        self.membertest = membertest
         self.remaining = None
+        self.padding = 0
 
     def skip(self, length):
         while length:
@@ -35,25 +35,32 @@ class ArReader(object):
                 raise ValueError("archive truncated")
             length -= len(data)
 
-    def skiptillmember(self):
+    def read_magic(self):
         data = self.fileobj.read(len(self.global_magic))
         if data != self.global_magic:
             raise ValueError("ar global header not found")
-        while True:
-            file_header = self.fileobj.read(60)
-            if not file_header:
-                raise ValueError("end of archive found")
-            parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
-            parts = [p.rstrip(" ") for p in parts]
-            if parts.pop() != self.file_magic:
-                print(repr(file_header))
-                raise ValueError("ar file header not found")
-            name = parts[0]
-            length = int(parts[5])
-            if self.membertest(name):
-                self.remaining = length
-                return name
-            self.skip(length + length % 2)
+        self.remaining = 0
+
+    def read_entry(self):
+        self.skip_current_entry()
+        if self.padding:
+            if self.fileobj.read(1) != '\n':
+                raise ValueError("missing ar padding")
+            self.padding = 0
+        file_header = self.fileobj.read(60)
+        if not file_header:
+            raise EOFError("end of archive found")
+        parts = struct.unpack("16s 12s 6s 6s 8s 10s 2s", file_header)
+        parts = [p.rstrip(" ") for p in parts]
+        if parts.pop() != self.file_magic:
+            raise ValueError("ar file header not found")
+        self.remaining = int(parts[5])
+        self.padding = self.remaining % 2
+        return parts[0] # name
+
+    def skip_current_entry(self):
+        self.skip(self.remaining)
+        self.remaining = 0
 
     def read(self, length=None):
         if length is None:
@@ -229,26 +236,34 @@ def gziphash():
     return HashBlacklist(hashobj, boring_sha512_hashes)
 
 def get_hashes(filelike):
-    af = ArReader(filelike, lambda name: name.startswith("data.tar"))
-    name = af.skiptillmember()
-    if name == "data.tar.gz":
-        tf = tarfile.open(fileobj=af, mode="r|gz")
-    elif name == "data.tar.bz2":
-        tf = tarfile.open(fileobj=af, mode="r|bz2")
-    elif name == "data.tar.xz":
-        zf = XzStream(af)
-        tf = tarfile.open(fileobj=zf, mode="r|")
-    else:
-        raise ValueError("unsupported compression %r" % name)
-    for elem in tf:
-        if not elem.isreg(): # excludes hard links as well
+    af = ArReader(filelike)
+    af.read_magic()
+    tf = None
+    while True:
+        try:
+            name = af.read_entry()
+        except EOFError:
+            return
+        if name == "data.tar.gz":
+            tf = tarfile.open(fileobj=af, mode="r|gz")
+        elif name == "data.tar.bz2":
+            tf = tarfile.open(fileobj=af, mode="r|bz2")
+        elif name == "data.tar.xz":
+            zf = XzStream(af)
+            tf = tarfile.open(fileobj=zf, mode="r|")
+        else:
             continue
-        hasher = MultiHash(sha512_nontrivial(), gziphash())
-        hasher = hash_file(hasher, tf.extractfile(elem))
-        for hashobj in hasher.hashes:
-            hashvalue = hashobj.hexdigest()
-            if hashvalue:
-                yield (elem.name, elem.size, hashobj.name, hashvalue)
+        for elem in tf:
+            if not elem.isreg(): # excludes hard links as well
+                continue
+            hasher = MultiHash(sha512_nontrivial(), gziphash())
+            hasher = hash_file(hasher, tf.extractfile(elem))
+            for hashobj in hasher.hashes:
+                hashvalue = hashobj.hexdigest()
+                if hashvalue:
+                    yield (elem.name, elem.size, hashobj.name, hashvalue)
+    if not tf:
+        raise ValueError("data.tar not found")
 
 def main():
     filename = sys.argv[1]