github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/doq.go (about) 1 package dns 2 3 import ( 4 "context" 5 "crypto/tls" 6 "encoding/binary" 7 "errors" 8 "fmt" 9 "io" 10 "net" 11 "sync" 12 "time" 13 14 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 15 pdns "github.com/Asutorufa/yuhaiin/pkg/protos/config/dns" 16 "github.com/Asutorufa/yuhaiin/pkg/protos/statistic" 17 "github.com/Asutorufa/yuhaiin/pkg/utils/id" 18 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 19 "github.com/quic-go/quic-go" 20 "golang.org/x/net/http2" 21 ) 22 23 func init() { 24 Register(pdns.Type_doq, NewDoQ) 25 } 26 27 type doq struct { 28 conn net.PacketConn 29 connection quic.Connection 30 host netapi.Address 31 servername string 32 dialer netapi.PacketProxy 33 34 mu sync.RWMutex 35 36 *client 37 } 38 39 func NewDoQ(config Config) (netapi.Resolver, error) { 40 addr, err := ParseAddr(statistic.Type_udp, config.Host, "784") 41 if err != nil { 42 return nil, fmt.Errorf("parse addr failed: %w", err) 43 } 44 45 if config.Servername == "" { 46 config.Servername = addr.Hostname() 47 } 48 49 d := &doq{ 50 dialer: config.Dialer, 51 host: addr, 52 servername: config.Servername, 53 } 54 55 d.client = NewClient(config, func(ctx context.Context, b []byte) (*pool.Bytes, error) { 56 session, err := d.initSession(ctx) 57 if err != nil { 58 return nil, fmt.Errorf("init session failed: %w", err) 59 } 60 61 d.mu.RLock() 62 con, err := session.OpenStream() 63 if err != nil { 64 return nil, fmt.Errorf("open stream failed: %w", err) 65 } 66 defer con.Close() 67 defer d.mu.RUnlock() 68 69 err = con.SetWriteDeadline(time.Now().Add(time.Second * 4)) 70 if err != nil { 71 con.Close() 72 return nil, fmt.Errorf("set write deadline failed: %w", err) 73 } 74 75 buf := pool.GetBytesWriter(2 + len(b)) 76 defer buf.Free() 77 78 buf.WriteUint16(uint16(len(b))) 79 _, _ = buf.Write(b) 80 81 if _, err = con.Write(buf.Bytes()); err != nil { 82 con.Close() 83 return nil, fmt.Errorf("write dns req failed: %w", err) 84 } 85 86 // close to make server io.EOF 87 if err = con.Close(); err != nil { 88 return nil, fmt.Errorf("close stream failed: %w", err) 89 } 90 91 err = con.SetReadDeadline(time.Now().Add(time.Second * 4)) 92 if err != nil { 93 return nil, fmt.Errorf("set read deadline failed: %w", err) 94 } 95 96 var length uint16 97 err = binary.Read(con, binary.BigEndian, &length) 98 if err != nil { 99 return nil, fmt.Errorf("read dns response length failed: %w", err) 100 } 101 102 data := pool.GetBytesBuffer(int(length)) 103 104 _, err = io.ReadFull(con, data.Bytes()) 105 if err != nil { 106 return nil, fmt.Errorf("read dns server response failed: %w", err) 107 } 108 109 return data, nil 110 }) 111 return d, nil 112 } 113 114 func (d *doq) Close() error { 115 var err error 116 if d.connection != nil { 117 er := d.connection.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") 118 if er != nil { 119 err = errors.Join(err, er) 120 } 121 } 122 123 if d.conn != nil { 124 er := d.conn.Close() 125 if er != nil { 126 err = errors.Join(err, er) 127 } 128 } 129 130 return err 131 } 132 133 type DOQWrapConn struct { 134 net.PacketConn 135 localAddrSalt string 136 } 137 138 func (d *DOQWrapConn) LocalAddr() net.Addr { 139 return &doqWrapLocalAddr{d.PacketConn.LocalAddr(), d.localAddrSalt} 140 } 141 142 // doqWrapLocalAddr make doq packetConn local addr is different, otherwise the quic-go will panic 143 // see: https://github.com/quic-go/quic-go/issues/3727 144 type doqWrapLocalAddr struct { 145 net.Addr 146 salt string 147 } 148 149 func (a *doqWrapLocalAddr) String() string { 150 return fmt.Sprintf("doq://%s-%s", a.Addr.String(), a.salt) 151 } 152 153 var doqIgGenerate = id.IDGenerator{} 154 155 func (d *doq) initSession(ctx context.Context) (quic.Connection, error) { 156 connection := d.connection 157 158 if connection != nil { 159 select { 160 case <-connection.Context().Done(): 161 _ = connection.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") 162 default: 163 return connection, nil 164 } 165 } 166 167 d.mu.Lock() 168 defer d.mu.Unlock() 169 170 if d.connection != nil { 171 select { 172 case <-d.connection.Context().Done(): 173 _ = d.connection.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") 174 175 default: 176 return d.connection, nil 177 } 178 } 179 180 if d.conn != nil { 181 d.conn.Close() 182 d.conn = nil 183 } 184 185 if d.conn == nil { 186 conn, err := d.dialer.PacketConn(ctx, d.host) 187 if err != nil { 188 return nil, err 189 } 190 d.conn = conn 191 } 192 193 session, err := quic.Dial( 194 ctx, 195 &DOQWrapConn{d.conn, fmt.Sprint(doqIgGenerate.Generate())}, 196 d.host, 197 &tls.Config{ 198 NextProtos: []string{"http/1.1", "doq-i02", "doq-i01", "doq-i00", "doq", "dq", http2.NextProtoTLS}, 199 ServerName: d.servername, 200 }, &quic.Config{}) 201 if err != nil { 202 _ = d.conn.Close() 203 return nil, fmt.Errorf("quic dial failed: %w", err) 204 } 205 206 d.connection = session 207 return session, nil 208 }