github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/common/xudp/xudp.go (about)

     1  package xudp
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"io"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/xmplusdev/xmcore/common/buf"
    14  	"github.com/xmplusdev/xmcore/common/net"
    15  	"github.com/xmplusdev/xmcore/common/platform"
    16  	"github.com/xmplusdev/xmcore/common/protocol"
    17  	"github.com/xmplusdev/xmcore/common/session"
    18  	"lukechampine.com/blake3"
    19  )
    20  
    21  var AddrParser = protocol.NewAddressParser(
    22  	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
    23  	protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
    24  	protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
    25  	protocol.PortThenAddress(),
    26  )
    27  
    28  var (
    29  	Show    bool
    30  	BaseKey []byte
    31  )
    32  
    33  func init() {
    34  	if strings.ToLower(platform.NewEnvFlag(platform.XUDPLog).GetValue(func() string { return "" })) == "true" {
    35  		Show = true
    36  	}
    37  	rand.Read(BaseKey)
    38  	go func() {
    39  		time.Sleep(100 * time.Millisecond) // this is not nice, but need to give some time for Android to setup ENV 
    40  		if raw := platform.NewEnvFlag(platform.XUDPBaseKey).GetValue(func() string { return "" }); raw != "" {
    41  			if BaseKey, _ = base64.RawURLEncoding.DecodeString(raw); len(BaseKey) == 32 {
    42  				return
    43  			}
    44  			panic(platform.XUDPBaseKey + ": invalid value (BaseKey must be 32 bytes): " + raw + " len " + strconv.Itoa(len(BaseKey)))
    45  		}
    46  	}()
    47  }
    48  
    49  func GetGlobalID(ctx context.Context) (globalID [8]byte) {
    50  	if cone := ctx.Value("cone"); cone == nil || !cone.(bool) { // cone is nil only in some unit tests
    51  		return
    52  	}
    53  	if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
    54  		(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") {
    55  		h := blake3.New(8, BaseKey)
    56  		h.Write([]byte(inbound.Source.String()))
    57  		copy(globalID[:], h.Sum(nil))
    58  		if Show {
    59  			newError(fmt.Sprintf("XUDP inbound.Source.String(): %v\tglobalID: %v\n", inbound.Source.String(), globalID)).WriteToLog(session.ExportIDToError(ctx))
    60  		}
    61  	}
    62  	return
    63  }
    64  
    65  func NewPacketWriter(writer buf.Writer, dest net.Destination, globalID [8]byte) *PacketWriter {
    66  	return &PacketWriter{
    67  		Writer:   writer,
    68  		Dest:     dest,
    69  		GlobalID: globalID,
    70  	}
    71  }
    72  
    73  type PacketWriter struct {
    74  	Writer   buf.Writer
    75  	Dest     net.Destination
    76  	GlobalID [8]byte
    77  }
    78  
    79  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
    80  	defer buf.ReleaseMulti(mb)
    81  	mb2Write := make(buf.MultiBuffer, 0, len(mb))
    82  	for _, b := range mb {
    83  		length := b.Len()
    84  		if length == 0 || length+666 > buf.Size {
    85  			continue
    86  		}
    87  
    88  		eb := buf.New()
    89  		eb.Write([]byte{0, 0, 0, 0}) // Meta data length; Mux Session ID
    90  		if w.Dest.Network == net.Network_UDP {
    91  			eb.WriteByte(1) // New
    92  			eb.WriteByte(1) // Opt
    93  			eb.WriteByte(2) // UDP
    94  			AddrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port)
    95  			if b.UDP != nil { // make sure it's user's proxy request
    96  				eb.Write(w.GlobalID[:]) // no need to check whether it's empty
    97  			}
    98  			w.Dest.Network = net.Network_Unknown
    99  		} else {
   100  			eb.WriteByte(2) // Keep
   101  			eb.WriteByte(1) // Opt
   102  			if b.UDP != nil {
   103  				eb.WriteByte(2) // UDP
   104  				AddrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port)
   105  			}
   106  		}
   107  		l := eb.Len() - 2
   108  		eb.SetByte(0, byte(l>>8))
   109  		eb.SetByte(1, byte(l))
   110  		eb.WriteByte(byte(length >> 8))
   111  		eb.WriteByte(byte(length))
   112  		eb.Write(b.Bytes())
   113  
   114  		mb2Write = append(mb2Write, eb)
   115  	}
   116  	if mb2Write.IsEmpty() {
   117  		return nil
   118  	}
   119  	return w.Writer.WriteMultiBuffer(mb2Write)
   120  }
   121  
   122  func NewPacketReader(reader io.Reader) *PacketReader {
   123  	return &PacketReader{
   124  		Reader: reader,
   125  		cache:  make([]byte, 2),
   126  	}
   127  }
   128  
   129  type PacketReader struct {
   130  	Reader io.Reader
   131  	cache  []byte
   132  }
   133  
   134  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   135  	for {
   136  		if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
   137  			return nil, err
   138  		}
   139  		l := int32(r.cache[0])<<8 | int32(r.cache[1])
   140  		if l < 4 {
   141  			return nil, io.EOF
   142  		}
   143  		b := buf.New()
   144  		if _, err := b.ReadFullFrom(r.Reader, l); err != nil {
   145  			b.Release()
   146  			return nil, err
   147  		}
   148  		discard := false
   149  		switch b.Byte(2) {
   150  		case 2:
   151  			if l > 4 && b.Byte(4) == 2 { // MUST check the flag first
   152  				b.Advance(5)
   153  				// b.Clear() will be called automatically if all data had been read.
   154  				addr, port, err := AddrParser.ReadAddressPort(nil, b)
   155  				if err != nil {
   156  					b.Release()
   157  					return nil, err
   158  				}
   159  				b.UDP = &net.Destination{
   160  					Network: net.Network_UDP,
   161  					Address: addr,
   162  					Port:    port,
   163  				}
   164  			}
   165  		case 4:
   166  			discard = true
   167  		default:
   168  			b.Release()
   169  			return nil, io.EOF
   170  		}
   171  		b.Clear() // in case there is padding (empty bytes) attached
   172  		if b.Byte(3) == 1 {
   173  			if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
   174  				b.Release()
   175  				return nil, err
   176  			}
   177  			length := int32(r.cache[0])<<8 | int32(r.cache[1])
   178  			if length > 0 {
   179  				if _, err := b.ReadFullFrom(r.Reader, length); err != nil {
   180  					b.Release()
   181  					return nil, err
   182  				}
   183  				if !discard {
   184  					return buf.MultiBuffer{b}, nil
   185  				}
   186  			}
   187  		}
   188  		b.Release()
   189  	}
   190  }