github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/common/protocol/dns/io.go (about)

     1  package dns
     2  
     3  import (
     4  	"encoding/binary"
     5  	"sync"
     6  
     7  	"golang.org/x/net/dns/dnsmessage"
     8  
     9  	"github.com/v2fly/v2ray-core/v5/common"
    10  	"github.com/v2fly/v2ray-core/v5/common/buf"
    11  	"github.com/v2fly/v2ray-core/v5/common/serial"
    12  )
    13  
    14  func PackMessage(msg *dnsmessage.Message) (*buf.Buffer, error) {
    15  	buffer := buf.New()
    16  	rawBytes := buffer.Extend(buf.Size)
    17  	packed, err := msg.AppendPack(rawBytes[:0])
    18  	if err != nil {
    19  		buffer.Release()
    20  		return nil, err
    21  	}
    22  	buffer.Resize(0, int32(len(packed)))
    23  	return buffer, nil
    24  }
    25  
    26  type MessageReader interface {
    27  	ReadMessage() (*buf.Buffer, error)
    28  }
    29  
    30  type UDPReader struct {
    31  	buf.Reader
    32  
    33  	access sync.Mutex
    34  	cache  buf.MultiBuffer
    35  }
    36  
    37  func (r *UDPReader) readCache() *buf.Buffer {
    38  	r.access.Lock()
    39  	defer r.access.Unlock()
    40  
    41  	mb, b := buf.SplitFirst(r.cache)
    42  	r.cache = mb
    43  	return b
    44  }
    45  
    46  func (r *UDPReader) refill() error {
    47  	mb, err := r.Reader.ReadMultiBuffer()
    48  	if err != nil {
    49  		return err
    50  	}
    51  	r.access.Lock()
    52  	r.cache = mb
    53  	r.access.Unlock()
    54  	return nil
    55  }
    56  
    57  // ReadMessage implements MessageReader.
    58  func (r *UDPReader) ReadMessage() (*buf.Buffer, error) {
    59  	for {
    60  		b := r.readCache()
    61  		if b != nil {
    62  			return b, nil
    63  		}
    64  		if err := r.refill(); err != nil {
    65  			return nil, err
    66  		}
    67  	}
    68  }
    69  
    70  // Close implements common.Closable.
    71  func (r *UDPReader) Close() error {
    72  	defer func() {
    73  		r.access.Lock()
    74  		buf.ReleaseMulti(r.cache)
    75  		r.cache = nil
    76  		r.access.Unlock()
    77  	}()
    78  
    79  	return common.Close(r.Reader)
    80  }
    81  
    82  type TCPReader struct {
    83  	reader *buf.BufferedReader
    84  }
    85  
    86  func NewTCPReader(reader buf.Reader) *TCPReader {
    87  	return &TCPReader{
    88  		reader: &buf.BufferedReader{
    89  			Reader: reader,
    90  		},
    91  	}
    92  }
    93  
    94  func (r *TCPReader) ReadMessage() (*buf.Buffer, error) {
    95  	size, err := serial.ReadUint16(r.reader)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	if size > buf.Size {
   100  		return nil, newError("message size too large: ", size)
   101  	}
   102  	b := buf.New()
   103  	if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil {
   104  		return nil, err
   105  	}
   106  	return b, nil
   107  }
   108  
   109  func (r *TCPReader) Interrupt() {
   110  	common.Interrupt(r.reader)
   111  }
   112  
   113  func (r *TCPReader) Close() error {
   114  	return common.Close(r.reader)
   115  }
   116  
   117  type MessageWriter interface {
   118  	WriteMessage(msg *buf.Buffer) error
   119  }
   120  
   121  type UDPWriter struct {
   122  	buf.Writer
   123  }
   124  
   125  func (w *UDPWriter) WriteMessage(b *buf.Buffer) error {
   126  	return w.WriteMultiBuffer(buf.MultiBuffer{b})
   127  }
   128  
   129  type TCPWriter struct {
   130  	buf.Writer
   131  }
   132  
   133  func (w *TCPWriter) WriteMessage(b *buf.Buffer) error {
   134  	if b.IsEmpty() {
   135  		return nil
   136  	}
   137  
   138  	mb := make(buf.MultiBuffer, 0, 2)
   139  
   140  	size := buf.New()
   141  	binary.BigEndian.PutUint16(size.Extend(2), uint16(b.Len()))
   142  	mb = append(mb, size, b)
   143  	return w.WriteMultiBuffer(mb)
   144  }