keep track of claimed pieces in preparation for multi-source downloads

This commit is contained in:
Zlatin Balevsky
2019-06-04 02:18:30 +01:00
parent c91440cbfc
commit e7240dcb6f
4 changed files with 85 additions and 20 deletions

View File

@@ -23,7 +23,7 @@ class DownloadSession {
private static int SAMPLES = 10
private final String meB64
private final Pieces pieces
private final Pieces downloaded, claimed
private final InfoHash infoHash
private final Endpoint endpoint
private final File file
@@ -36,10 +36,11 @@ class DownloadSession {
private ByteBuffer mapped
DownloadSession(String meB64, Pieces pieces, InfoHash infoHash, Endpoint endpoint, File file,
DownloadSession(String meB64, Pieces downloaded, Pieces claimed, InfoHash infoHash, Endpoint endpoint, File file,
int pieceSize, long fileLength) {
this.meB64 = meB64
this.pieces = pieces
this.downloaded = downloaded
this.claimed = claimed
this.endpoint = endpoint
this.infoHash = infoHash
this.file = file
@@ -53,11 +54,31 @@ class DownloadSession {
}
}
public void request() throws IOException {
/**
* @return if the request will proceed. The only time it may not
* is if all the pieces have been claimed by other sessions.
* @throws IOException
*/
public boolean request() throws IOException {
OutputStream os = endpoint.getOutputStream()
InputStream is = endpoint.getInputStream()
int piece = pieces.getRandomPiece()
int piece
while(true) {
piece = downloaded.getRandomPiece()
if (claimed.isMarked(piece)) {
if (downloaded.donePieces() + claimed.donePieces() == downloaded.nPieces) {
log.info("all pieces claimed")
return false
}
continue
}
break
}
claimed.markDownloaded(piece)
log.info("will download piece $piece")
long start = piece * pieceSize
long end = Math.min(fileLength, start + pieceSize) - 1
long length = end - start + 1
@@ -145,10 +166,12 @@ class DownloadSession {
if (hash != expected)
throw new BadHashException()
pieces.markDownloaded(piece)
downloaded.markDownloaded(piece)
} finally {
claimed.clear(piece)
try { channel?.close() } catch (IOException ignore) {}
}
return true
}
synchronized int positionInPiece() {

View File

@@ -18,7 +18,7 @@ public class Downloader {
private final DownloadManager downloadManager
private final String meB64
private final File file
private final Pieces pieces
private final Pieces downloaded, claimed
private final long length
private final InfoHash infoHash
private final int pieceSize
@@ -53,7 +53,8 @@ public class Downloader {
nPieces = length / pieceSize + 1
this.nPieces = nPieces
pieces = new Pieces(nPieces, Constants.DOWNLOAD_SEQUENTIAL_RATIO)
downloaded = new Pieces(nPieces, Constants.DOWNLOAD_SEQUENTIAL_RATIO)
claimed = new Pieces(nPieces)
currentState = DownloadState.CONNECTING
}
@@ -64,13 +65,18 @@ public class Downloader {
try {
endpoint = connector.connect(destination)
currentState = DownloadState.DOWNLOADING
while(!pieces.isComplete()) {
currentSession = new DownloadSession(meB64, pieces, infoHash, endpoint, file, pieceSize, length)
currentSession.request()
boolean requestPerformed
while(!downloaded.isComplete()) {
currentSession = new DownloadSession(meB64, downloaded, claimed, infoHash, endpoint, file, pieceSize, length)
requestPerformed = currentSession.request()
if (!requestPerformed)
break
writePieces()
}
currentState = DownloadState.FINISHED
piecesFile.delete()
if (requestPerformed) {
currentState = DownloadState.FINISHED
piecesFile.delete()
} else log.info("request not performed")
} catch (Exception bad) {
log.log(Level.WARNING,"Exception while downloading",bad)
if (cancelled)
@@ -87,20 +93,20 @@ public class Downloader {
return
piecesFile.withReader {
int piece = Integer.parseInt(it.readLine())
pieces.markDownloaded(piece)
downloaded.markDownloaded(piece)
}
}
void writePieces() {
piecesFile.withPrintWriter { writer ->
pieces.getDownloaded().each { piece ->
downloaded.getDownloaded().each { piece ->
writer.println(piece)
}
}
}
public long donePieces() {
pieces.donePieces()
downloaded.donePieces()
}
public int positionInPiece() {

View File

@@ -28,7 +28,8 @@ class Pieces {
while(true) {
int start = random.nextInt(nPieces)
while(bitSet.get(start) && ++start < nPieces);
if (bitSet.get(start))
continue
return start
}
}
@@ -45,10 +46,18 @@ class Pieces {
bitSet.set(piece)
}
synchronized void clear(int piece) {
bitSet.clear(piece)
}
synchronized boolean isComplete() {
bitSet.cardinality() == nPieces
}
synchronized boolean isMarked(int piece) {
bitSet.get(piece)
}
synchronized int donePieces() {
bitSet.cardinality()
}

View File

@@ -15,7 +15,7 @@ class DownloadSessionTest {
private File source, target
private InfoHash infoHash
private Endpoint endpoint
private Pieces pieces
private Pieces pieces, claimed
private String rootBase64
private DownloadSession session
@@ -24,7 +24,7 @@ class DownloadSessionTest {
private InputStream fromDownloader, fromUploader
private OutputStream toDownloader, toUploader
private void initSession(int size) {
private void initSession(int size, def claimedPieces = []) {
Random r = new Random()
byte [] content = new byte[size]
r.nextBytes(content)
@@ -48,6 +48,8 @@ class DownloadSessionTest {
else
nPieces = size / pieceSize + 1
pieces = new Pieces(nPieces)
claimed = new Pieces(nPieces)
claimedPieces.each {claimed.markDownloaded(it)}
fromDownloader = new PipedInputStream()
fromUploader = new PipedInputStream()
@@ -55,7 +57,7 @@ class DownloadSessionTest {
toUploader = new PipedOutputStream(fromDownloader)
endpoint = new Endpoint(null, fromUploader, toUploader, null)
session = new DownloadSession("",pieces, infoHash, endpoint, target, pieceSize, size)
session = new DownloadSession("",pieces, claimed, infoHash, endpoint, target, pieceSize, size)
downloadThread = new Thread( { session.request() } as Runnable)
downloadThread.setDaemon(true)
downloadThread.start()
@@ -138,4 +140,29 @@ class DownloadSessionTest {
assert !pieces.isComplete()
assert 1 == pieces.donePieces()
}
@Test
public void testSmallFileClaimed() {
initSession(20, [0])
long now = System.currentTimeMillis()
downloadThread.join(100)
assert 100 > (System.currentTimeMillis() - now)
}
@Test
public void testClaimedPiecesAvoided() {
int pieceSize = FileHasher.getPieceSize(1)
int size = (1 << pieceSize) * 10
initSession(size, [1,2,3,4,5,6,7,8,9])
assert !claimed.isMarked(0)
assert "GET $rootBase64" == readTillRN(fromDownloader)
String range = readTillRN(fromDownloader)
def matcher = (range =~ /^Range: (\d+)-(\d+)$/)
int start = Integer.parseInt(matcher[0][1])
int end = Integer.parseInt(matcher[0][2])
assert claimed.isMarked(0)
assert start == 0 && end == (1 << pieceSize) - 1
}
}