download code and tests
This commit is contained in:
@@ -5,4 +5,7 @@ import net.i2p.crypto.SigType
|
|||||||
class Constants {
|
class Constants {
|
||||||
public static final byte PERSONA_VERSION = (byte)1
|
public static final byte PERSONA_VERSION = (byte)1
|
||||||
public static final SigType SIG_TYPE = SigType.ECDSA_SHA512_P521 // TODO: decide which
|
public static final SigType SIG_TYPE = SigType.ECDSA_SHA512_P521 // TODO: decide which
|
||||||
|
|
||||||
|
public static final int MAX_HEADER_SIZE = 0x1 << 14
|
||||||
|
public static final int MAX_HEADERS = 16
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,139 @@
|
|||||||
|
package com.muwire.core.download;
|
||||||
|
|
||||||
|
import net.i2p.data.Base64
|
||||||
|
|
||||||
|
import com.muwire.core.Constants
|
||||||
|
import com.muwire.core.InfoHash
|
||||||
|
import com.muwire.core.connection.Endpoint
|
||||||
|
import static com.muwire.core.util.DataUtil.readTillRN
|
||||||
|
|
||||||
|
import groovy.util.logging.Log
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
import java.nio.channels.FileChannel
|
||||||
|
import java.nio.charset.StandardCharsets
|
||||||
|
import java.nio.file.Files
|
||||||
|
import java.nio.file.StandardOpenOption
|
||||||
|
import java.security.MessageDigest
|
||||||
|
import java.security.NoSuchAlgorithmException
|
||||||
|
|
||||||
|
@Log
|
||||||
|
class DownloadSession {
|
||||||
|
|
||||||
|
private final Pieces pieces
|
||||||
|
private final InfoHash infoHash
|
||||||
|
private final Endpoint endpoint
|
||||||
|
private final File file
|
||||||
|
private final int pieceSize
|
||||||
|
private final long fileLength
|
||||||
|
private final MessageDigest digest
|
||||||
|
|
||||||
|
private ByteBuffer mapped
|
||||||
|
|
||||||
|
DownloadSession(Pieces pieces, InfoHash infoHash, Endpoint endpoint, File file,
|
||||||
|
int pieceSize, long fileLength) {
|
||||||
|
this.pieces = pieces
|
||||||
|
this.endpoint = endpoint
|
||||||
|
this.infoHash = infoHash
|
||||||
|
this.file = file
|
||||||
|
this.pieceSize = pieceSize
|
||||||
|
this.fileLength = fileLength
|
||||||
|
try {
|
||||||
|
digest = MessageDigest.getInstance("SHA-256")
|
||||||
|
} catch (NoSuchAlgorithmException impossible) {
|
||||||
|
digest = null
|
||||||
|
System.exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void request() throws IOException {
|
||||||
|
OutputStream os = endpoint.getOutputStream()
|
||||||
|
InputStream is = endpoint.getInputStream()
|
||||||
|
|
||||||
|
int piece = pieces.getRandomPiece()
|
||||||
|
long start = piece * pieceSize
|
||||||
|
long end = Math.min(fileLength, start + pieceSize) - 1
|
||||||
|
long length = end - start + 1
|
||||||
|
|
||||||
|
String root = Base64.encode(infoHash.getRoot())
|
||||||
|
|
||||||
|
FileChannel channel
|
||||||
|
try {
|
||||||
|
os.write("GET $root\r\n".getBytes(StandardCharsets.US_ASCII))
|
||||||
|
os.write("Range: $start-$end\r\n\r\n".getBytes(StandardCharsets.US_ASCII))
|
||||||
|
os.flush()
|
||||||
|
String code = readTillRN(is)
|
||||||
|
if (code.startsWith("404 ")) {
|
||||||
|
log.warning("file not found")
|
||||||
|
endpoint.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (code.startsWith("416 ")) {
|
||||||
|
log.warning("range $start-$end cannot be satisfied")
|
||||||
|
return // leave endpoint open
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!code.startsWith("200 ")) {
|
||||||
|
log.warning("unknown code $code")
|
||||||
|
endpoint.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse all headers
|
||||||
|
Set<String> headers = new HashSet<>()
|
||||||
|
String header
|
||||||
|
while((header = readTillRN(is)) != "" && headers.size() < Constants.MAX_HEADERS)
|
||||||
|
headers.add(header)
|
||||||
|
|
||||||
|
long receivedStart = -1
|
||||||
|
long receivedEnd = -1
|
||||||
|
for (String receivedHeader : headers) {
|
||||||
|
def group = (receivedHeader =~ /^Content-Range: (\d+)-(\d+)$/)
|
||||||
|
if (group.size() != 1) {
|
||||||
|
log.info("ignoring header $receivedHeader")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedStart = Long.parseLong(group[0][1])
|
||||||
|
receivedEnd = Long.parseLong(group[0][2])
|
||||||
|
}
|
||||||
|
|
||||||
|
if (receivedStart != start || receivedEnd != end) {
|
||||||
|
log.warning("We don't support mismatching ranges yet")
|
||||||
|
endpoint.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// start the download
|
||||||
|
channel = Files.newByteChannel(file.toPath(), EnumSet.of(StandardOpenOption.READ, StandardOpenOption.WRITE,
|
||||||
|
StandardOpenOption.SPARSE, StandardOpenOption.CREATE)) // TODO: double-check, maybe CREATE_NEW
|
||||||
|
mapped = channel.map(FileChannel.MapMode.READ_WRITE, start, end - start + 1)
|
||||||
|
|
||||||
|
byte[] tmp = new byte[0x1 << 13]
|
||||||
|
while(mapped.hasRemaining()) {
|
||||||
|
int read = is.read(tmp)
|
||||||
|
if (read == -1)
|
||||||
|
throw new IOException()
|
||||||
|
synchronized(this) {
|
||||||
|
mapped.put(tmp, 0, read)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mapped.clear()
|
||||||
|
digest.update(mapped)
|
||||||
|
byte [] hash = digest.digest()
|
||||||
|
byte [] expected = new byte[32]
|
||||||
|
System.arraycopy(infoHash.getHashList(), piece * 32, expected, 0, 32)
|
||||||
|
if (hash != expected) {
|
||||||
|
log.warning("hash mismatch")
|
||||||
|
endpoint.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pieces.markDownloaded(piece)
|
||||||
|
} finally {
|
||||||
|
try { channel?.close() } catch (IOException ignore) {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package com.muwire.core.upload
|
|||||||
|
|
||||||
import java.nio.charset.StandardCharsets
|
import java.nio.charset.StandardCharsets
|
||||||
|
|
||||||
|
import com.muwire.core.Constants
|
||||||
import com.muwire.core.InfoHash
|
import com.muwire.core.InfoHash
|
||||||
|
|
||||||
import groovy.util.logging.Log
|
import groovy.util.logging.Log
|
||||||
@@ -19,8 +20,8 @@ class Request {
|
|||||||
|
|
||||||
static Request parse(InfoHash infoHash, InputStream is) throws IOException {
|
static Request parse(InfoHash infoHash, InputStream is) throws IOException {
|
||||||
Map<String,String> headers = new HashMap<>()
|
Map<String,String> headers = new HashMap<>()
|
||||||
byte [] tmp = new byte[0x1 << 14]
|
byte [] tmp = new byte[Constants.MAX_HEADER_SIZE]
|
||||||
while(true) {
|
while(headers.size() < Constants.MAX_HEADERS) {
|
||||||
boolean r = false
|
boolean r = false
|
||||||
boolean n = false
|
boolean n = false
|
||||||
int idx = 0
|
int idx = 0
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package com.muwire.core.util
|
|||||||
|
|
||||||
import java.nio.charset.StandardCharsets
|
import java.nio.charset.StandardCharsets
|
||||||
|
|
||||||
|
import com.muwire.core.Constants
|
||||||
|
|
||||||
class DataUtil {
|
class DataUtil {
|
||||||
|
|
||||||
private final static int MAX_SHORT = (0x1 << 16) - 1
|
private final static int MAX_SHORT = (0x1 << 16) - 1
|
||||||
@@ -61,4 +63,20 @@ class DataUtil {
|
|||||||
daos.close()
|
daos.close()
|
||||||
baos.toByteArray()
|
baos.toByteArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static String readTillRN(InputStream is) {
|
||||||
|
def baos = new ByteArrayOutputStream()
|
||||||
|
while(baos.size() < (Constants.MAX_HEADER_SIZE)) {
|
||||||
|
byte read = is.read()
|
||||||
|
if (read == -1)
|
||||||
|
throw new IOException()
|
||||||
|
if (read == '\r') {
|
||||||
|
if (is.read() != '\n')
|
||||||
|
throw new IOException("invalid header")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
baos.write(read)
|
||||||
|
}
|
||||||
|
new String(baos.toByteArray(), StandardCharsets.US_ASCII)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,89 @@
|
|||||||
|
package com.muwire.core.download
|
||||||
|
|
||||||
|
import org.junit.After
|
||||||
|
import org.junit.Test
|
||||||
|
|
||||||
|
import com.muwire.core.InfoHash
|
||||||
|
import com.muwire.core.connection.Endpoint
|
||||||
|
import com.muwire.core.files.FileHasher
|
||||||
|
import static com.muwire.core.util.DataUtil.readTillRN
|
||||||
|
|
||||||
|
import net.i2p.data.Base64
|
||||||
|
|
||||||
|
class DownloadSessionTest {
|
||||||
|
|
||||||
|
private File source, target
|
||||||
|
private InfoHash infoHash
|
||||||
|
private Endpoint endpoint
|
||||||
|
private Pieces pieces
|
||||||
|
private String rootBase64
|
||||||
|
|
||||||
|
private DownloadSession session
|
||||||
|
private Thread downloadThread
|
||||||
|
|
||||||
|
private InputStream fromDownloader, fromUploader
|
||||||
|
private OutputStream toDownloader, toUploader
|
||||||
|
|
||||||
|
private void initSession(int size) {
|
||||||
|
Random r = new Random()
|
||||||
|
byte [] content = new byte[size]
|
||||||
|
r.nextBytes(content)
|
||||||
|
|
||||||
|
source = File.createTempFile("source", "tmp")
|
||||||
|
source.deleteOnExit()
|
||||||
|
def fos = new FileOutputStream(source)
|
||||||
|
fos.write(content)
|
||||||
|
fos.close()
|
||||||
|
|
||||||
|
def hasher = new FileHasher()
|
||||||
|
infoHash = hasher.hashFile(source)
|
||||||
|
rootBase64 = Base64.encode(infoHash.getRoot())
|
||||||
|
|
||||||
|
target = File.createTempFile("target", "tmp")
|
||||||
|
int pieceSize = 1 << FileHasher.getPieceSize(size)
|
||||||
|
|
||||||
|
int nPieces
|
||||||
|
if (size % pieceSize == 0)
|
||||||
|
nPieces = size / pieceSize
|
||||||
|
else
|
||||||
|
nPieces = size / pieceSize + 1
|
||||||
|
pieces = new Pieces(nPieces)
|
||||||
|
|
||||||
|
fromDownloader = new PipedInputStream()
|
||||||
|
fromUploader = new PipedInputStream()
|
||||||
|
toDownloader = new PipedOutputStream(fromUploader)
|
||||||
|
toUploader = new PipedOutputStream(fromDownloader)
|
||||||
|
endpoint = new Endpoint(null, fromUploader, toUploader, null)
|
||||||
|
|
||||||
|
session = new DownloadSession(pieces, infoHash, endpoint, target, pieceSize, size)
|
||||||
|
downloadThread = new Thread( { session.request() } as Runnable)
|
||||||
|
downloadThread.setDaemon(true)
|
||||||
|
downloadThread.start()
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void teardown() {
|
||||||
|
source?.delete()
|
||||||
|
target?.delete()
|
||||||
|
downloadThread?.interrupt()
|
||||||
|
Thread.sleep(50)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSmallFile() {
|
||||||
|
initSession(20)
|
||||||
|
assert "GET $rootBase64" == readTillRN(fromDownloader)
|
||||||
|
assert "Range: 0-19" == readTillRN(fromDownloader)
|
||||||
|
assert "" == readTillRN(fromDownloader)
|
||||||
|
|
||||||
|
toDownloader.write("200 OK\r\n".bytes)
|
||||||
|
toDownloader.write("Content-Range: 0-19\r\n\r\n".bytes)
|
||||||
|
toDownloader.write(source.bytes)
|
||||||
|
toDownloader.flush()
|
||||||
|
|
||||||
|
Thread.sleep(150)
|
||||||
|
|
||||||
|
assert pieces.isComplete()
|
||||||
|
assert target.bytes == source.bytes
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user