github.com/phuslu/fastdns@v0.8.3-0.20240310041952-69506fc67dd1/server_fork.go (about)

     1  package fastdns
     2  
     3  import (
     4  	"errors"
     5  	"log"
     6  	"os"
     7  	"os/exec"
     8  	"runtime"
     9  	"strconv"
    10  )
    11  
    12  // ForkServer implements a prefork DNS server.
    13  type ForkServer struct {
    14  	// handler to invoke
    15  	Handler Handler
    16  
    17  	// stats to invoke
    18  	Stats Stats
    19  
    20  	// ErrorLog specifies an optional logger for errors accepting
    21  	// connections, unexpected behavior from handlers, and
    22  	// underlying FileSystem errors.
    23  	// If nil, logging is done via the log package's standard logger.
    24  	ErrorLog *log.Logger
    25  
    26  	// The maximum number of procs the server may spawn. use runtime.NumCPU() if empty
    27  	MaxProcs int
    28  
    29  	// SetAffinity sets the CPU affinity mask of current process.
    30  	SetAffinity bool
    31  
    32  	// The maximum number of concurrent clients the server may serve.
    33  	Concurrency int
    34  }
    35  
    36  // ListenAndServe serves DNS requests from the given UDP addr.
    37  func (s *ForkServer) ListenAndServe(addr string) error {
    38  	if s.Index() == 0 {
    39  		return s.fork(addr, s.MaxProcs)
    40  	}
    41  
    42  	if s.ErrorLog == nil {
    43  		s.ErrorLog = log.Default()
    44  	}
    45  
    46  	if s.SetAffinity {
    47  		// set cpu affinity for performance
    48  		err := taskset((s.Index() - 1) % runtime.NumCPU())
    49  		if err != nil {
    50  			s.ErrorLog.Printf("forkserver-%d set cpu_affinity=%d failed: %+v", s.Index(), s.Index()-1, err)
    51  		}
    52  	}
    53  
    54  	// so_reuseport listen for performance
    55  	conn, err := listen("udp", addr)
    56  	if err != nil {
    57  		s.ErrorLog.Printf("forkserver-%d listen on addr=%s failed: %+v", s.Index(), addr, err)
    58  		return err
    59  	}
    60  
    61  	// s.ErrorLog.Printf("forkserver-%d pid-%d serving dns on %s", s.Index(), os.Getpid(), conn.LocalAddr())
    62  
    63  	return serve(conn, s.Handler, s.Stats, s.ErrorLog, s.Concurrency)
    64  }
    65  
    66  // Index indicates the index of Server instances.
    67  func (s *ForkServer) Index() (index int) {
    68  	index, _ = strconv.Atoi(os.Getenv("FASTDNS_CHILD_INDEX"))
    69  	return
    70  }
    71  
    72  func fork(index int) (*exec.Cmd, error) {
    73  	/* #nosec G204 */
    74  	cmd := exec.Command(os.Args[0], os.Args[1:]...)
    75  	cmd.Stdout = os.Stdout
    76  	cmd.Stderr = os.Stderr
    77  	cmd.Env = append([]string{"FASTDNS_CHILD_INDEX=" + strconv.Itoa(index)}, os.Environ()...)
    78  	return cmd, cmd.Start()
    79  }
    80  
    81  func (s *ForkServer) fork(addr string, maxProcs int) (err error) {
    82  	type racer struct {
    83  		index int
    84  		pid   int
    85  		err   error
    86  	}
    87  
    88  	if maxProcs == 0 {
    89  		maxProcs = runtime.NumCPU()
    90  	}
    91  	if runtime.GOOS != "linux" {
    92  		maxProcs = 1
    93  	}
    94  
    95  	ch := make(chan racer, maxProcs)
    96  	childs := make(map[int]*exec.Cmd)
    97  
    98  	defer func() {
    99  		for _, proc := range childs {
   100  			_ = proc.Process.Kill()
   101  		}
   102  	}()
   103  
   104  	for i := 1; i <= maxProcs; i++ {
   105  		var cmd *exec.Cmd
   106  		if cmd, err = fork(i); err != nil {
   107  			s.ErrorLog.Printf("forkserver failed to start a child process, error: %v\n", err)
   108  			return
   109  		}
   110  
   111  		childs[cmd.Process.Pid] = cmd
   112  		go func(index int) {
   113  			ch <- racer{index, cmd.Process.Pid, cmd.Wait()}
   114  		}(i)
   115  	}
   116  
   117  	var exited int
   118  	for sig := range ch {
   119  		delete(childs, sig.pid)
   120  
   121  		s.ErrorLog.Printf("forkserver one of the child processes exited with error: %v", sig.err)
   122  
   123  		if exited++; exited > 200 {
   124  			s.ErrorLog.Printf("forkserver child workers exit too many times(%d)", exited)
   125  			err = errors.New("forkserver child workers exit too many times")
   126  			break
   127  		}
   128  
   129  		var cmd *exec.Cmd
   130  		if cmd, err = fork(sig.index); err != nil {
   131  			break
   132  		}
   133  		childs[cmd.Process.Pid] = cmd
   134  		go func(index int) {
   135  			ch <- racer{index, cmd.Process.Pid, cmd.Wait()}
   136  		}(sig.index)
   137  	}
   138  
   139  	return
   140  }