diff --git a/cmd/reseed.go b/cmd/reseed.go index 0be280d..3146938 100644 --- a/cmd/reseed.go +++ b/cmd/reseed.go @@ -2,10 +2,8 @@ package cmd import ( "fmt" - "io/ioutil" "log" "net" - "strings" "time" "github.com/MDrollette/i2p-tools/reseed" @@ -153,17 +151,15 @@ func reseedAction(c *cli.Context) { // create a server server := reseed.NewServer(c.String("prefix"), c.Bool("trustProxy")) + blacklist := reseed.NewBlacklist() + server.Blacklist = blacklist server.Reseeder = reseeder server.Addr = net.JoinHostPort(c.String("ip"), c.String("port")) // load a blacklist blacklistFile := c.String("blacklist") - if blacklistFile != "" { - if content, err := ioutil.ReadFile(blacklistFile); err == nil { - server.Blacklist = strings.Split(string(content), "\n") - } else { - log.Fatalln("Failed to load blacklist: ", err) - } + if "" != blacklistFile { + blacklist.LoadFile(blacklistFile) } if tlsHost != "" && tlsCert != "" && tlsKey != "" { diff --git a/reseed/blacklist.go b/reseed/blacklist.go new file mode 100644 index 0000000..23f0c98 --- /dev/null +++ b/reseed/blacklist.go @@ -0,0 +1,78 @@ +package reseed + +import ( + "io/ioutil" + "log" + "net" + "strings" + "sync" +) + +type Blacklist struct { + blacklist map[string]bool + m sync.RWMutex +} + +func NewBlacklist() *Blacklist { + return &Blacklist{blacklist: make(map[string]bool), m: sync.RWMutex{}} +} + +func (s *Blacklist) LoadFile(file string) error { + if file != "" { + if content, err := ioutil.ReadFile(file); err == nil { + for _, ip := range strings.Split(string(content), "\n") { + s.BlockIp(ip) + } + } else { + return err + } + } + + return nil +} + +func (s *Blacklist) BlockIp(ip string) { + s.m.Lock() + defer s.m.Unlock() + + s.blacklist[ip] = true +} + +func (s *Blacklist) isBlocked(ip string) bool { + s.m.RLock() + defer s.m.RUnlock() + + blocked, found := s.blacklist[ip] + + return found && blocked +} + +type blacklistListener struct { + *net.TCPListener + blacklist *Blacklist +} + +func (ln blacklistListener) Accept() (net.Conn, error) { + tc, err := ln.AcceptTCP() + if err != nil { + return nil, err + } + + ip, _, err := net.SplitHostPort(tc.RemoteAddr().String()) + if err != nil { + tc.Close() + return tc, err + } + + if ln.blacklist.isBlocked(ip) { + tc.Close() + log.Printf("blocked connection from: %s\n", ip) + return tc, nil + } + + return tc, err +} + +func newBlacklistListener(ln net.Listener, bl *Blacklist) blacklistListener { + return blacklistListener{ln.(*net.TCPListener), bl} +} diff --git a/reseed/server.go b/reseed/server.go index 1542236..0a4882c 100644 --- a/reseed/server.go +++ b/reseed/server.go @@ -20,43 +20,52 @@ const ( I2P_USER_AGENT = "Wget/1.11.4" ) -type Listener struct { - net.Listener - Blacklist []string -} - -func (nl Listener) Accept() (net.Conn, error) { - for { - c, err := nl.Listener.Accept() - if err != nil { - return nil, err - } - - host, port, err := net.SplitHostPort(c.RemoteAddr().String()) - if err != nil { - l.Printf("accept fail: %s\n", err.Error()) - go c.Close() - continue - } - - ip := net.ParseIP(host) - - for _, cidr := range nl.Blacklist { - if _, ipnet, err := net.ParseCIDR(cidr); err == nil { - if ipnet.Contains(ip) { - l.Printf("allow conn from: %s:%s\n", host, port) - return c, err - } - } - } - - go c.Close() - } -} - type Server struct { *http.Server - Reseeder Reseeder + Reseeder Reseeder + Blacklist *Blacklist +} + +func (srv *Server) ListenAndServe() error { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + return srv.Serve(newBlacklistListener(ln, srv.Blacklist)) +} + +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + config := &tls.Config{} + if srv.TLSConfig != nil { + *config = *srv.TLSConfig + } + if config.NextProtos == nil { + config.NextProtos = []string{"http/1.1"} + } + + var err error + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + tlsListener := tls.NewListener(newBlacklistListener(ln, srv.Blacklist), config) + return srv.Serve(tlsListener) } func NewServer(prefix string, trustProxy bool) *Server { @@ -80,7 +89,7 @@ func NewServer(prefix string, trustProxy bool) *Server { }, } h := &http.Server{TLSConfig: config} - server := Server{h, nil} + server := Server{Server: h, Reseeder: nil} th := throttled.RateLimit(throttled.PerDay(4), &throttled.VaryBy{RemoteAddr: true}, store.NewMemStore(100000))