diff --git a/core/src/main/groovy/com/muwire/core/Constants.groovy b/core/src/main/groovy/com/muwire/core/Constants.groovy index de3193a7..0c9b3c54 100644 --- a/core/src/main/groovy/com/muwire/core/Constants.groovy +++ b/core/src/main/groovy/com/muwire/core/Constants.groovy @@ -5,4 +5,7 @@ import net.i2p.crypto.SigType class Constants { public static final byte PERSONA_VERSION = (byte)1 public static final SigType SIG_TYPE = SigType.ECDSA_SHA512_P521 // TODO: decide which + + public static final int MAX_HEADER_SIZE = 0x1 << 14 + public static final int MAX_HEADERS = 16 } diff --git a/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy b/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy new file mode 100644 index 00000000..1fc47c5a --- /dev/null +++ b/core/src/main/groovy/com/muwire/core/download/DownloadSession.groovy @@ -0,0 +1,139 @@ +package com.muwire.core.download; + +import net.i2p.data.Base64 + +import com.muwire.core.Constants +import com.muwire.core.InfoHash +import com.muwire.core.connection.Endpoint +import static com.muwire.core.util.DataUtil.readTillRN + +import groovy.util.logging.Log + +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.nio.file.StandardOpenOption +import java.security.MessageDigest +import java.security.NoSuchAlgorithmException + +@Log +class DownloadSession { + + private final Pieces pieces + private final InfoHash infoHash + private final Endpoint endpoint + private final File file + private final int pieceSize + private final long fileLength + private final MessageDigest digest + + private ByteBuffer mapped + + DownloadSession(Pieces pieces, InfoHash infoHash, Endpoint endpoint, File file, + int pieceSize, long fileLength) { + this.pieces = pieces + this.endpoint = endpoint + this.infoHash = infoHash + this.file = file + this.pieceSize = pieceSize + this.fileLength = fileLength + try { + digest = MessageDigest.getInstance("SHA-256") + } catch (NoSuchAlgorithmException impossible) { + digest = null + System.exit(1) + } + } + + public void request() throws IOException { + OutputStream os = endpoint.getOutputStream() + InputStream is = endpoint.getInputStream() + + int piece = pieces.getRandomPiece() + long start = piece * pieceSize + long end = Math.min(fileLength, start + pieceSize) - 1 + long length = end - start + 1 + + String root = Base64.encode(infoHash.getRoot()) + + FileChannel channel + try { + os.write("GET $root\r\n".getBytes(StandardCharsets.US_ASCII)) + os.write("Range: $start-$end\r\n\r\n".getBytes(StandardCharsets.US_ASCII)) + os.flush() + String code = readTillRN(is) + if (code.startsWith("404 ")) { + log.warning("file not found") + endpoint.close() + return + } + + if (code.startsWith("416 ")) { + log.warning("range $start-$end cannot be satisfied") + return // leave endpoint open + } + + if (!code.startsWith("200 ")) { + log.warning("unknown code $code") + endpoint.close() + return + } + + // parse all headers + Set headers = new HashSet<>() + String header + while((header = readTillRN(is)) != "" && headers.size() < Constants.MAX_HEADERS) + headers.add(header) + + long receivedStart = -1 + long receivedEnd = -1 + for (String receivedHeader : headers) { + def group = (receivedHeader =~ /^Content-Range: (\d+)-(\d+)$/) + if (group.size() != 1) { + log.info("ignoring header $receivedHeader") + continue + } + + receivedStart = Long.parseLong(group[0][1]) + receivedEnd = Long.parseLong(group[0][2]) + } + + if (receivedStart != start || receivedEnd != end) { + log.warning("We don't support mismatching ranges yet") + endpoint.close() + return + } + + // start the download + channel = Files.newByteChannel(file.toPath(), EnumSet.of(StandardOpenOption.READ, StandardOpenOption.WRITE, + StandardOpenOption.SPARSE, StandardOpenOption.CREATE)) // TODO: double-check, maybe CREATE_NEW + mapped = channel.map(FileChannel.MapMode.READ_WRITE, start, end - start + 1) + + byte[] tmp = new byte[0x1 << 13] + while(mapped.hasRemaining()) { + int read = is.read(tmp) + if (read == -1) + throw new IOException() + synchronized(this) { + mapped.put(tmp, 0, read) + } + } + + mapped.clear() + digest.update(mapped) + byte [] hash = digest.digest() + byte [] expected = new byte[32] + System.arraycopy(infoHash.getHashList(), piece * 32, expected, 0, 32) + if (hash != expected) { + log.warning("hash mismatch") + endpoint.close() + return + } + + pieces.markDownloaded(piece) + } finally { + try { channel?.close() } catch (IOException ignore) {} + } + } +} diff --git a/core/src/main/groovy/com/muwire/core/upload/Request.groovy b/core/src/main/groovy/com/muwire/core/upload/Request.groovy index 8238ba2e..d014c9d8 100644 --- a/core/src/main/groovy/com/muwire/core/upload/Request.groovy +++ b/core/src/main/groovy/com/muwire/core/upload/Request.groovy @@ -2,6 +2,7 @@ package com.muwire.core.upload import java.nio.charset.StandardCharsets +import com.muwire.core.Constants import com.muwire.core.InfoHash import groovy.util.logging.Log @@ -19,8 +20,8 @@ class Request { static Request parse(InfoHash infoHash, InputStream is) throws IOException { Map headers = new HashMap<>() - byte [] tmp = new byte[0x1 << 14] - while(true) { + byte [] tmp = new byte[Constants.MAX_HEADER_SIZE] + while(headers.size() < Constants.MAX_HEADERS) { boolean r = false boolean n = false int idx = 0 diff --git a/core/src/main/groovy/com/muwire/core/util/DataUtil.groovy b/core/src/main/groovy/com/muwire/core/util/DataUtil.groovy index e519cffc..25bec593 100644 --- a/core/src/main/groovy/com/muwire/core/util/DataUtil.groovy +++ b/core/src/main/groovy/com/muwire/core/util/DataUtil.groovy @@ -2,6 +2,8 @@ package com.muwire.core.util import java.nio.charset.StandardCharsets +import com.muwire.core.Constants + class DataUtil { private final static int MAX_SHORT = (0x1 << 16) - 1 @@ -61,4 +63,20 @@ class DataUtil { daos.close() baos.toByteArray() } + + public static String readTillRN(InputStream is) { + def baos = new ByteArrayOutputStream() + while(baos.size() < (Constants.MAX_HEADER_SIZE)) { + byte read = is.read() + if (read == -1) + throw new IOException() + if (read == '\r') { + if (is.read() != '\n') + throw new IOException("invalid header") + break + } + baos.write(read) + } + new String(baos.toByteArray(), StandardCharsets.US_ASCII) + } } diff --git a/core/src/test/groovy/com/muwire/core/download/DownloadSessionTest.groovy b/core/src/test/groovy/com/muwire/core/download/DownloadSessionTest.groovy new file mode 100644 index 00000000..f465ac46 --- /dev/null +++ b/core/src/test/groovy/com/muwire/core/download/DownloadSessionTest.groovy @@ -0,0 +1,89 @@ +package com.muwire.core.download + +import org.junit.After +import org.junit.Test + +import com.muwire.core.InfoHash +import com.muwire.core.connection.Endpoint +import com.muwire.core.files.FileHasher +import static com.muwire.core.util.DataUtil.readTillRN + +import net.i2p.data.Base64 + +class DownloadSessionTest { + + private File source, target + private InfoHash infoHash + private Endpoint endpoint + private Pieces pieces + private String rootBase64 + + private DownloadSession session + private Thread downloadThread + + private InputStream fromDownloader, fromUploader + private OutputStream toDownloader, toUploader + + private void initSession(int size) { + Random r = new Random() + byte [] content = new byte[size] + r.nextBytes(content) + + source = File.createTempFile("source", "tmp") + source.deleteOnExit() + def fos = new FileOutputStream(source) + fos.write(content) + fos.close() + + def hasher = new FileHasher() + infoHash = hasher.hashFile(source) + rootBase64 = Base64.encode(infoHash.getRoot()) + + target = File.createTempFile("target", "tmp") + int pieceSize = 1 << FileHasher.getPieceSize(size) + + int nPieces + if (size % pieceSize == 0) + nPieces = size / pieceSize + else + nPieces = size / pieceSize + 1 + pieces = new Pieces(nPieces) + + fromDownloader = new PipedInputStream() + fromUploader = new PipedInputStream() + toDownloader = new PipedOutputStream(fromUploader) + toUploader = new PipedOutputStream(fromDownloader) + endpoint = new Endpoint(null, fromUploader, toUploader, null) + + session = new DownloadSession(pieces, infoHash, endpoint, target, pieceSize, size) + downloadThread = new Thread( { session.request() } as Runnable) + downloadThread.setDaemon(true) + downloadThread.start() + } + + @After + public void teardown() { + source?.delete() + target?.delete() + downloadThread?.interrupt() + Thread.sleep(50) + } + + @Test + public void testSmallFile() { + initSession(20) + assert "GET $rootBase64" == readTillRN(fromDownloader) + assert "Range: 0-19" == readTillRN(fromDownloader) + assert "" == readTillRN(fromDownloader) + + toDownloader.write("200 OK\r\n".bytes) + toDownloader.write("Content-Range: 0-19\r\n\r\n".bytes) + toDownloader.write(source.bytes) + toDownloader.flush() + + Thread.sleep(150) + + assert pieces.isComplete() + assert target.bytes == source.bytes + } +}