github.com/xraypb/Xray-core@v1.8.1/common/xudp/xudp.go (about)

     1  package xudp
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"io"
     9  	"os"
    10  	"strings"
    11  
    12  	"github.com/xraypb/Xray-core/common/buf"
    13  	"github.com/xraypb/Xray-core/common/net"
    14  	"github.com/xraypb/Xray-core/common/protocol"
    15  	"github.com/xraypb/Xray-core/common/session"
    16  	"lukechampine.com/blake3"
    17  )
    18  
    19  var AddrParser = protocol.NewAddressParser(
    20  	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
    21  	protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
    22  	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
    23  	protocol.PortThenAddress(),
    24  )
    25  
    26  var (
    27  	Show    bool
    28  	BaseKey [32]byte
    29  )
    30  
    31  const (
    32  	EnvShow    = "XRAY_XUDP_SHOW"
    33  	EnvBaseKey = "XRAY_XUDP_BASEKEY"
    34  )
    35  
    36  func init() {
    37  	if strings.ToLower(os.Getenv(EnvShow)) == "true" {
    38  		Show = true
    39  	}
    40  	if raw := os.Getenv(EnvBaseKey); raw != "" {
    41  		if key, _ := base64.RawURLEncoding.DecodeString(raw); len(key) == len(BaseKey) {
    42  			copy(BaseKey[:], key)
    43  			return
    44  		} else {
    45  			panic(EnvBaseKey + ": invalid value: " + raw)
    46  		}
    47  	}
    48  	rand.Read(BaseKey[:])
    49  }
    50  
    51  func GetGlobalID(ctx context.Context) (globalID [8]byte) {
    52  	if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
    53  		(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") {
    54  		h := blake3.New(8, BaseKey[:])
    55  		h.Write([]byte(inbound.Source.String()))
    56  		copy(globalID[:], h.Sum(nil))
    57  		fmt.Printf("XUDP inbound.Source.String(): %v\tglobalID: %v\n", inbound.Source.String(), globalID)
    58  	}
    59  	return
    60  }
    61  
    62  func NewPacketWriter(writer buf.Writer, dest net.Destination, globalID [8]byte) *PacketWriter {
    63  	return &PacketWriter{
    64  		Writer:   writer,
    65  		Dest:     dest,
    66  		GlobalID: globalID,
    67  	}
    68  }
    69  
    70  type PacketWriter struct {
    71  	Writer   buf.Writer
    72  	Dest     net.Destination
    73  	GlobalID [8]byte
    74  }
    75  
    76  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
    77  	defer buf.ReleaseMulti(mb)
    78  	mb2Write := make(buf.MultiBuffer, 0, len(mb))
    79  	for _, b := range mb {
    80  		length := b.Len()
    81  		if length == 0 || length+666 > buf.Size {
    82  			continue
    83  		}
    84  
    85  		eb := buf.New()
    86  		eb.Write([]byte{0, 0, 0, 0})
    87  		if w.Dest.Network == net.Network_UDP {
    88  			eb.WriteByte(1) // New
    89  			eb.WriteByte(1) // Opt
    90  			eb.WriteByte(2) // UDP
    91  			AddrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port)
    92  			if b.UDP != nil { // make sure it's user's proxy request
    93  				eb.Write(w.GlobalID[:])
    94  			}
    95  			w.Dest.Network = net.Network_Unknown
    96  		} else {
    97  			eb.WriteByte(2) // Keep
    98  			eb.WriteByte(1)
    99  			if b.UDP != nil {
   100  				eb.WriteByte(2)
   101  				AddrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port)
   102  			}
   103  		}
   104  		l := eb.Len() - 2
   105  		eb.SetByte(0, byte(l>>8))
   106  		eb.SetByte(1, byte(l))
   107  		eb.WriteByte(byte(length >> 8))
   108  		eb.WriteByte(byte(length))
   109  		eb.Write(b.Bytes())
   110  
   111  		mb2Write = append(mb2Write, eb)
   112  	}
   113  	if mb2Write.IsEmpty() {
   114  		return nil
   115  	}
   116  	return w.Writer.WriteMultiBuffer(mb2Write)
   117  }
   118  
   119  func NewPacketReader(reader io.Reader) *PacketReader {
   120  	return &PacketReader{
   121  		Reader: reader,
   122  		cache:  make([]byte, 2),
   123  	}
   124  }
   125  
   126  type PacketReader struct {
   127  	Reader io.Reader
   128  	cache  []byte
   129  }
   130  
   131  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   132  	for {
   133  		if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
   134  			return nil, err
   135  		}
   136  		l := int32(r.cache[0])<<8 | int32(r.cache[1])
   137  		if l < 4 {
   138  			return nil, io.EOF
   139  		}
   140  		b := buf.New()
   141  		if _, err := b.ReadFullFrom(r.Reader, l); err != nil {
   142  			b.Release()
   143  			return nil, err
   144  		}
   145  		discard := false
   146  		switch b.Byte(2) {
   147  		case 2:
   148  			if l != 4 {
   149  				b.Advance(5)
   150  				addr, port, err := AddrParser.ReadAddressPort(nil, b)
   151  				if err != nil {
   152  					b.Release()
   153  					return nil, err
   154  				}
   155  				b.UDP = &net.Destination{
   156  					Network: net.Network_UDP,
   157  					Address: addr,
   158  					Port:    port,
   159  				}
   160  			}
   161  		case 4:
   162  			discard = true
   163  		default:
   164  			b.Release()
   165  			return nil, io.EOF
   166  		}
   167  		if b.Byte(3) == 1 {
   168  			if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
   169  				b.Release()
   170  				return nil, err
   171  			}
   172  			length := int32(r.cache[0])<<8 | int32(r.cache[1])
   173  			if length > 0 {
   174  				b.Clear()
   175  				if _, err := b.ReadFullFrom(r.Reader, length); err != nil {
   176  					b.Release()
   177  					return nil, err
   178  				}
   179  				if !discard {
   180  					return buf.MultiBuffer{b}, nil
   181  				}
   182  			}
   183  		}
   184  		b.Release()
   185  	}
   186  }