github.com/xmplusdev/xray-core@v1.8.10/common/protocol/dns/io.go (about)

     1  package dns
     2  
     3  import (
     4  	"encoding/binary"
     5  	"sync"
     6  
     7  	"github.com/xmplusdev/xray-core/common"
     8  	"github.com/xmplusdev/xray-core/common/buf"
     9  	"github.com/xmplusdev/xray-core/common/serial"
    10  	"golang.org/x/net/dns/dnsmessage"
    11  )
    12  
    13  func PackMessage(msg *dnsmessage.Message) (*buf.Buffer, error) {
    14  	buffer := buf.New()
    15  	rawBytes := buffer.Extend(buf.Size)
    16  	packed, err := msg.AppendPack(rawBytes[:0])
    17  	if err != nil {
    18  		buffer.Release()
    19  		return nil, err
    20  	}
    21  	buffer.Resize(0, int32(len(packed)))
    22  	return buffer, nil
    23  }
    24  
    25  type MessageReader interface {
    26  	ReadMessage() (*buf.Buffer, error)
    27  }
    28  
    29  type UDPReader struct {
    30  	buf.Reader
    31  
    32  	access sync.Mutex
    33  	cache  buf.MultiBuffer
    34  }
    35  
    36  func (r *UDPReader) readCache() *buf.Buffer {
    37  	r.access.Lock()
    38  	defer r.access.Unlock()
    39  
    40  	mb, b := buf.SplitFirst(r.cache)
    41  	r.cache = mb
    42  	return b
    43  }
    44  
    45  func (r *UDPReader) refill() error {
    46  	mb, err := r.Reader.ReadMultiBuffer()
    47  	if err != nil {
    48  		return err
    49  	}
    50  	r.access.Lock()
    51  	r.cache = mb
    52  	r.access.Unlock()
    53  	return nil
    54  }
    55  
    56  // ReadMessage implements MessageReader.
    57  func (r *UDPReader) ReadMessage() (*buf.Buffer, error) {
    58  	for {
    59  		b := r.readCache()
    60  		if b != nil {
    61  			return b, nil
    62  		}
    63  		if err := r.refill(); err != nil {
    64  			return nil, err
    65  		}
    66  	}
    67  }
    68  
    69  // Close implements common.Closable.
    70  func (r *UDPReader) Close() error {
    71  	defer func() {
    72  		r.access.Lock()
    73  		buf.ReleaseMulti(r.cache)
    74  		r.cache = nil
    75  		r.access.Unlock()
    76  	}()
    77  
    78  	return common.Close(r.Reader)
    79  }
    80  
    81  type TCPReader struct {
    82  	reader *buf.BufferedReader
    83  }
    84  
    85  func NewTCPReader(reader buf.Reader) *TCPReader {
    86  	return &TCPReader{
    87  		reader: &buf.BufferedReader{
    88  			Reader: reader,
    89  		},
    90  	}
    91  }
    92  
    93  func (r *TCPReader) ReadMessage() (*buf.Buffer, error) {
    94  	size, err := serial.ReadUint16(r.reader)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  	if size > buf.Size {
    99  		return nil, newError("message size too large: ", size)
   100  	}
   101  	b := buf.New()
   102  	if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil {
   103  		return nil, err
   104  	}
   105  	return b, nil
   106  }
   107  
   108  func (r *TCPReader) Interrupt() {
   109  	common.Interrupt(r.reader)
   110  }
   111  
   112  func (r *TCPReader) Close() error {
   113  	return common.Close(r.reader)
   114  }
   115  
   116  type MessageWriter interface {
   117  	WriteMessage(msg *buf.Buffer) error
   118  }
   119  
   120  type UDPWriter struct {
   121  	buf.Writer
   122  }
   123  
   124  func (w *UDPWriter) WriteMessage(b *buf.Buffer) error {
   125  	return w.WriteMultiBuffer(buf.MultiBuffer{b})
   126  }
   127  
   128  type TCPWriter struct {
   129  	buf.Writer
   130  }
   131  
   132  func (w *TCPWriter) WriteMessage(b *buf.Buffer) error {
   133  	if b.IsEmpty() {
   134  		return nil
   135  	}
   136  
   137  	mb := make(buf.MultiBuffer, 0, 2)
   138  
   139  	size := buf.New()
   140  	binary.BigEndian.PutUint16(size.Extend(2), uint16(b.Len()))
   141  	mb = append(mb, size, b)
   142  	return w.WriteMultiBuffer(mb)
   143  }