github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/common/xudp/xudp.go (about) 1 package xudp 2 3 import ( 4 "context" 5 "crypto/rand" 6 "encoding/base64" 7 "fmt" 8 "io" 9 "strconv" 10 "strings" 11 "time" 12 13 "github.com/xmplusdev/xmcore/common/buf" 14 "github.com/xmplusdev/xmcore/common/net" 15 "github.com/xmplusdev/xmcore/common/platform" 16 "github.com/xmplusdev/xmcore/common/protocol" 17 "github.com/xmplusdev/xmcore/common/session" 18 "lukechampine.com/blake3" 19 ) 20 21 var AddrParser = protocol.NewAddressParser( 22 protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4), 23 protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain), 24 protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6), 25 protocol.PortThenAddress(), 26 ) 27 28 var ( 29 Show bool 30 BaseKey []byte 31 ) 32 33 func init() { 34 if strings.ToLower(platform.NewEnvFlag(platform.XUDPLog).GetValue(func() string { return "" })) == "true" { 35 Show = true 36 } 37 rand.Read(BaseKey) 38 go func() { 39 time.Sleep(100 * time.Millisecond) // this is not nice, but need to give some time for Android to setup ENV 40 if raw := platform.NewEnvFlag(platform.XUDPBaseKey).GetValue(func() string { return "" }); raw != "" { 41 if BaseKey, _ = base64.RawURLEncoding.DecodeString(raw); len(BaseKey) == 32 { 42 return 43 } 44 panic(platform.XUDPBaseKey + ": invalid value (BaseKey must be 32 bytes): " + raw + " len " + strconv.Itoa(len(BaseKey))) 45 } 46 }() 47 } 48 49 func GetGlobalID(ctx context.Context) (globalID [8]byte) { 50 if cone := ctx.Value("cone"); cone == nil || !cone.(bool) { // cone is nil only in some unit tests 51 return 52 } 53 if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP && 54 (inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") { 55 h := blake3.New(8, BaseKey) 56 h.Write([]byte(inbound.Source.String())) 57 copy(globalID[:], h.Sum(nil)) 58 if Show { 59 newError(fmt.Sprintf("XUDP inbound.Source.String(): %v\tglobalID: %v\n", inbound.Source.String(), globalID)).WriteToLog(session.ExportIDToError(ctx)) 60 } 61 } 62 return 63 } 64 65 func NewPacketWriter(writer buf.Writer, dest net.Destination, globalID [8]byte) *PacketWriter { 66 return &PacketWriter{ 67 Writer: writer, 68 Dest: dest, 69 GlobalID: globalID, 70 } 71 } 72 73 type PacketWriter struct { 74 Writer buf.Writer 75 Dest net.Destination 76 GlobalID [8]byte 77 } 78 79 func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { 80 defer buf.ReleaseMulti(mb) 81 mb2Write := make(buf.MultiBuffer, 0, len(mb)) 82 for _, b := range mb { 83 length := b.Len() 84 if length == 0 || length+666 > buf.Size { 85 continue 86 } 87 88 eb := buf.New() 89 eb.Write([]byte{0, 0, 0, 0}) // Meta data length; Mux Session ID 90 if w.Dest.Network == net.Network_UDP { 91 eb.WriteByte(1) // New 92 eb.WriteByte(1) // Opt 93 eb.WriteByte(2) // UDP 94 AddrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port) 95 if b.UDP != nil { // make sure it's user's proxy request 96 eb.Write(w.GlobalID[:]) // no need to check whether it's empty 97 } 98 w.Dest.Network = net.Network_Unknown 99 } else { 100 eb.WriteByte(2) // Keep 101 eb.WriteByte(1) // Opt 102 if b.UDP != nil { 103 eb.WriteByte(2) // UDP 104 AddrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port) 105 } 106 } 107 l := eb.Len() - 2 108 eb.SetByte(0, byte(l>>8)) 109 eb.SetByte(1, byte(l)) 110 eb.WriteByte(byte(length >> 8)) 111 eb.WriteByte(byte(length)) 112 eb.Write(b.Bytes()) 113 114 mb2Write = append(mb2Write, eb) 115 } 116 if mb2Write.IsEmpty() { 117 return nil 118 } 119 return w.Writer.WriteMultiBuffer(mb2Write) 120 } 121 122 func NewPacketReader(reader io.Reader) *PacketReader { 123 return &PacketReader{ 124 Reader: reader, 125 cache: make([]byte, 2), 126 } 127 } 128 129 type PacketReader struct { 130 Reader io.Reader 131 cache []byte 132 } 133 134 func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 135 for { 136 if _, err := io.ReadFull(r.Reader, r.cache); err != nil { 137 return nil, err 138 } 139 l := int32(r.cache[0])<<8 | int32(r.cache[1]) 140 if l < 4 { 141 return nil, io.EOF 142 } 143 b := buf.New() 144 if _, err := b.ReadFullFrom(r.Reader, l); err != nil { 145 b.Release() 146 return nil, err 147 } 148 discard := false 149 switch b.Byte(2) { 150 case 2: 151 if l > 4 && b.Byte(4) == 2 { // MUST check the flag first 152 b.Advance(5) 153 // b.Clear() will be called automatically if all data had been read. 154 addr, port, err := AddrParser.ReadAddressPort(nil, b) 155 if err != nil { 156 b.Release() 157 return nil, err 158 } 159 b.UDP = &net.Destination{ 160 Network: net.Network_UDP, 161 Address: addr, 162 Port: port, 163 } 164 } 165 case 4: 166 discard = true 167 default: 168 b.Release() 169 return nil, io.EOF 170 } 171 b.Clear() // in case there is padding (empty bytes) attached 172 if b.Byte(3) == 1 { 173 if _, err := io.ReadFull(r.Reader, r.cache); err != nil { 174 b.Release() 175 return nil, err 176 } 177 length := int32(r.cache[0])<<8 | int32(r.cache[1]) 178 if length > 0 { 179 if _, err := b.ReadFullFrom(r.Reader, length); err != nil { 180 b.Release() 181 return nil, err 182 } 183 if !discard { 184 return buf.MultiBuffer{b}, nil 185 } 186 } 187 } 188 b.Release() 189 } 190 }