github.com/EagleQL/Xray-core@v1.4.3/proxy/vless/encoding/addons.go (about)

     1  package encoding
     2  
     3  import (
     4  	"io"
     5  
     6  	"github.com/golang/protobuf/proto"
     7  
     8  	"github.com/xtls/xray-core/common/buf"
     9  	"github.com/xtls/xray-core/common/protocol"
    10  	"github.com/xtls/xray-core/proxy/vless"
    11  )
    12  
    13  func EncodeHeaderAddons(buffer *buf.Buffer, addons *Addons) error {
    14  	switch addons.Flow {
    15  	case vless.XRO, vless.XRD:
    16  		bytes, err := proto.Marshal(addons)
    17  		if err != nil {
    18  			return newError("failed to marshal addons protobuf value").Base(err)
    19  		}
    20  		if err := buffer.WriteByte(byte(len(bytes))); err != nil {
    21  			return newError("failed to write addons protobuf length").Base(err)
    22  		}
    23  		if _, err := buffer.Write(bytes); err != nil {
    24  			return newError("failed to write addons protobuf value").Base(err)
    25  		}
    26  	default:
    27  		if err := buffer.WriteByte(0); err != nil {
    28  			return newError("failed to write addons protobuf length").Base(err)
    29  		}
    30  	}
    31  
    32  	return nil
    33  }
    34  
    35  func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) {
    36  	addons := new(Addons)
    37  	buffer.Clear()
    38  	if _, err := buffer.ReadFullFrom(reader, 1); err != nil {
    39  		return nil, newError("failed to read addons protobuf length").Base(err)
    40  	}
    41  
    42  	if length := int32(buffer.Byte(0)); length != 0 {
    43  		buffer.Clear()
    44  		if _, err := buffer.ReadFullFrom(reader, length); err != nil {
    45  			return nil, newError("failed to read addons protobuf value").Base(err)
    46  		}
    47  
    48  		if err := proto.Unmarshal(buffer.Bytes(), addons); err != nil {
    49  			return nil, newError("failed to unmarshal addons protobuf value").Base(err)
    50  		}
    51  
    52  		// Verification.
    53  		switch addons.Flow {
    54  		default:
    55  		}
    56  	}
    57  
    58  	return addons, nil
    59  }
    60  
    61  // EncodeBodyAddons returns a Writer that auto-encrypt content written by caller.
    62  func EncodeBodyAddons(writer io.Writer, request *protocol.RequestHeader, addons *Addons) buf.Writer {
    63  	switch addons.Flow {
    64  	default:
    65  		if request.Command == protocol.RequestCommandUDP {
    66  			return NewMultiLengthPacketWriter(writer.(buf.Writer))
    67  		}
    68  	}
    69  	return buf.NewWriter(writer)
    70  }
    71  
    72  // DecodeBodyAddons returns a Reader from which caller can fetch decrypted body.
    73  func DecodeBodyAddons(reader io.Reader, request *protocol.RequestHeader, addons *Addons) buf.Reader {
    74  	switch addons.Flow {
    75  	default:
    76  		if request.Command == protocol.RequestCommandUDP {
    77  			return NewLengthPacketReader(reader)
    78  		}
    79  	}
    80  	return buf.NewReader(reader)
    81  }
    82  
    83  func NewMultiLengthPacketWriter(writer buf.Writer) *MultiLengthPacketWriter {
    84  	return &MultiLengthPacketWriter{
    85  		Writer: writer,
    86  	}
    87  }
    88  
    89  type MultiLengthPacketWriter struct {
    90  	buf.Writer
    91  }
    92  
    93  func (w *MultiLengthPacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
    94  	defer buf.ReleaseMulti(mb)
    95  	mb2Write := make(buf.MultiBuffer, 0, len(mb)+1)
    96  	for _, b := range mb {
    97  		length := b.Len()
    98  		if length == 0 || length+2 > buf.Size {
    99  			continue
   100  		}
   101  		eb := buf.New()
   102  		if err := eb.WriteByte(byte(length >> 8)); err != nil {
   103  			eb.Release()
   104  			continue
   105  		}
   106  		if err := eb.WriteByte(byte(length)); err != nil {
   107  			eb.Release()
   108  			continue
   109  		}
   110  		if _, err := eb.Write(b.Bytes()); err != nil {
   111  			eb.Release()
   112  			continue
   113  		}
   114  		mb2Write = append(mb2Write, eb)
   115  	}
   116  	if mb2Write.IsEmpty() {
   117  		return nil
   118  	}
   119  	return w.Writer.WriteMultiBuffer(mb2Write)
   120  }
   121  
   122  func NewLengthPacketWriter(writer io.Writer) *LengthPacketWriter {
   123  	return &LengthPacketWriter{
   124  		Writer: writer,
   125  		cache:  make([]byte, 0, 65536),
   126  	}
   127  }
   128  
   129  type LengthPacketWriter struct {
   130  	io.Writer
   131  	cache []byte
   132  }
   133  
   134  func (w *LengthPacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   135  	length := mb.Len() // none of mb is nil
   136  	// fmt.Println("Write", length)
   137  	if length == 0 {
   138  		return nil
   139  	}
   140  	defer func() {
   141  		w.cache = w.cache[:0]
   142  	}()
   143  	w.cache = append(w.cache, byte(length>>8), byte(length))
   144  	for i, b := range mb {
   145  		w.cache = append(w.cache, b.Bytes()...)
   146  		b.Release()
   147  		mb[i] = nil
   148  	}
   149  	if _, err := w.Write(w.cache); err != nil {
   150  		return newError("failed to write a packet").Base(err)
   151  	}
   152  	return nil
   153  }
   154  
   155  func NewLengthPacketReader(reader io.Reader) *LengthPacketReader {
   156  	return &LengthPacketReader{
   157  		Reader: reader,
   158  		cache:  make([]byte, 2),
   159  	}
   160  }
   161  
   162  type LengthPacketReader struct {
   163  	io.Reader
   164  	cache []byte
   165  }
   166  
   167  func (r *LengthPacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   168  	if _, err := io.ReadFull(r.Reader, r.cache); err != nil { // maybe EOF
   169  		return nil, newError("failed to read packet length").Base(err)
   170  	}
   171  	length := int32(r.cache[0])<<8 | int32(r.cache[1])
   172  	// fmt.Println("Read", length)
   173  	mb := make(buf.MultiBuffer, 0, length/buf.Size+1)
   174  	for length > 0 {
   175  		size := length
   176  		if size > buf.Size {
   177  			size = buf.Size
   178  		}
   179  		length -= size
   180  		b := buf.New()
   181  		if _, err := b.ReadFullFrom(r.Reader, size); err != nil {
   182  			return nil, newError("failed to read packet payload").Base(err)
   183  		}
   184  		mb = append(mb, b)
   185  	}
   186  	return mb, nil
   187  }