github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/common/protocol/dns/io.go (about) 1 package dns 2 3 import ( 4 "encoding/binary" 5 "sync" 6 7 "github.com/xmplusdev/xmcore/common" 8 "github.com/xmplusdev/xmcore/common/buf" 9 "github.com/xmplusdev/xmcore/common/serial" 10 "golang.org/x/net/dns/dnsmessage" 11 ) 12 13 func PackMessage(msg *dnsmessage.Message) (*buf.Buffer, error) { 14 buffer := buf.New() 15 rawBytes := buffer.Extend(buf.Size) 16 packed, err := msg.AppendPack(rawBytes[:0]) 17 if err != nil { 18 buffer.Release() 19 return nil, err 20 } 21 buffer.Resize(0, int32(len(packed))) 22 return buffer, nil 23 } 24 25 type MessageReader interface { 26 ReadMessage() (*buf.Buffer, error) 27 } 28 29 type UDPReader struct { 30 buf.Reader 31 32 access sync.Mutex 33 cache buf.MultiBuffer 34 } 35 36 func (r *UDPReader) readCache() *buf.Buffer { 37 r.access.Lock() 38 defer r.access.Unlock() 39 40 mb, b := buf.SplitFirst(r.cache) 41 r.cache = mb 42 return b 43 } 44 45 func (r *UDPReader) refill() error { 46 mb, err := r.Reader.ReadMultiBuffer() 47 if err != nil { 48 return err 49 } 50 r.access.Lock() 51 r.cache = mb 52 r.access.Unlock() 53 return nil 54 } 55 56 // ReadMessage implements MessageReader. 57 func (r *UDPReader) ReadMessage() (*buf.Buffer, error) { 58 for { 59 b := r.readCache() 60 if b != nil { 61 return b, nil 62 } 63 if err := r.refill(); err != nil { 64 return nil, err 65 } 66 } 67 } 68 69 // Close implements common.Closable. 70 func (r *UDPReader) Close() error { 71 defer func() { 72 r.access.Lock() 73 buf.ReleaseMulti(r.cache) 74 r.cache = nil 75 r.access.Unlock() 76 }() 77 78 return common.Close(r.Reader) 79 } 80 81 type TCPReader struct { 82 reader *buf.BufferedReader 83 } 84 85 func NewTCPReader(reader buf.Reader) *TCPReader { 86 return &TCPReader{ 87 reader: &buf.BufferedReader{ 88 Reader: reader, 89 }, 90 } 91 } 92 93 func (r *TCPReader) ReadMessage() (*buf.Buffer, error) { 94 size, err := serial.ReadUint16(r.reader) 95 if err != nil { 96 return nil, err 97 } 98 if size > buf.Size { 99 return nil, newError("message size too large: ", size) 100 } 101 b := buf.New() 102 if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil { 103 return nil, err 104 } 105 return b, nil 106 } 107 108 func (r *TCPReader) Interrupt() { 109 common.Interrupt(r.reader) 110 } 111 112 func (r *TCPReader) Close() error { 113 return common.Close(r.reader) 114 } 115 116 type MessageWriter interface { 117 WriteMessage(msg *buf.Buffer) error 118 } 119 120 type UDPWriter struct { 121 buf.Writer 122 } 123 124 func (w *UDPWriter) WriteMessage(b *buf.Buffer) error { 125 return w.WriteMultiBuffer(buf.MultiBuffer{b}) 126 } 127 128 type TCPWriter struct { 129 buf.Writer 130 } 131 132 func (w *TCPWriter) WriteMessage(b *buf.Buffer) error { 133 if b.IsEmpty() { 134 return nil 135 } 136 137 mb := make(buf.MultiBuffer, 0, 2) 138 139 size := buf.New() 140 binary.BigEndian.PutUint16(size.Extend(2), uint16(b.Len())) 141 mb = append(mb, size, b) 142 return w.WriteMultiBuffer(mb) 143 }