github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/server.go (about) 1 package dns 2 3 import ( 4 "context" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "io" 9 "log/slog" 10 "net" 11 "time" 12 13 "github.com/Asutorufa/yuhaiin/pkg/log" 14 "github.com/Asutorufa/yuhaiin/pkg/net/dialer" 15 "github.com/Asutorufa/yuhaiin/pkg/net/nat" 16 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 17 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 18 "github.com/Asutorufa/yuhaiin/pkg/utils/system" 19 "golang.org/x/net/dns/dnsmessage" 20 "golang.org/x/sync/semaphore" 21 ) 22 23 type dnsServer struct { 24 server string 25 resolver netapi.Resolver 26 listener net.PacketConn 27 tcpListener net.Listener 28 29 sf *semaphore.Weighted 30 } 31 32 func NewServer(server string, process netapi.Resolver) netapi.DNSServer { 33 d := &dnsServer{ 34 server: server, 35 resolver: process, 36 sf: semaphore.NewWeighted(200), 37 } 38 39 if server == "" { 40 log.Info("dns server is empty, skip to listen tcp and udp") 41 return d 42 } 43 44 if err := d.startUDP(); err != nil { 45 log.Error("start udp dns server failed", slog.Any("err", err)) 46 } 47 48 go func() { 49 if err := d.startTCP(); err != nil { 50 log.Error("start tcp dns server failed", slog.Any("err", err)) 51 } 52 }() 53 54 return d 55 } 56 57 func (d *dnsServer) Close() error { 58 if d.listener != nil { 59 d.listener.Close() 60 } 61 if d.tcpListener != nil { 62 d.tcpListener.Close() 63 } 64 65 return nil 66 } 67 68 func (d *dnsServer) startUDP() (err error) { 69 d.listener, err = dialer.ListenPacket("udp", d.server) 70 if err != nil { 71 return fmt.Errorf("dns udp server listen failed: %w", err) 72 } 73 74 log.Info("new udp dns server", "host", d.server) 75 76 for i := 0; i < system.Procs; i++ { 77 go func() { 78 defer d.Close() 79 80 for { 81 buf := pool.GetBytesBuffer(nat.MaxSegmentSize) 82 _, addr, err := buf.ReadFromPacket(d.listener) 83 if err != nil { 84 buf.Free() 85 86 if e, ok := err.(net.Error); ok && e.Temporary() { 87 continue 88 } 89 90 if !errors.Is(err, net.ErrClosed) { 91 log.Error("dns udp server handle failed", "err", err) 92 } 93 return 94 } 95 96 err = d.sf.Acquire(context.TODO(), 1) 97 if err != nil { 98 buf.Free() 99 continue 100 } 101 102 go func() { 103 defer d.sf.Release(1) 104 err := d.Do(context.TODO(), buf, func(b []byte) error { 105 if _, err = d.listener.WriteTo(b, addr); err != nil { 106 return fmt.Errorf("write dns response to client failed: %w", err) 107 } 108 return nil 109 }) 110 if err != nil { 111 log.Error("dns server handle data failed", slog.Any("err", err)) 112 } 113 }() 114 115 } 116 }() 117 } 118 119 return nil 120 } 121 122 func (d *dnsServer) startTCP() (err error) { 123 defer d.Close() 124 125 d.tcpListener, err = dialer.ListenContext(context.TODO(), "tcp", d.server) 126 if err != nil { 127 return fmt.Errorf("dns tcp server listen failed: %w", err) 128 } 129 130 log.Info("new tcp dns server", "host", d.server) 131 132 for { 133 conn, err := d.tcpListener.Accept() 134 if err != nil { 135 if e, ok := err.(net.Error); ok && e.Temporary() { 136 continue 137 } 138 return fmt.Errorf("dns server accept failed: %w", err) 139 } 140 141 go func() { 142 defer conn.Close() 143 144 if err := d.HandleTCP(context.TODO(), conn); err != nil { 145 log.Error("handle dns tcp failed", "err", err) 146 } 147 }() 148 } 149 } 150 151 func (d *dnsServer) HandleTCP(ctx context.Context, c net.Conn) error { 152 var length uint16 153 if err := binary.Read(c, binary.BigEndian, &length); err != nil { 154 return fmt.Errorf("read dns length failed: %w", err) 155 } 156 157 data := pool.GetBytesBuffer(int(length)) 158 159 _, err := io.ReadFull(c, data.Bytes()) 160 if err != nil { 161 return fmt.Errorf("dns server read data failed: %w", err) 162 } 163 164 return d.Do(ctx, data, func(b []byte) error { 165 if err = binary.Write(c, binary.BigEndian, uint16(len(b))); err != nil { 166 return fmt.Errorf("dns server write length failed: %w", err) 167 } 168 _, err = c.Write(b) 169 return err 170 }) 171 } 172 173 func (d *dnsServer) HandleUDP(ctx context.Context, l net.PacketConn) error { 174 buf := pool.GetBytesBuffer(nat.MaxSegmentSize) 175 176 _, addr, err := buf.ReadFromPacket(l) 177 if err != nil { 178 return err 179 } 180 181 return d.Do(context.TODO(), buf, func(b []byte) error { 182 _, err = l.WriteTo(b, addr) 183 return err 184 }) 185 } 186 187 func (d *dnsServer) Do(ctx context.Context, b *pool.Bytes, writeBack func([]byte) error) error { 188 ctx, cancel := context.WithTimeout(ctx, time.Second*10) 189 defer cancel() 190 191 defer b.Free() 192 193 var parse dnsmessage.Parser 194 header, err := parse.Start(b.Bytes()) 195 if err != nil { 196 return fmt.Errorf("dns server parse failed: %w", err) 197 } 198 199 question, err := parse.Question() 200 if err != nil { 201 return fmt.Errorf("dns server parse failed: %w", err) 202 } 203 204 msg, err := d.resolver.Raw(ctx, question) 205 if err != nil { 206 return fmt.Errorf("do raw request (%v:%v) failed: %w", question.Name, question.Type, err) 207 } 208 209 msg.ID = header.ID 210 211 respBuf := pool.GetBytes(pool.DefaultSize) 212 defer pool.PutBytes(respBuf) 213 214 bytes, err := msg.AppendPack(respBuf[:0]) 215 if err != nil { 216 return err 217 } 218 219 return writeBack(bytes) 220 }