github.com/phuslu/fastdns@v0.8.3-0.20240310041952-69506fc67dd1/server.go (about)

     1  package fastdns
     2  
     3  import (
     4  	"errors"
     5  	"log"
     6  	"net"
     7  	"runtime"
     8  	"sync"
     9  	"time"
    10  )
    11  
    12  // Server implements a mutli-listener DNS server.
    13  type Server struct {
    14  	// handler to invoke
    15  	Handler Handler
    16  
    17  	// Stats to invoke
    18  	Stats Stats
    19  
    20  	// ErrorLog specifies an optional logger for errors accepting
    21  	// connections, unexpected behavior from handlers, and
    22  	// underlying FileSystem errors.
    23  	// If nil, logging is done via the log package's standard logger.
    24  	ErrorLog *log.Logger
    25  
    26  	// The maximum number of procs the server may spawn. use runtime.NumCPU() if empty
    27  	MaxProcs int
    28  
    29  	// The maximum number of concurrent clients the server may serve.
    30  	Concurrency int
    31  
    32  	// Index indicates the index of Server instances.
    33  	index int
    34  }
    35  
    36  // ListenAndServe serves DNS requests from the given UDP addr.
    37  func (s *Server) ListenAndServe(addr string) error {
    38  	if s.Index() == 0 {
    39  		// only prefork for linux(reuse_port)
    40  		return s.spawn(addr, s.MaxProcs)
    41  	}
    42  
    43  	if s.ErrorLog == nil {
    44  		s.ErrorLog = log.Default()
    45  	}
    46  
    47  	conn, err := listen("udp", addr)
    48  	if err != nil {
    49  		s.ErrorLog.Printf("server-%d listen on addr=%s failed: %+v", s.Index(), addr, err)
    50  		return err
    51  	}
    52  
    53  	// s.ErrorLog.Printf("server-%d pid-%d serving dns on %s", s.Index(), os.Getpid(), conn.LocalAddr())
    54  
    55  	return serve(conn, s.Handler, s.Stats, s.ErrorLog, s.Concurrency)
    56  }
    57  
    58  // Index indicates the index of Server instances.
    59  func (s *Server) Index() (index int) {
    60  	index = s.index
    61  	return
    62  }
    63  
    64  func (s *Server) spawn(addr string, maxProcs int) (err error) {
    65  	type racer struct {
    66  		index int
    67  		err   error
    68  	}
    69  
    70  	if maxProcs == 0 {
    71  		maxProcs = runtime.NumCPU()
    72  	}
    73  	if runtime.GOOS != "linux" {
    74  		maxProcs = 1
    75  	}
    76  
    77  	ch := make(chan racer, maxProcs)
    78  
    79  	// create multiple receive worker for performance
    80  	for i := 1; i <= maxProcs; i++ {
    81  		go func(index int) {
    82  			server := &Server{
    83  				Handler:     s.Handler,
    84  				Stats:       s.Stats,
    85  				ErrorLog:    s.ErrorLog,
    86  				MaxProcs:    s.MaxProcs,
    87  				Concurrency: s.Concurrency,
    88  				index:       index,
    89  			}
    90  			err := server.ListenAndServe(addr)
    91  			ch <- racer{index, err}
    92  		}(i)
    93  	}
    94  
    95  	var exited int
    96  	for sig := range ch {
    97  		s.ErrorLog.Printf("server one of the child workers exited with error: %v", sig.err)
    98  
    99  		if exited++; exited > 200 {
   100  			s.ErrorLog.Printf("server child workers exit too many times(%d)", exited)
   101  			err = errors.New("server child workers exit too many times")
   102  			break
   103  		}
   104  
   105  		go func(index int) {
   106  			server := &Server{
   107  				Handler:     s.Handler,
   108  				Stats:       s.Stats,
   109  				ErrorLog:    s.ErrorLog,
   110  				MaxProcs:    s.MaxProcs,
   111  				Concurrency: s.Concurrency,
   112  				index:       index,
   113  			}
   114  			err := server.ListenAndServe(addr)
   115  			ch <- racer{index, err}
   116  		}(sig.index)
   117  	}
   118  
   119  	return
   120  }
   121  
   122  type udpCtx struct {
   123  	rw      *udpResponseWriter
   124  	req     *Message
   125  	handler Handler
   126  	stats   Stats
   127  }
   128  
   129  var udpCtxPool = &sync.Pool{
   130  	New: func() interface{} {
   131  		ctx := new(udpCtx)
   132  		ctx.rw = new(udpResponseWriter)
   133  		ctx.req = new(Message)
   134  		ctx.req.Raw = make([]byte, 0, 1024)
   135  		ctx.req.Domain = make([]byte, 0, 256)
   136  		return ctx
   137  	},
   138  }
   139  
   140  func serve(conn *net.UDPConn, handler Handler, stats Stats, logger *log.Logger, concurrency int) error {
   141  	if concurrency == 0 {
   142  		concurrency = 256 * 1024
   143  	}
   144  
   145  	pool := &workerPool{
   146  		WorkerFunc:            serveCtx,
   147  		MaxWorkersCount:       concurrency,
   148  		LogAllErrors:          false,
   149  		MaxIdleWorkerDuration: 2 * time.Minute,
   150  		Logger:                logger,
   151  	}
   152  	pool.Start()
   153  
   154  	for {
   155  		ctx := udpCtxPool.Get().(*udpCtx)
   156  
   157  		ctx.req.Raw = ctx.req.Raw[:cap(ctx.req.Raw)]
   158  		n, addrPort, err := conn.ReadFromUDPAddrPort(ctx.req.Raw)
   159  		if err != nil {
   160  			udpCtxPool.Put(ctx)
   161  			time.Sleep(10 * time.Millisecond)
   162  
   163  			continue
   164  		}
   165  
   166  		ctx.req.Raw = ctx.req.Raw[:n]
   167  		ctx.rw.Conn = conn
   168  		ctx.rw.AddrPort = addrPort
   169  
   170  		ctx.handler = handler
   171  		ctx.stats = stats
   172  
   173  		pool.Serve(ctx)
   174  	}
   175  }
   176  
   177  func serveCtx(ctx *udpCtx) error {
   178  	var start time.Time
   179  	if ctx.stats != nil {
   180  		start = time.Now()
   181  	}
   182  
   183  	rw, req := ctx.rw, ctx.req
   184  
   185  	err := ParseMessage(req, req.Raw, false)
   186  	if err != nil {
   187  		Error(rw, req, RcodeFormErr)
   188  	} else {
   189  		ctx.handler.ServeDNS(rw, req)
   190  	}
   191  
   192  	if ctx.stats != nil {
   193  		ctx.stats.UpdateStats(rw.RemoteAddr(), req, time.Since(start))
   194  	}
   195  
   196  	udpCtxPool.Put(ctx)
   197  
   198  	return err
   199  }