mirror of
https://github.com/go-i2p/go-limit.git
synced 2025-08-19 11:45:29 -04:00
Compare commits
8 Commits
210616857c
...
master
Author | SHA1 | Date | |
---|---|---|---|
![]() |
52e5c6fec5 | ||
![]() |
c0dcd59f9c | ||
![]() |
4734f182b0 | ||
![]() |
1bee6672fe | ||
![]() |
bbe23ddcab | ||
![]() |
894e2a0789 | ||
![]() |
defcb3d497 | ||
![]() |
f30214ed29 |
@@ -15,7 +15,6 @@ type LimitedListener struct {
|
||||
activeConns int64
|
||||
limiter *rate.Limiter
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
activeSet map[net.Conn]struct{}
|
||||
}
|
||||
|
||||
|
71
listener.go
71
listener.go
@@ -7,7 +7,6 @@ func NewLimitedListener(listener net.Listener, opts ...Option) *LimitedListener
|
||||
l := &LimitedListener{
|
||||
Listener: listener,
|
||||
maxConns: 1000, // default limit
|
||||
done: make(chan struct{}),
|
||||
activeSet: make(map[net.Conn]struct{}),
|
||||
}
|
||||
|
||||
@@ -20,39 +19,43 @@ func NewLimitedListener(listener net.Listener, opts ...Option) *LimitedListener
|
||||
|
||||
// Accept accepts a connection with limiting logic
|
||||
func (l *LimitedListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
// Check if rate limit is exceeded
|
||||
if l.limiter != nil {
|
||||
if !l.limiter.Allow() {
|
||||
return nil, ErrRateLimitExceeded
|
||||
}
|
||||
// Check if rate limit is exceeded
|
||||
if l.limiter != nil {
|
||||
if !l.limiter.Allow() {
|
||||
return nil, ErrRateLimitExceeded
|
||||
}
|
||||
|
||||
// Check concurrent connection limit
|
||||
l.mu.Lock()
|
||||
if l.maxConns > 0 && l.activeConns >= int64(l.maxConns) {
|
||||
l.mu.Unlock()
|
||||
return nil, ErrMaxConnsReached
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
// Accept the connection
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wrap the connection for tracking
|
||||
tracked := &trackedConn{
|
||||
Conn: conn,
|
||||
listener: l,
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
l.activeConns++
|
||||
l.activeSet[tracked] = struct{}{}
|
||||
l.mu.Unlock()
|
||||
|
||||
return tracked, nil
|
||||
}
|
||||
|
||||
// Atomically reserve a connection slot before calling Accept()
|
||||
l.mu.Lock()
|
||||
if l.maxConns > 0 && l.activeConns >= int64(l.maxConns) {
|
||||
l.mu.Unlock()
|
||||
return nil, ErrMaxConnsReached
|
||||
}
|
||||
// Reserve the slot by incrementing the counter
|
||||
l.activeConns++
|
||||
l.mu.Unlock()
|
||||
|
||||
// Now call Accept() with the slot already reserved
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
// Accept failed, release the reserved slot
|
||||
l.mu.Lock()
|
||||
l.activeConns--
|
||||
l.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wrap the connection for tracking
|
||||
tracked := &trackedConn{
|
||||
Conn: conn,
|
||||
listener: l,
|
||||
}
|
||||
|
||||
// Add to active set (connection count already incremented above)
|
||||
l.mu.Lock()
|
||||
l.activeSet[tracked] = struct{}{}
|
||||
l.mu.Unlock()
|
||||
|
||||
return tracked, nil
|
||||
}
|
||||
|
407
listener_test.go
Normal file
407
listener_test.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package limitedlistener
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockListener implements net.Listener for testing
|
||||
type mockListener struct {
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockListener() *mockListener {
|
||||
return &mockListener{
|
||||
connChan: make(chan net.Conn, 100),
|
||||
errChan: make(chan error, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-m.connChan:
|
||||
return conn, nil
|
||||
case err := <-m.errChan:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockListener) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if !m.closed {
|
||||
m.closed = true
|
||||
close(m.connChan)
|
||||
close(m.errChan)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
|
||||
}
|
||||
|
||||
func (m *mockListener) sendConn(conn net.Conn) {
|
||||
m.connChan <- conn
|
||||
}
|
||||
|
||||
// mockConn implements net.Conn for testing
|
||||
type mockConn struct {
|
||||
net.Conn
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockConn) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConn) Read(b []byte) (int, error) { return 0, nil }
|
||||
func (m *mockConn) Write(b []byte) (int, error) { return len(b), nil }
|
||||
func (m *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080} }
|
||||
func (m *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8081} }
|
||||
func (m *mockConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func TestConnectionLimitRaceCondition(t *testing.T) {
|
||||
mockL := newMockListener()
|
||||
defer mockL.Close()
|
||||
|
||||
// Create a limited listener with a small connection limit
|
||||
limited := NewLimitedListener(mockL, WithMaxConnections(2))
|
||||
|
||||
// Pre-populate the mock with connections
|
||||
for i := 0; i < 10; i++ {
|
||||
mockL.sendConn(&mockConn{})
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}, 10)
|
||||
|
||||
// Start multiple goroutines trying to accept connections concurrently
|
||||
numGoroutines := 10
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn, err := limited.Accept()
|
||||
results <- struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}{conn, err}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete with a timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// All completed
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Test timed out - possible deadlock")
|
||||
}
|
||||
close(results)
|
||||
|
||||
// Count results
|
||||
var successful, failed int
|
||||
var connections []net.Conn
|
||||
|
||||
for result := range results {
|
||||
if result.err == nil {
|
||||
successful++
|
||||
connections = append(connections, result.conn)
|
||||
} else if result.err == ErrMaxConnsReached {
|
||||
failed++
|
||||
} else {
|
||||
t.Errorf("Unexpected error: %v", result.err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close all successful connections
|
||||
for _, conn := range connections {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
// The key test: successful connections should not exceed the limit
|
||||
if successful > 2 {
|
||||
t.Errorf("Race condition detected: Expected at most 2 successful connections, but %d were accepted", successful)
|
||||
}
|
||||
|
||||
t.Logf("Successful: %d, Failed: %d", successful, failed)
|
||||
}
|
||||
|
||||
func TestConnectionLimitEnforcement(t *testing.T) {
|
||||
mockL := newMockListener()
|
||||
defer mockL.Close()
|
||||
|
||||
limited := NewLimitedListener(mockL, WithMaxConnections(1))
|
||||
|
||||
// Send mock connections
|
||||
for i := 0; i < 3; i++ {
|
||||
mockL.sendConn(&mockConn{})
|
||||
}
|
||||
|
||||
// Accept first connection - should succeed
|
||||
conn1, err := limited.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("First connection should be accepted: %v", err)
|
||||
}
|
||||
defer conn1.Close()
|
||||
|
||||
// Try to accept second connection - should fail with limit error
|
||||
conn2, err := limited.Accept()
|
||||
if err != ErrMaxConnsReached {
|
||||
t.Errorf("Expected ErrMaxConnsReached, got: %v", err)
|
||||
}
|
||||
if conn2 != nil {
|
||||
t.Error("Connection should be nil when limit is reached")
|
||||
conn2.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveSetDataRace(t *testing.T) {
|
||||
mockL := newMockListener()
|
||||
defer mockL.Close()
|
||||
|
||||
limited := NewLimitedListener(mockL, WithMaxConnections(10))
|
||||
|
||||
// Pre-populate with exactly enough connections to avoid blocking
|
||||
for i := 0; i < 15; i++ {
|
||||
mockL.sendConn(&mockConn{})
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
connections := make([]net.Conn, 0, 10)
|
||||
connMutex := sync.Mutex{}
|
||||
|
||||
// Accept 10 connections concurrently
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn, err := limited.Accept()
|
||||
if err == nil && conn != nil {
|
||||
connMutex.Lock()
|
||||
connections = append(connections, conn)
|
||||
connMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for accepts to complete
|
||||
wg.Wait()
|
||||
|
||||
// Now close some connections concurrently to test the race condition
|
||||
wg.Add(5)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
connMutex.Lock()
|
||||
if len(connections) > 0 {
|
||||
conn := connections[len(connections)-1]
|
||||
connections = connections[:len(connections)-1]
|
||||
connMutex.Unlock()
|
||||
conn.Close()
|
||||
} else {
|
||||
connMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Also test listener close concurrent with connection operations
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(10 * time.Millisecond) // Brief delay to let some closes happen
|
||||
limited.Close()
|
||||
}()
|
||||
|
||||
// Wait for all operations to complete
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success - no race detected
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Test timed out - possible deadlock or race")
|
||||
}
|
||||
|
||||
// Clean up any remaining connections
|
||||
connMutex.Lock()
|
||||
for _, conn := range connections {
|
||||
conn.Close()
|
||||
}
|
||||
connMutex.Unlock()
|
||||
}
|
||||
|
||||
func TestSlotReservationRaceCondition(t *testing.T) {
|
||||
mockL := newMockListener()
|
||||
defer mockL.Close()
|
||||
|
||||
// Create a listener with a very small limit
|
||||
limited := NewLimitedListener(mockL, WithMaxConnections(3))
|
||||
|
||||
// Pre-populate with enough connections to stress test the limit
|
||||
for i := 0; i < 20; i++ {
|
||||
mockL.sendConn(&mockConn{})
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successfulConnections []net.Conn
|
||||
var connMutex sync.Mutex
|
||||
|
||||
// Start many goroutines trying to accept connections concurrently
|
||||
numGoroutines := 15
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn, err := limited.Accept()
|
||||
if err == nil && conn != nil {
|
||||
connMutex.Lock()
|
||||
successfulConnections = append(successfulConnections, conn)
|
||||
connMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all accepts to complete
|
||||
wg.Wait()
|
||||
|
||||
// Verify that we didn't exceed the connection limit
|
||||
connMutex.Lock()
|
||||
actualCount := len(successfulConnections)
|
||||
connMutex.Unlock()
|
||||
|
||||
// Check that activeConns matches the actual number of successful connections
|
||||
stats := limited.GetStats()
|
||||
|
||||
if actualCount != int(stats.ActiveConnections) {
|
||||
t.Errorf("Slot reservation race detected: actualCount=%d, activeConns=%d",
|
||||
actualCount, stats.ActiveConnections)
|
||||
}
|
||||
|
||||
if actualCount > 3 {
|
||||
t.Errorf("Connection limit exceeded: expected max 3, got %d successful connections", actualCount)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
connMutex.Lock()
|
||||
for _, conn := range successfulConnections {
|
||||
conn.Close()
|
||||
}
|
||||
connMutex.Unlock()
|
||||
|
||||
t.Logf("Successfully limited to %d connections (limit: 3)", actualCount)
|
||||
}
|
||||
|
||||
// doubleCloseMockConn tracks how many times Close() is called
|
||||
type doubleCloseMockConn struct {
|
||||
*mockConn
|
||||
closeCount int32
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (d *doubleCloseMockConn) Close() error {
|
||||
d.mu.Lock()
|
||||
d.closeCount++
|
||||
closedTimes := d.closeCount
|
||||
d.mu.Unlock()
|
||||
|
||||
if closedTimes > 1 {
|
||||
panic(fmt.Sprintf("Connection closed %d times - double close detected!", closedTimes))
|
||||
}
|
||||
|
||||
return d.mockConn.Close()
|
||||
}
|
||||
|
||||
func (d *doubleCloseMockConn) getCloseCount() int32 {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return d.closeCount
|
||||
}
|
||||
|
||||
func TestDoubleClosePreevention(t *testing.T) {
|
||||
mockL := newMockListener()
|
||||
defer mockL.Close()
|
||||
|
||||
limited := NewLimitedListener(mockL, WithMaxConnections(5))
|
||||
|
||||
// Create special mock connections that track close calls
|
||||
doubleCloseMocks := make([]*doubleCloseMockConn, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
doubleCloseMocks[i] = &doubleCloseMockConn{
|
||||
mockConn: &mockConn{},
|
||||
}
|
||||
mockL.sendConn(doubleCloseMocks[i])
|
||||
}
|
||||
|
||||
// Accept the connections
|
||||
var connections []net.Conn
|
||||
for i := 0; i < 3; i++ {
|
||||
conn, err := limited.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to accept connection %d: %v", i, err)
|
||||
}
|
||||
connections = append(connections, conn)
|
||||
}
|
||||
|
||||
// Close the first connection manually (user close)
|
||||
err := connections[0].Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to close connection manually: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was closed once
|
||||
if doubleCloseMocks[0].getCloseCount() != 1 {
|
||||
t.Errorf("Expected connection 0 to be closed once, got %d times", doubleCloseMocks[0].getCloseCount())
|
||||
}
|
||||
|
||||
// Now close the listener - this should close remaining connections but NOT double-close the first one
|
||||
err = limited.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to close listener: %v", err)
|
||||
}
|
||||
|
||||
// Verify close counts
|
||||
for i, mock := range doubleCloseMocks {
|
||||
closeCount := mock.getCloseCount()
|
||||
if closeCount != 1 {
|
||||
t.Errorf("Connection %d should be closed exactly once, but was closed %d times", i, closeCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Test idempotent close - closing already closed connections should be safe
|
||||
for i, conn := range connections {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Idempotent close of connection %d should not return error: %v", i, err)
|
||||
}
|
||||
|
||||
// Close count should still be 1 (idempotent)
|
||||
if doubleCloseMocks[i].getCloseCount() != 1 {
|
||||
t.Errorf("After idempotent close, connection %d close count should still be 1, got %d", i, doubleCloseMocks[i].getCloseCount())
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("Double close prevention test passed - all connections closed exactly once")
|
||||
}
|
131
race_condition_test.go
Normal file
131
race_condition_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package limitedlistener
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// controlledMockListener allows us to control when Accept() returns
|
||||
type controlledMockListener struct {
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
acceptCh chan struct{} // Signal to allow Accept() to proceed
|
||||
acceptWait bool // Whether Accept() should wait
|
||||
}
|
||||
|
||||
func newControlledMockListener() *controlledMockListener {
|
||||
return &controlledMockListener{
|
||||
connChan: make(chan net.Conn, 100),
|
||||
errChan: make(chan error, 10),
|
||||
acceptCh: make(chan struct{}, 100), // Buffered channel
|
||||
}
|
||||
}
|
||||
|
||||
func (m *controlledMockListener) Accept() (net.Conn, error) {
|
||||
// Wait for signal if control is enabled
|
||||
if m.acceptWait {
|
||||
<-m.acceptCh
|
||||
}
|
||||
|
||||
select {
|
||||
case conn := <-m.connChan:
|
||||
return conn, nil
|
||||
case err := <-m.errChan:
|
||||
return nil, err
|
||||
default:
|
||||
// Return a mock connection if no specific conn/err was queued
|
||||
return &mockConn{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *controlledMockListener) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if !m.closed {
|
||||
m.closed = true
|
||||
close(m.connChan)
|
||||
close(m.errChan)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *controlledMockListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
|
||||
}
|
||||
|
||||
func (m *controlledMockListener) enableAcceptControl() {
|
||||
m.acceptWait = true
|
||||
}
|
||||
|
||||
func (m *controlledMockListener) releaseAllAccepts(count int) {
|
||||
for i := 0; i < count; i++ {
|
||||
m.acceptCh <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRaceConditionReproduction creates a scenario that reliably reproduces
|
||||
// the race condition where connection limits can be temporarily exceeded
|
||||
func TestRaceConditionReproduction(t *testing.T) {
|
||||
// Use a simple approach: stress test with many concurrent accepts
|
||||
// and check if we can exceed the limit
|
||||
for attempt := 0; attempt < 10; attempt++ {
|
||||
mockL := newControlledMockListener()
|
||||
limited := NewLimitedListener(mockL, WithMaxConnections(2))
|
||||
|
||||
// Pre-signal the accepts to avoid blocking
|
||||
mockL.enableAcceptControl()
|
||||
mockL.releaseAllAccepts(10)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successfulConns int64
|
||||
var connections []net.Conn
|
||||
var connMutex sync.Mutex
|
||||
|
||||
// Start multiple goroutines trying to accept connections simultaneously
|
||||
numGoroutines := 5
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
conn, err := limited.Accept()
|
||||
if err == nil && conn != nil {
|
||||
atomic.AddInt64(&successfulConns, 1)
|
||||
connMutex.Lock()
|
||||
connections = append(connections, conn)
|
||||
connMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for completion
|
||||
wg.Wait()
|
||||
|
||||
// Clean up connections
|
||||
connMutex.Lock()
|
||||
for _, conn := range connections {
|
||||
conn.Close()
|
||||
}
|
||||
connMutex.Unlock()
|
||||
|
||||
mockL.Close()
|
||||
limited.Close()
|
||||
|
||||
successful := atomic.LoadInt64(&successfulConns)
|
||||
|
||||
// Check if we reproduced the race condition
|
||||
if successful > 2 {
|
||||
t.Logf("RACE CONDITION REPRODUCED on attempt %d: %d connections accepted (limit was 2)", attempt+1, successful)
|
||||
t.Errorf("Connection limit exceeded due to race condition: expected at most 2, got %d", successful)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("Attempt %d: %d connections accepted (within limit)", attempt+1, successful)
|
||||
}
|
||||
|
||||
t.Log("Race condition was not reproduced in 10 attempts, but the theoretical issue still exists")
|
||||
}
|
14
stats.go
14
stats.go
@@ -1,5 +1,7 @@
|
||||
package limitedlistener
|
||||
|
||||
import "net"
|
||||
|
||||
// Stats provides current listener statistics
|
||||
type Stats struct {
|
||||
ActiveConnections int64
|
||||
@@ -27,13 +29,17 @@ func (l *LimitedListener) GetStats() Stats {
|
||||
// Close implements graceful shutdown
|
||||
func (l *LimitedListener) Close() error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// Close all active connections
|
||||
// Create a slice of connections to close, so we can release the lock
|
||||
connections := make([]net.Conn, 0, len(l.activeSet))
|
||||
for conn := range l.activeSet {
|
||||
connections = append(connections, conn)
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
// Close all active connections outside the lock to avoid deadlock
|
||||
for _, conn := range connections {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
close(l.done)
|
||||
return l.Listener.Close()
|
||||
}
|
||||
|
@@ -10,15 +10,30 @@ type trackedConn struct {
|
||||
net.Conn
|
||||
listener *LimitedListener
|
||||
once sync.Once
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Close implements net.Conn Close with connection tracking
|
||||
func (c *trackedConn) Close() error {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil // Already closed, return without error
|
||||
}
|
||||
c.closed = true
|
||||
c.mu.Unlock()
|
||||
|
||||
var closeErr error
|
||||
c.once.Do(func() {
|
||||
// Remove from active set and decrement counter atomically
|
||||
c.listener.mu.Lock()
|
||||
delete(c.listener.activeSet, c)
|
||||
c.listener.activeConns--
|
||||
c.listener.mu.Unlock()
|
||||
|
||||
// Close the underlying connection
|
||||
closeErr = c.Conn.Close()
|
||||
})
|
||||
return c.Conn.Close()
|
||||
return closeErr
|
||||
}
|
||||
|
Reference in New Issue
Block a user