github.com/phuslu/fastdns@v0.8.3-0.20240310041952-69506fc67dd1/server.go (about) 1 package fastdns 2 3 import ( 4 "errors" 5 "log" 6 "net" 7 "runtime" 8 "sync" 9 "time" 10 ) 11 12 // Server implements a mutli-listener DNS server. 13 type Server struct { 14 // handler to invoke 15 Handler Handler 16 17 // Stats to invoke 18 Stats Stats 19 20 // ErrorLog specifies an optional logger for errors accepting 21 // connections, unexpected behavior from handlers, and 22 // underlying FileSystem errors. 23 // If nil, logging is done via the log package's standard logger. 24 ErrorLog *log.Logger 25 26 // The maximum number of procs the server may spawn. use runtime.NumCPU() if empty 27 MaxProcs int 28 29 // The maximum number of concurrent clients the server may serve. 30 Concurrency int 31 32 // Index indicates the index of Server instances. 33 index int 34 } 35 36 // ListenAndServe serves DNS requests from the given UDP addr. 37 func (s *Server) ListenAndServe(addr string) error { 38 if s.Index() == 0 { 39 // only prefork for linux(reuse_port) 40 return s.spawn(addr, s.MaxProcs) 41 } 42 43 if s.ErrorLog == nil { 44 s.ErrorLog = log.Default() 45 } 46 47 conn, err := listen("udp", addr) 48 if err != nil { 49 s.ErrorLog.Printf("server-%d listen on addr=%s failed: %+v", s.Index(), addr, err) 50 return err 51 } 52 53 // s.ErrorLog.Printf("server-%d pid-%d serving dns on %s", s.Index(), os.Getpid(), conn.LocalAddr()) 54 55 return serve(conn, s.Handler, s.Stats, s.ErrorLog, s.Concurrency) 56 } 57 58 // Index indicates the index of Server instances. 59 func (s *Server) Index() (index int) { 60 index = s.index 61 return 62 } 63 64 func (s *Server) spawn(addr string, maxProcs int) (err error) { 65 type racer struct { 66 index int 67 err error 68 } 69 70 if maxProcs == 0 { 71 maxProcs = runtime.NumCPU() 72 } 73 if runtime.GOOS != "linux" { 74 maxProcs = 1 75 } 76 77 ch := make(chan racer, maxProcs) 78 79 // create multiple receive worker for performance 80 for i := 1; i <= maxProcs; i++ { 81 go func(index int) { 82 server := &Server{ 83 Handler: s.Handler, 84 Stats: s.Stats, 85 ErrorLog: s.ErrorLog, 86 MaxProcs: s.MaxProcs, 87 Concurrency: s.Concurrency, 88 index: index, 89 } 90 err := server.ListenAndServe(addr) 91 ch <- racer{index, err} 92 }(i) 93 } 94 95 var exited int 96 for sig := range ch { 97 s.ErrorLog.Printf("server one of the child workers exited with error: %v", sig.err) 98 99 if exited++; exited > 200 { 100 s.ErrorLog.Printf("server child workers exit too many times(%d)", exited) 101 err = errors.New("server child workers exit too many times") 102 break 103 } 104 105 go func(index int) { 106 server := &Server{ 107 Handler: s.Handler, 108 Stats: s.Stats, 109 ErrorLog: s.ErrorLog, 110 MaxProcs: s.MaxProcs, 111 Concurrency: s.Concurrency, 112 index: index, 113 } 114 err := server.ListenAndServe(addr) 115 ch <- racer{index, err} 116 }(sig.index) 117 } 118 119 return 120 } 121 122 type udpCtx struct { 123 rw *udpResponseWriter 124 req *Message 125 handler Handler 126 stats Stats 127 } 128 129 var udpCtxPool = &sync.Pool{ 130 New: func() interface{} { 131 ctx := new(udpCtx) 132 ctx.rw = new(udpResponseWriter) 133 ctx.req = new(Message) 134 ctx.req.Raw = make([]byte, 0, 1024) 135 ctx.req.Domain = make([]byte, 0, 256) 136 return ctx 137 }, 138 } 139 140 func serve(conn *net.UDPConn, handler Handler, stats Stats, logger *log.Logger, concurrency int) error { 141 if concurrency == 0 { 142 concurrency = 256 * 1024 143 } 144 145 pool := &workerPool{ 146 WorkerFunc: serveCtx, 147 MaxWorkersCount: concurrency, 148 LogAllErrors: false, 149 MaxIdleWorkerDuration: 2 * time.Minute, 150 Logger: logger, 151 } 152 pool.Start() 153 154 for { 155 ctx := udpCtxPool.Get().(*udpCtx) 156 157 ctx.req.Raw = ctx.req.Raw[:cap(ctx.req.Raw)] 158 n, addrPort, err := conn.ReadFromUDPAddrPort(ctx.req.Raw) 159 if err != nil { 160 udpCtxPool.Put(ctx) 161 time.Sleep(10 * time.Millisecond) 162 163 continue 164 } 165 166 ctx.req.Raw = ctx.req.Raw[:n] 167 ctx.rw.Conn = conn 168 ctx.rw.AddrPort = addrPort 169 170 ctx.handler = handler 171 ctx.stats = stats 172 173 pool.Serve(ctx) 174 } 175 } 176 177 func serveCtx(ctx *udpCtx) error { 178 var start time.Time 179 if ctx.stats != nil { 180 start = time.Now() 181 } 182 183 rw, req := ctx.rw, ctx.req 184 185 err := ParseMessage(req, req.Raw, false) 186 if err != nil { 187 Error(rw, req, RcodeFormErr) 188 } else { 189 ctx.handler.ServeDNS(rw, req) 190 } 191 192 if ctx.stats != nil { 193 ctx.stats.UpdateStats(rw.RemoteAddr(), req, time.Since(start)) 194 } 195 196 udpCtxPool.Put(ctx) 197 198 return err 199 }