diff --git a/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy b/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy index c9798a1f..2b3e56a3 100644 --- a/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy +++ b/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy @@ -23,7 +23,7 @@ class DownloadSession { private static int SAMPLES = 10 private final String meB64 - private final Pieces pieces + private final Pieces downloaded, claimed private final InfoHash infoHash private final Endpoint endpoint private final File file @@ -36,10 +36,11 @@ class DownloadSession { private ByteBuffer mapped - DownloadSession(String meB64, Pieces pieces, InfoHash infoHash, Endpoint endpoint, File file, + DownloadSession(String meB64, Pieces downloaded, Pieces claimed, InfoHash infoHash, Endpoint endpoint, File file, int pieceSize, long fileLength) { this.meB64 = meB64 - this.pieces = pieces + this.downloaded = downloaded + this.claimed = claimed this.endpoint = endpoint this.infoHash = infoHash this.file = file @@ -53,11 +54,31 @@ class DownloadSession { } } - public void request() throws IOException { + /** + * @return if the request will proceed. The only time it may not + * is if all the pieces have been claimed by other sessions. + * @throws IOException + */ + public boolean request() throws IOException { OutputStream os = endpoint.getOutputStream() InputStream is = endpoint.getInputStream() - int piece = pieces.getRandomPiece() + int piece + while(true) { + piece = downloaded.getRandomPiece() + if (claimed.isMarked(piece)) { + if (downloaded.donePieces() + claimed.donePieces() == downloaded.nPieces) { + log.info("all pieces claimed") + return false + } + continue + } + break + } + claimed.markDownloaded(piece) + + log.info("will download piece $piece") + long start = piece * pieceSize long end = Math.min(fileLength, start + pieceSize) - 1 long length = end - start + 1 @@ -145,10 +166,12 @@ class DownloadSession { if (hash != expected) throw new BadHashException() - pieces.markDownloaded(piece) + downloaded.markDownloaded(piece) } finally { + claimed.clear(piece) try { channel?.close() } catch (IOException ignore) {} } + return true } synchronized int positionInPiece() { diff --git a/core/src/main/groovy/com/muwire/core/download/Downloader.groovy b/core/src/main/groovy/com/muwire/core/download/Downloader.groovy index 398b9320..43dce26e 100644 --- a/core/src/main/groovy/com/muwire/core/download/Downloader.groovy +++ b/core/src/main/groovy/com/muwire/core/download/Downloader.groovy @@ -18,7 +18,7 @@ public class Downloader { private final DownloadManager downloadManager private final String meB64 private final File file - private final Pieces pieces + private final Pieces downloaded, claimed private final long length private final InfoHash infoHash private final int pieceSize @@ -53,7 +53,8 @@ public class Downloader { nPieces = length / pieceSize + 1 this.nPieces = nPieces - pieces = new Pieces(nPieces, Constants.DOWNLOAD_SEQUENTIAL_RATIO) + downloaded = new Pieces(nPieces, Constants.DOWNLOAD_SEQUENTIAL_RATIO) + claimed = new Pieces(nPieces) currentState = DownloadState.CONNECTING } @@ -64,13 +65,18 @@ public class Downloader { try { endpoint = connector.connect(destination) currentState = DownloadState.DOWNLOADING - while(!pieces.isComplete()) { - currentSession = new DownloadSession(meB64, pieces, infoHash, endpoint, file, pieceSize, length) - currentSession.request() + boolean requestPerformed + while(!downloaded.isComplete()) { + currentSession = new DownloadSession(meB64, downloaded, claimed, infoHash, endpoint, file, pieceSize, length) + requestPerformed = currentSession.request() + if (!requestPerformed) + break writePieces() } - currentState = DownloadState.FINISHED - piecesFile.delete() + if (requestPerformed) { + currentState = DownloadState.FINISHED + piecesFile.delete() + } else log.info("request not performed") } catch (Exception bad) { log.log(Level.WARNING,"Exception while downloading",bad) if (cancelled) @@ -87,20 +93,20 @@ public class Downloader { return piecesFile.withReader { int piece = Integer.parseInt(it.readLine()) - pieces.markDownloaded(piece) + downloaded.markDownloaded(piece) } } void writePieces() { piecesFile.withPrintWriter { writer -> - pieces.getDownloaded().each { piece -> + downloaded.getDownloaded().each { piece -> writer.println(piece) } } } public long donePieces() { - pieces.donePieces() + downloaded.donePieces() } public int positionInPiece() { diff --git a/core/src/main/groovy/com/muwire/core/download/Pieces.groovy b/core/src/main/groovy/com/muwire/core/download/Pieces.groovy index 71749c58..98dd27f6 100644 --- a/core/src/main/groovy/com/muwire/core/download/Pieces.groovy +++ b/core/src/main/groovy/com/muwire/core/download/Pieces.groovy @@ -28,7 +28,8 @@ class Pieces { while(true) { int start = random.nextInt(nPieces) - while(bitSet.get(start) && ++start < nPieces); + if (bitSet.get(start)) + continue return start } } @@ -45,10 +46,18 @@ class Pieces { bitSet.set(piece) } + synchronized void clear(int piece) { + bitSet.clear(piece) + } + synchronized boolean isComplete() { bitSet.cardinality() == nPieces } + synchronized boolean isMarked(int piece) { + bitSet.get(piece) + } + synchronized int donePieces() { bitSet.cardinality() } diff --git a/core/src/test/groovy/com/muwire/core/download/DownloadSessionTest.groovy b/core/src/test/groovy/com/muwire/core/download/DownloadSessionTest.groovy index 75d15122..12ed6ef7 100644 --- a/core/src/test/groovy/com/muwire/core/download/DownloadSessionTest.groovy +++ b/core/src/test/groovy/com/muwire/core/download/DownloadSessionTest.groovy @@ -15,7 +15,7 @@ class DownloadSessionTest { private File source, target private InfoHash infoHash private Endpoint endpoint - private Pieces pieces + private Pieces pieces, claimed private String rootBase64 private DownloadSession session @@ -24,7 +24,7 @@ class DownloadSessionTest { private InputStream fromDownloader, fromUploader private OutputStream toDownloader, toUploader - private void initSession(int size) { + private void initSession(int size, def claimedPieces = []) { Random r = new Random() byte [] content = new byte[size] r.nextBytes(content) @@ -48,6 +48,8 @@ class DownloadSessionTest { else nPieces = size / pieceSize + 1 pieces = new Pieces(nPieces) + claimed = new Pieces(nPieces) + claimedPieces.each {claimed.markDownloaded(it)} fromDownloader = new PipedInputStream() fromUploader = new PipedInputStream() @@ -55,7 +57,7 @@ class DownloadSessionTest { toUploader = new PipedOutputStream(fromDownloader) endpoint = new Endpoint(null, fromUploader, toUploader, null) - session = new DownloadSession("",pieces, infoHash, endpoint, target, pieceSize, size) + session = new DownloadSession("",pieces, claimed, infoHash, endpoint, target, pieceSize, size) downloadThread = new Thread( { session.request() } as Runnable) downloadThread.setDaemon(true) downloadThread.start() @@ -138,4 +140,29 @@ class DownloadSessionTest { assert !pieces.isComplete() assert 1 == pieces.donePieces() } + + @Test + public void testSmallFileClaimed() { + initSession(20, [0]) + long now = System.currentTimeMillis() + downloadThread.join(100) + assert 100 > (System.currentTimeMillis() - now) + } + + @Test + public void testClaimedPiecesAvoided() { + int pieceSize = FileHasher.getPieceSize(1) + int size = (1 << pieceSize) * 10 + initSession(size, [1,2,3,4,5,6,7,8,9]) + assert !claimed.isMarked(0) + + assert "GET $rootBase64" == readTillRN(fromDownloader) + String range = readTillRN(fromDownloader) + def matcher = (range =~ /^Range: (\d+)-(\d+)$/) + int start = Integer.parseInt(matcher[0][1]) + int end = Integer.parseInt(matcher[0][2]) + + assert claimed.isMarked(0) + assert start == 0 && end == (1 << pieceSize) - 1 + } }