From 6bd37c4d51985d154abdec3aff12b15221ada0ca Mon Sep 17 00:00:00 2001 From: idk Date: Sun, 29 Nov 2020 16:53:46 -0500 Subject: [PATCH] kill connections when they die in context --- dial.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/dial.go b/dial.go index 7e3440b..4820a6e 100644 --- a/dial.go +++ b/dial.go @@ -2,7 +2,6 @@ package goSam import ( "context" - "fmt" "log" "net" "strings" @@ -15,7 +14,7 @@ func (c *Client) DialContext(ctx context.Context, network, addr string) (net.Con errCh := make(chan error, 1) connCh := make(chan net.Conn, 1) go func() { - if conn, err := c.Dial(network, addr); err != nil { + if conn, err := c.DialContextFree(network, addr); err != nil { errCh <- err } else if ctx.Err() != nil { log.Println(ctx) @@ -26,10 +25,12 @@ func (c *Client) DialContext(ctx context.Context, network, addr string) (net.Con }() select { case err := <-errCh: - return c.SamConn, err + c.Close() + return nil, err case conn := <-connCh: return conn, nil case <-ctx.Done(): + c.Close() return nil, ctx.Err() } } @@ -45,8 +46,12 @@ func (c *Client) dialCheck(addr string) (int32, bool) { return c.id, false } -// Dial implements the net.Dial function and can be used for http.Transport func (c *Client) Dial(network, addr string) (net.Conn, error) { + return c.DialContext(context.Background(), network, addr) +} + +// Dial implements the net.Dial function and can be used for http.Transport +func (c *Client) DialContextFree(network, addr string) (net.Conn, error) { c.ml.Lock() defer c.ml.Unlock() portIdx := strings.Index(addr, ":")