github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/server.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"log/slog"
    10  	"net"
    11  	"time"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/log"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/dialer"
    15  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    16  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    17  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    18  	"github.com/Asutorufa/yuhaiin/pkg/utils/system"
    19  	"golang.org/x/net/dns/dnsmessage"
    20  	"golang.org/x/sync/semaphore"
    21  )
    22  
    23  type dnsServer struct {
    24  	server      string
    25  	resolver    netapi.Resolver
    26  	listener    net.PacketConn
    27  	tcpListener net.Listener
    28  
    29  	sf *semaphore.Weighted
    30  }
    31  
    32  func NewServer(server string, process netapi.Resolver) netapi.DNSServer {
    33  	d := &dnsServer{
    34  		server:   server,
    35  		resolver: process,
    36  		sf:       semaphore.NewWeighted(200),
    37  	}
    38  
    39  	if server == "" {
    40  		log.Info("dns server is empty, skip to listen tcp and udp")
    41  		return d
    42  	}
    43  
    44  	if err := d.startUDP(); err != nil {
    45  		log.Error("start udp dns server failed", slog.Any("err", err))
    46  	}
    47  
    48  	go func() {
    49  		if err := d.startTCP(); err != nil {
    50  			log.Error("start tcp dns server failed", slog.Any("err", err))
    51  		}
    52  	}()
    53  
    54  	return d
    55  }
    56  
    57  func (d *dnsServer) Close() error {
    58  	if d.listener != nil {
    59  		d.listener.Close()
    60  	}
    61  	if d.tcpListener != nil {
    62  		d.tcpListener.Close()
    63  	}
    64  
    65  	return nil
    66  }
    67  
    68  func (d *dnsServer) startUDP() (err error) {
    69  	d.listener, err = dialer.ListenPacket("udp", d.server)
    70  	if err != nil {
    71  		return fmt.Errorf("dns udp server listen failed: %w", err)
    72  	}
    73  
    74  	log.Info("new udp dns server", "host", d.server)
    75  
    76  	for i := 0; i < system.Procs; i++ {
    77  		go func() {
    78  			defer d.Close()
    79  
    80  			for {
    81  				buf := pool.GetBytesBuffer(nat.MaxSegmentSize)
    82  				_, addr, err := buf.ReadFromPacket(d.listener)
    83  				if err != nil {
    84  					buf.Free()
    85  
    86  					if e, ok := err.(net.Error); ok && e.Temporary() {
    87  						continue
    88  					}
    89  
    90  					if !errors.Is(err, net.ErrClosed) {
    91  						log.Error("dns udp server handle failed", "err", err)
    92  					}
    93  					return
    94  				}
    95  
    96  				err = d.sf.Acquire(context.TODO(), 1)
    97  				if err != nil {
    98  					buf.Free()
    99  					continue
   100  				}
   101  
   102  				go func() {
   103  					defer d.sf.Release(1)
   104  					err := d.Do(context.TODO(), buf, func(b []byte) error {
   105  						if _, err = d.listener.WriteTo(b, addr); err != nil {
   106  							return fmt.Errorf("write dns response to client failed: %w", err)
   107  						}
   108  						return nil
   109  					})
   110  					if err != nil {
   111  						log.Error("dns server handle data failed", slog.Any("err", err))
   112  					}
   113  				}()
   114  
   115  			}
   116  		}()
   117  	}
   118  
   119  	return nil
   120  }
   121  
   122  func (d *dnsServer) startTCP() (err error) {
   123  	defer d.Close()
   124  
   125  	d.tcpListener, err = dialer.ListenContext(context.TODO(), "tcp", d.server)
   126  	if err != nil {
   127  		return fmt.Errorf("dns tcp server listen failed: %w", err)
   128  	}
   129  
   130  	log.Info("new tcp dns server", "host", d.server)
   131  
   132  	for {
   133  		conn, err := d.tcpListener.Accept()
   134  		if err != nil {
   135  			if e, ok := err.(net.Error); ok && e.Temporary() {
   136  				continue
   137  			}
   138  			return fmt.Errorf("dns server accept failed: %w", err)
   139  		}
   140  
   141  		go func() {
   142  			defer conn.Close()
   143  
   144  			if err := d.HandleTCP(context.TODO(), conn); err != nil {
   145  				log.Error("handle dns tcp failed", "err", err)
   146  			}
   147  		}()
   148  	}
   149  }
   150  
   151  func (d *dnsServer) HandleTCP(ctx context.Context, c net.Conn) error {
   152  	var length uint16
   153  	if err := binary.Read(c, binary.BigEndian, &length); err != nil {
   154  		return fmt.Errorf("read dns length failed: %w", err)
   155  	}
   156  
   157  	data := pool.GetBytesBuffer(int(length))
   158  
   159  	_, err := io.ReadFull(c, data.Bytes())
   160  	if err != nil {
   161  		return fmt.Errorf("dns server read data failed: %w", err)
   162  	}
   163  
   164  	return d.Do(ctx, data, func(b []byte) error {
   165  		if err = binary.Write(c, binary.BigEndian, uint16(len(b))); err != nil {
   166  			return fmt.Errorf("dns server write length failed: %w", err)
   167  		}
   168  		_, err = c.Write(b)
   169  		return err
   170  	})
   171  }
   172  
   173  func (d *dnsServer) HandleUDP(ctx context.Context, l net.PacketConn) error {
   174  	buf := pool.GetBytesBuffer(nat.MaxSegmentSize)
   175  
   176  	_, addr, err := buf.ReadFromPacket(l)
   177  	if err != nil {
   178  		return err
   179  	}
   180  
   181  	return d.Do(context.TODO(), buf, func(b []byte) error {
   182  		_, err = l.WriteTo(b, addr)
   183  		return err
   184  	})
   185  }
   186  
   187  func (d *dnsServer) Do(ctx context.Context, b *pool.Bytes, writeBack func([]byte) error) error {
   188  	ctx, cancel := context.WithTimeout(ctx, time.Second*10)
   189  	defer cancel()
   190  
   191  	defer b.Free()
   192  
   193  	var parse dnsmessage.Parser
   194  	header, err := parse.Start(b.Bytes())
   195  	if err != nil {
   196  		return fmt.Errorf("dns server parse failed: %w", err)
   197  	}
   198  
   199  	question, err := parse.Question()
   200  	if err != nil {
   201  		return fmt.Errorf("dns server parse failed: %w", err)
   202  	}
   203  
   204  	msg, err := d.resolver.Raw(ctx, question)
   205  	if err != nil {
   206  		return fmt.Errorf("do raw request (%v:%v) failed: %w", question.Name, question.Type, err)
   207  	}
   208  
   209  	msg.ID = header.ID
   210  
   211  	respBuf := pool.GetBytes(pool.DefaultSize)
   212  	defer pool.PutBytes(respBuf)
   213  
   214  	bytes, err := msg.AppendPack(respBuf[:0])
   215  	if err != nil {
   216  		return err
   217  	}
   218  
   219  	return writeBack(bytes)
   220  }