From 6bb3657de2bf7641a7c12771b3caa4a61c94108d Mon Sep 17 00:00:00 2001
From: zzz <zzz@i2pmail.org>
Date: Sat, 26 Feb 2022 08:25:55 -0500
Subject: [PATCH] SSU2: Start of packet handling

Store CipherStates in PeerState2
Add missing getVersion() overrides
---
 .../udp/OutboundEstablishState2.java          |   2 +
 .../router/transport/udp/PacketBuilder2.java  |  23 +--
 .../router/transport/udp/PacketHandler.java   |  73 +++++++-
 .../i2p/router/transport/udp/PeerState.java   |   4 +-
 .../i2p/router/transport/udp/PeerState2.java  | 159 ++++++++++++++++--
 .../router/transport/udp/SSU2Bitfield.java    |   3 +-
 .../i2p/router/transport/udp/UDPPacket.java   |   3 +
 .../router/transport/udp/UDPTransport.java    |   2 +-
 8 files changed, 229 insertions(+), 40 deletions(-)

diff --git a/router/java/src/net/i2p/router/transport/udp/OutboundEstablishState2.java b/router/java/src/net/i2p/router/transport/udp/OutboundEstablishState2.java
index 86975bec28..3d498b6ebb 100644
--- a/router/java/src/net/i2p/router/transport/udp/OutboundEstablishState2.java
+++ b/router/java/src/net/i2p/router/transport/udp/OutboundEstablishState2.java
@@ -211,6 +211,8 @@ class OutboundEstablishState2 extends OutboundEstablishState implements SSU2Payl
     // end payload callbacks
     /////////////////////////////////////////////////////////
     
+    @Override
+    public int getVersion() { return 2; }
     public long getSendConnID() { return _sendConnID; }
     public long getRcvConnID() { return _rcvConnID; }
     public long getToken() { return _token; }
diff --git a/router/java/src/net/i2p/router/transport/udp/PacketBuilder2.java b/router/java/src/net/i2p/router/transport/udp/PacketBuilder2.java
index 7c67268b5c..4540dc322f 100644
--- a/router/java/src/net/i2p/router/transport/udp/PacketBuilder2.java
+++ b/router/java/src/net/i2p/router/transport/udp/PacketBuilder2.java
@@ -12,6 +12,7 @@ import java.util.Iterator;
 import java.util.List;
 
 import com.southernstorm.noise.protocol.ChaChaPolyCipherState;
+import com.southernstorm.noise.protocol.CipherState;
 import com.southernstorm.noise.protocol.HandshakeState;
 
 import net.i2p.crypto.ChaCha20;
@@ -317,7 +318,7 @@ class PacketBuilder2 {
         SSU2Payload.writePayload(data, SHORT_HEADER_SIZE, blocks);
         pkt.setLength(off);
 
-        encryptDataPacket(packet, peer.getSendEncryptKey(), pktNum, peer.getSendHeaderEncryptKey1(), peer.getSendHeaderEncryptKey2());
+        encryptDataPacket(packet, peer.getSendCipher(), pktNum, peer.getSendHeaderEncryptKey1(), peer.getSendHeaderEncryptKey2());
         setTo(packet, peer.getRemoteIPAddress(), peer.getRemotePort());
         
         // FIXME ticket #2675
@@ -439,7 +440,7 @@ class PacketBuilder2 {
 */
         
         pkt.setLength(off);
-        encryptDataPacket(packet, peer.getSendEncryptKey(), pktNum, peer.getSendHeaderEncryptKey1(), peer.getSendHeaderEncryptKey2());
+        encryptDataPacket(packet, peer.getSendCipher(), pktNum, peer.getSendHeaderEncryptKey1(), peer.getSendHeaderEncryptKey2());
         setTo(packet, peer.getRemoteIPAddress(), peer.getRemotePort());
         packet.setMessageType(TYPE_ACK);
         packet.setPriority((fullACKCount > 0 || partialACKCount > 0) ? PRIORITY_HIGH : PRIORITY_LOW);
@@ -1131,20 +1132,20 @@ class PacketBuilder2 {
      *                length set to the end of the data.
      *                This will extend the length by 16 for the MAC.
      */
-    private void encryptDataPacket(UDPPacket packet, byte[] chachaKey, long n,
+    private void encryptDataPacket(UDPPacket packet, CipherState chacha, long n,
                                     byte[] hdrKey1, byte[] hdrKey2) {
         DatagramPacket pkt = packet.getPacket();
         byte data[] = pkt.getData();
         int off = pkt.getOffset();
         int len = pkt.getLength();
-        ChaChaPolyCipherState chacha = new ChaChaPolyCipherState();
-        chacha.initializeKey(chachaKey, 0);
-        chacha.setNonce(n);
-        try {
-            chacha.encryptWithAd(data, off, SHORT_HEADER_SIZE,
-                                 data, off + SHORT_HEADER_SIZE, data, off + SHORT_HEADER_SIZE, len - SHORT_HEADER_SIZE);
-        } catch (GeneralSecurityException e) {
-            throw new IllegalArgumentException("Bad data msg", e);
+        synchronized(chacha) {
+            chacha.setNonce(n);
+            try {
+                chacha.encryptWithAd(data, off, SHORT_HEADER_SIZE,
+                                     data, off + SHORT_HEADER_SIZE, data, off + SHORT_HEADER_SIZE, len - SHORT_HEADER_SIZE);
+            } catch (GeneralSecurityException e) {
+                throw new IllegalArgumentException("Bad data msg", e);
+            }
         }
         pkt.setLength(len + MAC_LEN);
         SSU2Header.encryptShortHeader(packet, hdrKey1, hdrKey2);
diff --git a/router/java/src/net/i2p/router/transport/udp/PacketHandler.java b/router/java/src/net/i2p/router/transport/udp/PacketHandler.java
index abcbb13082..cc53982d14 100644
--- a/router/java/src/net/i2p/router/transport/udp/PacketHandler.java
+++ b/router/java/src/net/i2p/router/transport/udp/PacketHandler.java
@@ -40,6 +40,7 @@ class PacketHandler {
     private final Map<RemoteHostId, Object> _failCache;
     private final BlockingQueue<UDPPacket> _inboundQueue;
     private static final Object DUMMY = new Object();
+    private final boolean _enableSSU2;
     
     private static final int TYPE_POISON = -99999;
     private static final int MIN_QUEUE_SIZE = 16;
@@ -56,11 +57,12 @@ class PacketHandler {
     
     private enum AuthType { NONE, INTRO, BOBINTRO, SESSION }
 
-    PacketHandler(RouterContext ctx, UDPTransport transport, EstablishmentManager establisher,
+    PacketHandler(RouterContext ctx, UDPTransport transport, boolean enableSSU2, EstablishmentManager establisher,
                   InboundMessageFragments inbound, PeerTestManager testManager, IntroductionManager introManager) {
         _context = ctx;
         _log = ctx.logManager().getLog(PacketHandler.class);
         _transport = transport;
+        _enableSSU2 = enableSSU2;
         _establisher = establisher;
         _inbound = inbound;
         _testManager = testManager;
@@ -222,8 +224,8 @@ class PacketHandler {
          * Classify the packet by source IP/port, into 4 groups:
          *<ol>
          *<li>Established session
-         *<li>Pending inbound establishement
-         *<li>Pending outbound establishement
+         *<li>Pending inbound establishment
+         *<li>Pending outbound establishment
          *<li>No established or pending session found
          *</ol>
          */
@@ -238,7 +240,10 @@ class PacketHandler {
                     // Group 2: Inbound Establishment
                     if (_log.shouldLog(Log.DEBUG))
                         _log.debug("Packet received IS for an inbound establishment");
-                    receivePacket(reader, packet, est);
+                    if (est.getVersion() == 2)
+                        receiveSSU2Packet(packet, (InboundEstablishState2) est);
+                    else
+                        receivePacket(reader, packet, est);
                 } else {
                     //if (_log.shouldLog(Log.DEBUG))
                     //    _log.debug("Packet received is not for an inbound establishment");
@@ -247,7 +252,10 @@ class PacketHandler {
                         // Group 3: Outbound Establishment
                         if (_log.shouldLog(Log.DEBUG))
                             _log.debug("Packet received IS for an outbound establishment");
-                        receivePacket(reader, packet, oest);
+                        if (oest.getVersion() == 2)
+                            receiveSSU2Packet(packet, (OutboundEstablishState2) oest);
+                        else
+                            receivePacket(reader, packet, oest);
                     } else {
                         // Group 4: New conn or needs fallback
                         if (_log.shouldLog(Log.DEBUG))
@@ -259,9 +267,12 @@ class PacketHandler {
                 }
             } else {
                 // Group 1: Established
-                if (_log.shouldLog(Log.DEBUG))
-                    _log.debug("Packet received IS for an existing peer");
-                receivePacket(reader, packet, state);
+                //if (_log.shouldLog(Log.DEBUG))
+                //    _log.debug("Packet received IS for an existing peer");
+                if (state.getVersion() == 2)
+                    receiveSSU2Packet(packet, (PeerState2) state);
+                else
+                    receivePacket(reader, packet, state);
             }
         }
 
@@ -520,6 +531,8 @@ class PacketHandler {
          * and send it to one of four places: The EstablishmentManager, IntroductionManager,
          * PeerTestManager, or InboundMessageFragments.
          *
+         * SSU1 only.
+         *
          * @param state non-null if fully established
          * @param outState non-null if outbound establishing in process
          * @param inState unused always null, TODO use for 48-byte destroys during inbound establishment
@@ -743,6 +756,50 @@ class PacketHandler {
         }
     }
 
+    //// Begin SSU2 Handling ////
+
+    /**
+     *  Hand off to the state for processing.
+     *  Packet is decrypted in-place, no fallback
+     *  processing is possible.
+     *
+     *  @param state must be version 2
+     *  @since 0.9.54
+     */
+    private void receiveSSU2Packet(UDPPacket packet, PeerState2 state) {
+        state.receivePacket(packet);
+    }
+
+    /**
+     *  Hand off to the state for processing.
+     *  Packet is decrypted in-place, no fallback
+     *  processing is possible.
+     *
+     *  @param state must be version 2
+     *  @since 0.9.54
+     */
+    private void receiveSSU2Packet(UDPPacket packet, InboundEstablishState2 state) {
+
+
+    }
+
+    /**
+     *  Hand off to the state for processing.
+     *  Packet is decrypted in-place, no fallback
+     *  processing is possible.
+     *
+     *  @param state must be version 2
+     *  @since 0.9.54
+     */
+    private void receiveSSU2Packet(UDPPacket packet, OutboundEstablishState2 state) {
+
+
+    }
+
+
+    //// End SSU2 Handling ////
+
+
     /**
      *  Mark a string for extraction by xgettext and translation.
      *  Use this only in static initializers.
diff --git a/router/java/src/net/i2p/router/transport/udp/PeerState.java b/router/java/src/net/i2p/router/transport/udp/PeerState.java
index 08f1d4a735..591a034491 100644
--- a/router/java/src/net/i2p/router/transport/udp/PeerState.java
+++ b/router/java/src/net/i2p/router/transport/udp/PeerState.java
@@ -38,8 +38,8 @@ import net.i2p.util.SimpleTimer2;
  *
  */
 public class PeerState {
-    private final RouterContext _context;
-    private final Log _log;
+    protected final RouterContext _context;
+    protected final Log _log;
     /**
      * The peer are we talking to.  This should be set as soon as this
      * state is created if we are initiating a connection, but if we are
diff --git a/router/java/src/net/i2p/router/transport/udp/PeerState2.java b/router/java/src/net/i2p/router/transport/udp/PeerState2.java
index df31c12f8c..2fff406f1b 100644
--- a/router/java/src/net/i2p/router/transport/udp/PeerState2.java
+++ b/router/java/src/net/i2p/router/transport/udp/PeerState2.java
@@ -1,14 +1,22 @@
 package net.i2p.router.transport.udp;
 
+import java.net.DatagramPacket;
 import java.net.InetSocketAddress;
+import java.security.GeneralSecurityException;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import com.southernstorm.noise.protocol.CipherState;
+
+import net.i2p.data.DataFormatException;
 import net.i2p.data.DataHelper;
 import net.i2p.data.Hash;
+import net.i2p.data.router.RouterInfo;
 import net.i2p.data.SessionKey;
+import net.i2p.data.i2np.I2NPMessage;
 import net.i2p.router.RouterContext;
+import static net.i2p.router.transport.udp.SSU2Util.*;
+import net.i2p.util.HexDump;
 import net.i2p.util.Log;
-import net.i2p.util.SimpleTimer2;
 
 /**
  * Contain all of the state about a UDP connection to a peer.
@@ -20,17 +28,18 @@ import net.i2p.util.SimpleTimer2;
  *
  * @since 0.9.54
  */
-public class PeerState2 extends PeerState {
+public class PeerState2 extends PeerState implements SSU2Payload.PayloadCallback {
     private final long _sendConnID;
     private final long _rcvConnID;
     private final AtomicInteger _packetNumber = new AtomicInteger();
-    private final byte[] _sendEncryptKey;
-    private final byte[] _rcvEncryptKey;
+    private final CipherState _sendCha;
+    private final CipherState _rcvCha;
     private final byte[] _sendHeaderEncryptKey1;
     private final byte[] _rcvHeaderEncryptKey1;
     private final byte[] _sendHeaderEncryptKey2;
     private final byte[] _rcvHeaderEncryptKey2;
     private final SSU2Bitfield _receivedMessages;
+    private final SSU2Bitfield _sentMessages;
 
     public static final int MIN_MTU = 1280;
 
@@ -39,30 +48,146 @@ public class PeerState2 extends PeerState {
      */
     public PeerState2(RouterContext ctx, UDPTransport transport,
                      InetSocketAddress remoteAddress, Hash remotePeer, boolean isInbound, int rtt,
-                     byte[] sendKey, byte[] rcvKey, long sendID, long rcvID,
+                     CipherState sendCha, CipherState rcvCha, long sendID, long rcvID,
                      byte[] sendHdrKey1, byte[] sendHdrKey2, byte[] rcvHdrKey2) {
         super(ctx, transport, remoteAddress, remotePeer, isInbound, rtt);
         _sendConnID = sendID;
         _rcvConnID = rcvID;
-        _sendEncryptKey = sendKey;
-        _rcvEncryptKey = rcvKey;
+        _sendCha = sendCha;
+        _rcvCha = rcvCha;
         _sendHeaderEncryptKey1 = sendHdrKey1;
         _rcvHeaderEncryptKey1 = transport.getSSU2StaticIntroKey();
         _sendHeaderEncryptKey2 = sendHdrKey2;
         _rcvHeaderEncryptKey2 = rcvHdrKey2;
         _receivedMessages = new SSU2Bitfield(256, 0);
+        _sentMessages = new SSU2Bitfield(256, 0);
     }
 
-    // SSU2
+    @Override
+    public int getVersion() { return 2; }
     long getNextPacketNumber() { return _packetNumber.incrementAndGet(); }
-    public long getSendConnID() { return _sendConnID; }
-    public long getRcvConnID() { return _rcvConnID; }
-    public byte[] getSendEncryptKey() { return _sendEncryptKey; }
-    public byte[] getRcvEncryptKey() { return _rcvEncryptKey; }
-    public byte[] getSendHeaderEncryptKey1() { return _sendHeaderEncryptKey1; }
-    public byte[] getRcvHeaderEncryptKey1() { return _rcvHeaderEncryptKey1; }
-    public byte[] getSendHeaderEncryptKey2() { return _sendHeaderEncryptKey2; }
-    public byte[] getRcvHeaderEncryptKey2() { return _rcvHeaderEncryptKey2; }
-    public SSU2Bitfield getReceivedMessages() { return _receivedMessages; }
+    long getSendConnID() { return _sendConnID; }
+    long getRcvConnID() { return _rcvConnID; }
+    /** caller must sync on returned object when encrypting */
+    CipherState getSendCipher() { return _sendCha; }
+    byte[] getSendHeaderEncryptKey1() { return _sendHeaderEncryptKey1; }
+    byte[] getRcvHeaderEncryptKey1() { return _rcvHeaderEncryptKey1; }
+    byte[] getSendHeaderEncryptKey2() { return _sendHeaderEncryptKey2; }
+    byte[] getRcvHeaderEncryptKey2() { return _rcvHeaderEncryptKey2; }
+    SSU2Bitfield getReceivedMessages() { return _receivedMessages; }
+    SSU2Bitfield getSentMessages() { return _sentMessages; }
+
+    void receivePacket(UDPPacket packet) {
+        DatagramPacket dpacket = packet.getPacket();
+        byte[] data = dpacket.getData();
+        int off = dpacket.getOffset();
+        int len = dpacket.getLength();
+        try {
+            if (len < MIN_DATA_LEN) {
+                if (_log.shouldWarn())
+                    _log.warn("Inbound packet too short " + len + " on " + this);
+                return;
+            }
+            SSU2Header.Header header = SSU2Header.trialDecryptShortHeader(packet, _rcvHeaderEncryptKey1, _rcvHeaderEncryptKey2);
+            if (header == null) {
+                if (_log.shouldWarn())
+                    _log.warn("bad data header on " + this);
+                return;
+            }
+            if (header.getDestConnID() != _rcvConnID) {
+                if (_log.shouldWarn())
+                    _log.warn("bad Dest Conn id " + header.getDestConnID() + " on " + this);
+                return;
+            }
+            if (header.getType() != DATA_FLAG_BYTE) {
+                if (_log.shouldWarn())
+                    _log.warn("bad data pkt type " + (header.getType() & 0xff) + " on " + this);
+                return;
+            }
+            long n = header.getPacketNumber();
+            SSU2Header.acceptTrialDecrypt(packet, header);
+            synchronized (_rcvCha) {
+                _rcvCha.setNonce(n);
+                // decrypt in-place
+                _rcvCha.decryptWithAd(header.data, data, off + SHORT_HEADER_SIZE, data, off + SHORT_HEADER_SIZE, len - SHORT_HEADER_SIZE);
+                if (_receivedMessages.set(n)) {
+                    if (_log.shouldWarn())
+                        _log.warn("dup pkt rcvd " + n + " on " + this);
+                    return;
+                }
+            }
+            processPayload(data, off + SHORT_HEADER_SIZE, len - (SHORT_HEADER_SIZE + MAC_LEN));
+        } catch (GeneralSecurityException gse) {
+            if (_log.shouldWarn())
+                _log.warn("Bad encrypted packet:\n" + HexDump.dump(data, off, len), gse);
+        } catch (IndexOutOfBoundsException ioobe) {
+            if (_log.shouldWarn())
+                _log.warn("Bad encrypted packet:\n" + HexDump.dump(data, off, len), ioobe);
+        } finally {
+            packet.release();
+        }
+    }
+
+    private void processPayload(byte[] payload, int offset, int length) throws GeneralSecurityException {
+        try {
+            int blocks = SSU2Payload.processPayload(_context, this, payload, offset, length, false);
+        } catch (Exception e) {
+            throw new GeneralSecurityException("Session Created payload error", e);
+        }
+    }
+
+    /////////////////////////////////////////////////////////
+    // begin payload callbacks
+    /////////////////////////////////////////////////////////
+
+    public void gotDateTime(long time) {
+    }
+
+    public void gotOptions(byte[] options, boolean isHandshake) {
+    }
+
+    public void gotRI(RouterInfo ri, boolean isHandshake, boolean flood) throws DataFormatException {
+    }
+
+    public void gotRIFragment(byte[] data, boolean isHandshake, boolean flood, boolean isGzipped, int frag, int totalFrags) {
+        throw new IllegalStateException("RI fragment in Data phase");
+    }
+
+    public void gotAddress(byte[] ip, int port) {
+    }
+
+    public void gotIntroKey(byte[] key) {
+    }
+
+    public void gotRelayTagRequest() {
+    }
+
+    public void gotRelayTag(long tag) {
+    }
+
+    public void gotToken(long token, long expires) {
+    }
+
+    public void gotI2NP(I2NPMessage msg) {
+    }
+
+    public void gotFragment(byte[] data, long messageID, int type, long expires, int frag, boolean isLast) throws DataFormatException {
+    }
+
+    public void gotACK(long ackThru, int acks, byte[] ranges) {
+    }
+
+    public void gotTermination(int reason, long count) {
+    }
+
+    public void gotUnknown(int type, int len) {
+    }
+
+    public void gotPadding(int paddingLength, int frameLength) {
+    }
+
+    /////////////////////////////////////////////////////////
+    // end payload callbacks
+    /////////////////////////////////////////////////////////
 
 }
diff --git a/router/java/src/net/i2p/router/transport/udp/SSU2Bitfield.java b/router/java/src/net/i2p/router/transport/udp/SSU2Bitfield.java
index 62e2674a82..395d7463da 100644
--- a/router/java/src/net/i2p/router/transport/udp/SSU2Bitfield.java
+++ b/router/java/src/net/i2p/router/transport/udp/SSU2Bitfield.java
@@ -59,9 +59,10 @@ public class SSU2Bitfield {
      * the offset shifts up and the lowest set bits are lost.
      *
      * @throws IndexOutOfBoundsException if bit is smaller then zero
+     *                                   OR if the shift is too big
      * @return previous value, true if previously set or unknown
      */
-    public boolean set(long bit) {
+    public boolean set(long bit) throws IndexOutOfBoundsException {
         if (bit < 0)
             throw new IndexOutOfBoundsException(Long.toString(bit));
         boolean rv;
diff --git a/router/java/src/net/i2p/router/transport/udp/UDPPacket.java b/router/java/src/net/i2p/router/transport/udp/UDPPacket.java
index 3778bafeff..d48d7549bd 100644
--- a/router/java/src/net/i2p/router/transport/udp/UDPPacket.java
+++ b/router/java/src/net/i2p/router/transport/udp/UDPPacket.java
@@ -316,6 +316,9 @@ class UDPPacket implements CDPQEntry {
      * Decrypt this valid packet, overwriting the _data buffer's payload
      * with the decrypted data (leaving the MAC and IV unaltered)
      * 
+     * SSU 1 only.
+     * SSU 2 decryption is in PacketHandler.
+     * 
      */
     public synchronized void decrypt(SessionKey cipherKey) {
         verifyNotReleased(); 
diff --git a/router/java/src/net/i2p/router/transport/udp/UDPTransport.java b/router/java/src/net/i2p/router/transport/udp/UDPTransport.java
index e8b4469ec5..c6cdf18522 100644
--- a/router/java/src/net/i2p/router/transport/udp/UDPTransport.java
+++ b/router/java/src/net/i2p/router/transport/udp/UDPTransport.java
@@ -655,7 +655,7 @@ public class UDPTransport extends TransportImpl implements TimedWeightedPriority
             _establisher = new EstablishmentManager(_context, this);
         
         if (_handler == null)
-            _handler = new PacketHandler(_context, this, _establisher, _inboundFragments, _testManager, _introManager);
+            _handler = new PacketHandler(_context, this, _enableSSU2, _establisher, _inboundFragments, _testManager, _introManager);
         
         // See comments in DummyThrottle.java
         if (USE_PRIORITY && _refiller == null)
-- 
GitLab