github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dns/doq.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"encoding/binary"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    15  	pdns "github.com/Asutorufa/yuhaiin/pkg/protos/config/dns"
    16  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    17  	"github.com/Asutorufa/yuhaiin/pkg/utils/id"
    18  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    19  	"github.com/quic-go/quic-go"
    20  	"golang.org/x/net/http2"
    21  )
    22  
    23  func init() {
    24  	Register(pdns.Type_doq, NewDoQ)
    25  }
    26  
    27  type doq struct {
    28  	conn       net.PacketConn
    29  	connection quic.Connection
    30  	host       netapi.Address
    31  	servername string
    32  	dialer     netapi.PacketProxy
    33  
    34  	mu sync.RWMutex
    35  
    36  	*client
    37  }
    38  
    39  func NewDoQ(config Config) (netapi.Resolver, error) {
    40  	addr, err := ParseAddr(statistic.Type_udp, config.Host, "784")
    41  	if err != nil {
    42  		return nil, fmt.Errorf("parse addr failed: %w", err)
    43  	}
    44  
    45  	if config.Servername == "" {
    46  		config.Servername = addr.Hostname()
    47  	}
    48  
    49  	d := &doq{
    50  		dialer:     config.Dialer,
    51  		host:       addr,
    52  		servername: config.Servername,
    53  	}
    54  
    55  	d.client = NewClient(config, func(ctx context.Context, b []byte) (*pool.Bytes, error) {
    56  		session, err := d.initSession(ctx)
    57  		if err != nil {
    58  			return nil, fmt.Errorf("init session failed: %w", err)
    59  		}
    60  
    61  		d.mu.RLock()
    62  		con, err := session.OpenStream()
    63  		if err != nil {
    64  			return nil, fmt.Errorf("open stream failed: %w", err)
    65  		}
    66  		defer con.Close()
    67  		defer d.mu.RUnlock()
    68  
    69  		err = con.SetWriteDeadline(time.Now().Add(time.Second * 4))
    70  		if err != nil {
    71  			con.Close()
    72  			return nil, fmt.Errorf("set write deadline failed: %w", err)
    73  		}
    74  
    75  		buf := pool.GetBytesWriter(2 + len(b))
    76  		defer buf.Free()
    77  
    78  		buf.WriteUint16(uint16(len(b)))
    79  		_, _ = buf.Write(b)
    80  
    81  		if _, err = con.Write(buf.Bytes()); err != nil {
    82  			con.Close()
    83  			return nil, fmt.Errorf("write dns req failed: %w", err)
    84  		}
    85  
    86  		// close to make server io.EOF
    87  		if err = con.Close(); err != nil {
    88  			return nil, fmt.Errorf("close stream failed: %w", err)
    89  		}
    90  
    91  		err = con.SetReadDeadline(time.Now().Add(time.Second * 4))
    92  		if err != nil {
    93  			return nil, fmt.Errorf("set read deadline failed: %w", err)
    94  		}
    95  
    96  		var length uint16
    97  		err = binary.Read(con, binary.BigEndian, &length)
    98  		if err != nil {
    99  			return nil, fmt.Errorf("read dns response length failed: %w", err)
   100  		}
   101  
   102  		data := pool.GetBytesBuffer(int(length))
   103  
   104  		_, err = io.ReadFull(con, data.Bytes())
   105  		if err != nil {
   106  			return nil, fmt.Errorf("read dns server response failed: %w", err)
   107  		}
   108  
   109  		return data, nil
   110  	})
   111  	return d, nil
   112  }
   113  
   114  func (d *doq) Close() error {
   115  	var err error
   116  	if d.connection != nil {
   117  		er := d.connection.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
   118  		if er != nil {
   119  			err = errors.Join(err, er)
   120  		}
   121  	}
   122  
   123  	if d.conn != nil {
   124  		er := d.conn.Close()
   125  		if er != nil {
   126  			err = errors.Join(err, er)
   127  		}
   128  	}
   129  
   130  	return err
   131  }
   132  
   133  type DOQWrapConn struct {
   134  	net.PacketConn
   135  	localAddrSalt string
   136  }
   137  
   138  func (d *DOQWrapConn) LocalAddr() net.Addr {
   139  	return &doqWrapLocalAddr{d.PacketConn.LocalAddr(), d.localAddrSalt}
   140  }
   141  
   142  // doqWrapLocalAddr make doq packetConn local addr is different, otherwise the quic-go will panic
   143  // see: https://github.com/quic-go/quic-go/issues/3727
   144  type doqWrapLocalAddr struct {
   145  	net.Addr
   146  	salt string
   147  }
   148  
   149  func (a *doqWrapLocalAddr) String() string {
   150  	return fmt.Sprintf("doq://%s-%s", a.Addr.String(), a.salt)
   151  }
   152  
   153  var doqIgGenerate = id.IDGenerator{}
   154  
   155  func (d *doq) initSession(ctx context.Context) (quic.Connection, error) {
   156  	connection := d.connection
   157  
   158  	if connection != nil {
   159  		select {
   160  		case <-connection.Context().Done():
   161  			_ = connection.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
   162  		default:
   163  			return connection, nil
   164  		}
   165  	}
   166  
   167  	d.mu.Lock()
   168  	defer d.mu.Unlock()
   169  
   170  	if d.connection != nil {
   171  		select {
   172  		case <-d.connection.Context().Done():
   173  			_ = d.connection.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
   174  
   175  		default:
   176  			return d.connection, nil
   177  		}
   178  	}
   179  
   180  	if d.conn != nil {
   181  		d.conn.Close()
   182  		d.conn = nil
   183  	}
   184  
   185  	if d.conn == nil {
   186  		conn, err := d.dialer.PacketConn(ctx, d.host)
   187  		if err != nil {
   188  			return nil, err
   189  		}
   190  		d.conn = conn
   191  	}
   192  
   193  	session, err := quic.Dial(
   194  		ctx,
   195  		&DOQWrapConn{d.conn, fmt.Sprint(doqIgGenerate.Generate())},
   196  		d.host,
   197  		&tls.Config{
   198  			NextProtos: []string{"http/1.1", "doq-i02", "doq-i01", "doq-i00", "doq", "dq", http2.NextProtoTLS},
   199  			ServerName: d.servername,
   200  		}, &quic.Config{})
   201  	if err != nil {
   202  		_ = d.conn.Close()
   203  		return nil, fmt.Errorf("quic dial failed: %w", err)
   204  	}
   205  
   206  	d.connection = session
   207  	return session, nil
   208  }