From 20fab9b16df63e3246238d5f3a5e6d70d225fb9b Mon Sep 17 00:00:00 2001 From: Zlatin Balevsky Date: Sat, 6 Jul 2019 00:17:46 +0100 Subject: [PATCH] work on partial piece persistence --- .../core/download/DownloadSession.groovy | 29 +++++++++++-------- .../muwire/core/download/Downloader.groovy | 14 +++++---- .../com/muwire/core/download/Pieces.groovy | 28 ++++++++++++++---- 3 files changed, 48 insertions(+), 23 deletions(-) diff --git a/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy b/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy index 5f54ce0e..d510dccf 100644 --- a/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy +++ b/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy @@ -69,21 +69,22 @@ class DownloadSession { OutputStream os = endpoint.getOutputStream() InputStream is = endpoint.getInputStream() - int piece + int[] pieceAndPosition if (available.isEmpty()) - piece = pieces.claim() + pieceAndPosition = pieces.claim() else - piece = pieces.claim(new HashSet<>(available)) - if (piece == -1) + pieceAndPosition = pieces.claim(new HashSet<>(available)) + if (pieceAndPosition == null) return false + int piece = pieceAndPosition[0] + int position = pieceAndPosition[1] boolean unclaim = true - log.info("will download piece $piece") - - long start = piece * pieceSize - long end = Math.min(fileLength, start + pieceSize) - 1 - long length = end - start + 1 + log.info("will download piece $piece from position $position") + long pieceStart = piece * pieceSize + long end = Math.min(fileLength, pieceStart + pieceSize) - 1 + long start = pieceStart + position String root = Base64.encode(infoHash.getRoot()) try { @@ -172,8 +173,9 @@ class DownloadSession { FileChannel channel try { channel = Files.newByteChannel(file.toPath(), EnumSet.of(StandardOpenOption.READ, StandardOpenOption.WRITE, - StandardOpenOption.SPARSE, StandardOpenOption.CREATE)) // TODO: double-check, maybe CREATE_NEW - mapped = channel.map(FileChannel.MapMode.READ_WRITE, start, end - start + 1) + StandardOpenOption.SPARSE, StandardOpenOption.CREATE)) + mapped = channel.map(FileChannel.MapMode.READ_WRITE, pieceStart, end - start + 1) + mapped.position(position) byte[] tmp = new byte[0x1 << 13] while(mapped.hasRemaining()) { @@ -185,6 +187,7 @@ class DownloadSession { synchronized(this) { mapped.put(tmp, 0, read) dataSinceLastRead += read + pieces.markPartial(piece, mapped.position()) } } @@ -194,8 +197,10 @@ class DownloadSession { byte [] hash = digest.digest() byte [] expected = new byte[32] System.arraycopy(infoHash.getHashList(), piece * 32, expected, 0, 32) - if (hash != expected) + if (hash != expected) { + pieces.markPartial(piece, 0) throw new BadHashException() + } } finally { try { channel?.close() } catch (IOException ignore) {} } diff --git a/core/src/main/groovy/com/muwire/core/download/Downloader.groovy b/core/src/main/groovy/com/muwire/core/download/Downloader.groovy index 554647aa..4612d8ff 100644 --- a/core/src/main/groovy/com/muwire/core/download/Downloader.groovy +++ b/core/src/main/groovy/com/muwire/core/download/Downloader.groovy @@ -111,8 +111,14 @@ public class Downloader { if (!piecesFile.exists()) return piecesFile.eachLine { - int piece = Integer.parseInt(it) - pieces.markDownloaded(piece) + String [] split = it.split(",") + int piece = Integer.parseInt(split[0]) + if (split.length == 1) + pieces.markDownloaded(piece) + else { + int position = Integer.parseInt(split[1]) + pieces.markPartial(piece, position) + } } } @@ -121,9 +127,7 @@ public class Downloader { if (piecesFileClosed) return piecesFile.withPrintWriter { writer -> - pieces.getDownloaded().each { piece -> - writer.println(piece) - } + pieces.write(writer) } } } diff --git a/core/src/main/groovy/com/muwire/core/download/Pieces.groovy b/core/src/main/groovy/com/muwire/core/download/Pieces.groovy index f3eecedb..95d00f59 100644 --- a/core/src/main/groovy/com/muwire/core/download/Pieces.groovy +++ b/core/src/main/groovy/com/muwire/core/download/Pieces.groovy @@ -5,6 +5,7 @@ class Pieces { private final int nPieces private final float ratio private final Random random = new Random() + private final Map partials = new HashMap<>() Pieces(int nPieces) { this(nPieces, 1.0f) @@ -17,16 +18,16 @@ class Pieces { claimed = new BitSet(nPieces) } - synchronized int claim() { + synchronized int[] claim() { int claimedCardinality = claimed.cardinality() if (claimedCardinality == nPieces) - return -1 + return null // if fuller than ratio just do sequential if ( (1.0f * claimedCardinality) / nPieces > ratio) { int rv = claimed.nextClearBit(0) claimed.set(rv) - return rv + return [rv, partials.getOrDefault(rv, 0)] } while(true) { @@ -34,11 +35,11 @@ class Pieces { if (claimed.get(start)) continue claimed.set(start) - return start + return [start, partials.getOrDefault(start,0)] } } - synchronized int claim(Set available) { + synchronized int[] claim(Set available) { for (int i = claimed.nextSetBit(0); i >= 0; i = claimed.nextSetBit(i+1)) available.remove(i) if (available.isEmpty()) @@ -47,7 +48,7 @@ class Pieces { Collections.shuffle(toList) int rv = toList[0] claimed.set(rv) - rv + [rv, partials.getOrDefault(rv, 0)] } synchronized def getDownloaded() { @@ -61,6 +62,11 @@ class Pieces { synchronized void markDownloaded(int piece) { done.set(piece) claimed.set(piece) + partials.remove(piece) + } + + synchronized void markPartial(int piece, int position) { + partials.put(piece, position) } synchronized void unclaim(int piece) { @@ -82,5 +88,15 @@ class Pieces { synchronized void clearAll() { done.clear() claimed.clear() + partials.clear() + } + + synchronized void write(PrintWriter writer) { + for (int i = done.nextSetBit(0); i >= 0; i = done.nextSetBit(i+1)) { + writer.println(i) + } + partials.each { piece, position -> + writer.println("$piece,$position") + } } }