Compare commits

...

8 Commits

Author SHA1 Message Date
eyedeekay
52e5c6fec5 fmt 2025-07-18 17:22:14 -04:00
eyedeekay
c0dcd59f9c race condition 2025-07-18 17:06:00 -04:00
eyedeekay
4734f182b0 Fix double close synchronization bug with idempotent close 2025-07-18 16:37:32 -04:00
eyedeekay
1bee6672fe Fix connection slot reservation race condition 2025-07-18 16:30:43 -04:00
eyedeekay
bbe23ddcab Fix critical activeSet map race condition in Close() method 2025-07-18 16:16:48 -04:00
eyedeekay
894e2a0789 gofmt, eliminate redundant loop 2025-07-18 12:49:30 -04:00
eyedeekay
defcb3d497 Fix resource leak: remove unused done channel 2025-07-18 12:47:21 -04:00
eyedeekay
f30214ed29 Fix race condition(TOCTOU) in connection limit enforcer 2025-07-18 12:02:26 -04:00
6 changed files with 601 additions and 40 deletions

View File

@@ -15,7 +15,6 @@ type LimitedListener struct {
activeConns int64
limiter *rate.Limiter
mu sync.Mutex
done chan struct{}
activeSet map[net.Conn]struct{}
}

View File

@@ -7,7 +7,6 @@ func NewLimitedListener(listener net.Listener, opts ...Option) *LimitedListener
l := &LimitedListener{
Listener: listener,
maxConns: 1000, // default limit
done: make(chan struct{}),
activeSet: make(map[net.Conn]struct{}),
}
@@ -20,7 +19,6 @@ func NewLimitedListener(listener net.Listener, opts ...Option) *LimitedListener
// Accept accepts a connection with limiting logic
func (l *LimitedListener) Accept() (net.Conn, error) {
for {
// Check if rate limit is exceeded
if l.limiter != nil {
if !l.limiter.Allow() {
@@ -28,17 +26,23 @@ func (l *LimitedListener) Accept() (net.Conn, error) {
}
}
// Check concurrent connection limit
// Atomically reserve a connection slot before calling Accept()
l.mu.Lock()
if l.maxConns > 0 && l.activeConns >= int64(l.maxConns) {
l.mu.Unlock()
return nil, ErrMaxConnsReached
}
// Reserve the slot by incrementing the counter
l.activeConns++
l.mu.Unlock()
// Accept the connection
// Now call Accept() with the slot already reserved
conn, err := l.Listener.Accept()
if err != nil {
// Accept failed, release the reserved slot
l.mu.Lock()
l.activeConns--
l.mu.Unlock()
return nil, err
}
@@ -48,11 +52,10 @@ func (l *LimitedListener) Accept() (net.Conn, error) {
listener: l,
}
// Add to active set (connection count already incremented above)
l.mu.Lock()
l.activeConns++
l.activeSet[tracked] = struct{}{}
l.mu.Unlock()
return tracked, nil
}
}

407
listener_test.go Normal file
View File

@@ -0,0 +1,407 @@
package limitedlistener
import (
"fmt"
"net"
"sync"
"testing"
"time"
)
// mockListener implements net.Listener for testing
type mockListener struct {
connChan chan net.Conn
errChan chan error
closed bool
mu sync.Mutex
}
func newMockListener() *mockListener {
return &mockListener{
connChan: make(chan net.Conn, 100),
errChan: make(chan error, 10),
}
}
func (m *mockListener) Accept() (net.Conn, error) {
select {
case conn := <-m.connChan:
return conn, nil
case err := <-m.errChan:
return nil, err
}
}
func (m *mockListener) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if !m.closed {
m.closed = true
close(m.connChan)
close(m.errChan)
}
return nil
}
func (m *mockListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
}
func (m *mockListener) sendConn(conn net.Conn) {
m.connChan <- conn
}
// mockConn implements net.Conn for testing
type mockConn struct {
net.Conn
closed bool
}
func (m *mockConn) Close() error {
m.closed = true
return nil
}
func (m *mockConn) Read(b []byte) (int, error) { return 0, nil }
func (m *mockConn) Write(b []byte) (int, error) { return len(b), nil }
func (m *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080} }
func (m *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8081} }
func (m *mockConn) SetDeadline(t time.Time) error { return nil }
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
func TestConnectionLimitRaceCondition(t *testing.T) {
mockL := newMockListener()
defer mockL.Close()
// Create a limited listener with a small connection limit
limited := NewLimitedListener(mockL, WithMaxConnections(2))
// Pre-populate the mock with connections
for i := 0; i < 10; i++ {
mockL.sendConn(&mockConn{})
}
var wg sync.WaitGroup
results := make(chan struct {
conn net.Conn
err error
}, 10)
// Start multiple goroutines trying to accept connections concurrently
numGoroutines := 10
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
conn, err := limited.Accept()
results <- struct {
conn net.Conn
err error
}{conn, err}
}()
}
// Wait for all goroutines to complete with a timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// All completed
case <-time.After(2 * time.Second):
t.Fatal("Test timed out - possible deadlock")
}
close(results)
// Count results
var successful, failed int
var connections []net.Conn
for result := range results {
if result.err == nil {
successful++
connections = append(connections, result.conn)
} else if result.err == ErrMaxConnsReached {
failed++
} else {
t.Errorf("Unexpected error: %v", result.err)
}
}
// Close all successful connections
for _, conn := range connections {
conn.Close()
}
// The key test: successful connections should not exceed the limit
if successful > 2 {
t.Errorf("Race condition detected: Expected at most 2 successful connections, but %d were accepted", successful)
}
t.Logf("Successful: %d, Failed: %d", successful, failed)
}
func TestConnectionLimitEnforcement(t *testing.T) {
mockL := newMockListener()
defer mockL.Close()
limited := NewLimitedListener(mockL, WithMaxConnections(1))
// Send mock connections
for i := 0; i < 3; i++ {
mockL.sendConn(&mockConn{})
}
// Accept first connection - should succeed
conn1, err := limited.Accept()
if err != nil {
t.Fatalf("First connection should be accepted: %v", err)
}
defer conn1.Close()
// Try to accept second connection - should fail with limit error
conn2, err := limited.Accept()
if err != ErrMaxConnsReached {
t.Errorf("Expected ErrMaxConnsReached, got: %v", err)
}
if conn2 != nil {
t.Error("Connection should be nil when limit is reached")
conn2.Close()
}
}
func TestActiveSetDataRace(t *testing.T) {
mockL := newMockListener()
defer mockL.Close()
limited := NewLimitedListener(mockL, WithMaxConnections(10))
// Pre-populate with exactly enough connections to avoid blocking
for i := 0; i < 15; i++ {
mockL.sendConn(&mockConn{})
}
var wg sync.WaitGroup
connections := make([]net.Conn, 0, 10)
connMutex := sync.Mutex{}
// Accept 10 connections concurrently
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
conn, err := limited.Accept()
if err == nil && conn != nil {
connMutex.Lock()
connections = append(connections, conn)
connMutex.Unlock()
}
}()
}
// Wait for accepts to complete
wg.Wait()
// Now close some connections concurrently to test the race condition
wg.Add(5)
for i := 0; i < 5; i++ {
go func() {
defer wg.Done()
connMutex.Lock()
if len(connections) > 0 {
conn := connections[len(connections)-1]
connections = connections[:len(connections)-1]
connMutex.Unlock()
conn.Close()
} else {
connMutex.Unlock()
}
}()
}
// Also test listener close concurrent with connection operations
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(10 * time.Millisecond) // Brief delay to let some closes happen
limited.Close()
}()
// Wait for all operations to complete
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success - no race detected
case <-time.After(1 * time.Second):
t.Fatal("Test timed out - possible deadlock or race")
}
// Clean up any remaining connections
connMutex.Lock()
for _, conn := range connections {
conn.Close()
}
connMutex.Unlock()
}
func TestSlotReservationRaceCondition(t *testing.T) {
mockL := newMockListener()
defer mockL.Close()
// Create a listener with a very small limit
limited := NewLimitedListener(mockL, WithMaxConnections(3))
// Pre-populate with enough connections to stress test the limit
for i := 0; i < 20; i++ {
mockL.sendConn(&mockConn{})
}
var wg sync.WaitGroup
var successfulConnections []net.Conn
var connMutex sync.Mutex
// Start many goroutines trying to accept connections concurrently
numGoroutines := 15
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
conn, err := limited.Accept()
if err == nil && conn != nil {
connMutex.Lock()
successfulConnections = append(successfulConnections, conn)
connMutex.Unlock()
}
}()
}
// Wait for all accepts to complete
wg.Wait()
// Verify that we didn't exceed the connection limit
connMutex.Lock()
actualCount := len(successfulConnections)
connMutex.Unlock()
// Check that activeConns matches the actual number of successful connections
stats := limited.GetStats()
if actualCount != int(stats.ActiveConnections) {
t.Errorf("Slot reservation race detected: actualCount=%d, activeConns=%d",
actualCount, stats.ActiveConnections)
}
if actualCount > 3 {
t.Errorf("Connection limit exceeded: expected max 3, got %d successful connections", actualCount)
}
// Clean up
connMutex.Lock()
for _, conn := range successfulConnections {
conn.Close()
}
connMutex.Unlock()
t.Logf("Successfully limited to %d connections (limit: 3)", actualCount)
}
// doubleCloseMockConn tracks how many times Close() is called
type doubleCloseMockConn struct {
*mockConn
closeCount int32
mu sync.Mutex
}
func (d *doubleCloseMockConn) Close() error {
d.mu.Lock()
d.closeCount++
closedTimes := d.closeCount
d.mu.Unlock()
if closedTimes > 1 {
panic(fmt.Sprintf("Connection closed %d times - double close detected!", closedTimes))
}
return d.mockConn.Close()
}
func (d *doubleCloseMockConn) getCloseCount() int32 {
d.mu.Lock()
defer d.mu.Unlock()
return d.closeCount
}
func TestDoubleClosePreevention(t *testing.T) {
mockL := newMockListener()
defer mockL.Close()
limited := NewLimitedListener(mockL, WithMaxConnections(5))
// Create special mock connections that track close calls
doubleCloseMocks := make([]*doubleCloseMockConn, 3)
for i := 0; i < 3; i++ {
doubleCloseMocks[i] = &doubleCloseMockConn{
mockConn: &mockConn{},
}
mockL.sendConn(doubleCloseMocks[i])
}
// Accept the connections
var connections []net.Conn
for i := 0; i < 3; i++ {
conn, err := limited.Accept()
if err != nil {
t.Fatalf("Failed to accept connection %d: %v", i, err)
}
connections = append(connections, conn)
}
// Close the first connection manually (user close)
err := connections[0].Close()
if err != nil {
t.Fatalf("Failed to close connection manually: %v", err)
}
// Verify it was closed once
if doubleCloseMocks[0].getCloseCount() != 1 {
t.Errorf("Expected connection 0 to be closed once, got %d times", doubleCloseMocks[0].getCloseCount())
}
// Now close the listener - this should close remaining connections but NOT double-close the first one
err = limited.Close()
if err != nil {
t.Fatalf("Failed to close listener: %v", err)
}
// Verify close counts
for i, mock := range doubleCloseMocks {
closeCount := mock.getCloseCount()
if closeCount != 1 {
t.Errorf("Connection %d should be closed exactly once, but was closed %d times", i, closeCount)
}
}
// Test idempotent close - closing already closed connections should be safe
for i, conn := range connections {
err := conn.Close()
if err != nil {
t.Errorf("Idempotent close of connection %d should not return error: %v", i, err)
}
// Close count should still be 1 (idempotent)
if doubleCloseMocks[i].getCloseCount() != 1 {
t.Errorf("After idempotent close, connection %d close count should still be 1, got %d", i, doubleCloseMocks[i].getCloseCount())
}
}
t.Log("Double close prevention test passed - all connections closed exactly once")
}

131
race_condition_test.go Normal file
View File

@@ -0,0 +1,131 @@
package limitedlistener
import (
"net"
"sync"
"sync/atomic"
"testing"
)
// controlledMockListener allows us to control when Accept() returns
type controlledMockListener struct {
connChan chan net.Conn
errChan chan error
closed bool
mu sync.Mutex
acceptCh chan struct{} // Signal to allow Accept() to proceed
acceptWait bool // Whether Accept() should wait
}
func newControlledMockListener() *controlledMockListener {
return &controlledMockListener{
connChan: make(chan net.Conn, 100),
errChan: make(chan error, 10),
acceptCh: make(chan struct{}, 100), // Buffered channel
}
}
func (m *controlledMockListener) Accept() (net.Conn, error) {
// Wait for signal if control is enabled
if m.acceptWait {
<-m.acceptCh
}
select {
case conn := <-m.connChan:
return conn, nil
case err := <-m.errChan:
return nil, err
default:
// Return a mock connection if no specific conn/err was queued
return &mockConn{}, nil
}
}
func (m *controlledMockListener) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if !m.closed {
m.closed = true
close(m.connChan)
close(m.errChan)
}
return nil
}
func (m *controlledMockListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
}
func (m *controlledMockListener) enableAcceptControl() {
m.acceptWait = true
}
func (m *controlledMockListener) releaseAllAccepts(count int) {
for i := 0; i < count; i++ {
m.acceptCh <- struct{}{}
}
}
// TestRaceConditionReproduction creates a scenario that reliably reproduces
// the race condition where connection limits can be temporarily exceeded
func TestRaceConditionReproduction(t *testing.T) {
// Use a simple approach: stress test with many concurrent accepts
// and check if we can exceed the limit
for attempt := 0; attempt < 10; attempt++ {
mockL := newControlledMockListener()
limited := NewLimitedListener(mockL, WithMaxConnections(2))
// Pre-signal the accepts to avoid blocking
mockL.enableAcceptControl()
mockL.releaseAllAccepts(10)
var wg sync.WaitGroup
var successfulConns int64
var connections []net.Conn
var connMutex sync.Mutex
// Start multiple goroutines trying to accept connections simultaneously
numGoroutines := 5
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
conn, err := limited.Accept()
if err == nil && conn != nil {
atomic.AddInt64(&successfulConns, 1)
connMutex.Lock()
connections = append(connections, conn)
connMutex.Unlock()
}
}()
}
// Wait for completion
wg.Wait()
// Clean up connections
connMutex.Lock()
for _, conn := range connections {
conn.Close()
}
connMutex.Unlock()
mockL.Close()
limited.Close()
successful := atomic.LoadInt64(&successfulConns)
// Check if we reproduced the race condition
if successful > 2 {
t.Logf("RACE CONDITION REPRODUCED on attempt %d: %d connections accepted (limit was 2)", attempt+1, successful)
t.Errorf("Connection limit exceeded due to race condition: expected at most 2, got %d", successful)
return
}
t.Logf("Attempt %d: %d connections accepted (within limit)", attempt+1, successful)
}
t.Log("Race condition was not reproduced in 10 attempts, but the theoretical issue still exists")
}

View File

@@ -1,5 +1,7 @@
package limitedlistener
import "net"
// Stats provides current listener statistics
type Stats struct {
ActiveConnections int64
@@ -27,13 +29,17 @@ func (l *LimitedListener) GetStats() Stats {
// Close implements graceful shutdown
func (l *LimitedListener) Close() error {
l.mu.Lock()
defer l.mu.Unlock()
// Close all active connections
// Create a slice of connections to close, so we can release the lock
connections := make([]net.Conn, 0, len(l.activeSet))
for conn := range l.activeSet {
connections = append(connections, conn)
}
l.mu.Unlock()
// Close all active connections outside the lock to avoid deadlock
for _, conn := range connections {
conn.Close()
}
close(l.done)
return l.Listener.Close()
}

View File

@@ -10,15 +10,30 @@ type trackedConn struct {
net.Conn
listener *LimitedListener
once sync.Once
closed bool
mu sync.Mutex
}
// Close implements net.Conn Close with connection tracking
func (c *trackedConn) Close() error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil // Already closed, return without error
}
c.closed = true
c.mu.Unlock()
var closeErr error
c.once.Do(func() {
// Remove from active set and decrement counter atomically
c.listener.mu.Lock()
delete(c.listener.activeSet, c)
c.listener.activeConns--
c.listener.mu.Unlock()
// Close the underlying connection
closeErr = c.Conn.Close()
})
return c.Conn.Close()
return closeErr
}