github.com/xmplusdev/xray-core@v1.8.10/proxy/vmess/encoding/encoding_test.go (about) 1 package encoding_test 2 3 import ( 4 "context" 5 "testing" 6 7 "github.com/google/go-cmp/cmp" 8 "github.com/xmplusdev/xray-core/common" 9 "github.com/xmplusdev/xray-core/common/buf" 10 "github.com/xmplusdev/xray-core/common/net" 11 "github.com/xmplusdev/xray-core/common/protocol" 12 "github.com/xmplusdev/xray-core/common/uuid" 13 "github.com/xmplusdev/xray-core/proxy/vmess" 14 . "github.com/xmplusdev/xray-core/proxy/vmess/encoding" 15 ) 16 17 func toAccount(a *vmess.Account) protocol.Account { 18 account, err := a.AsAccount() 19 common.Must(err) 20 return account 21 } 22 23 func TestRequestSerialization(t *testing.T) { 24 user := &protocol.MemoryUser{ 25 Level: 0, 26 Email: "test@example.com", 27 } 28 id := uuid.New() 29 account := &vmess.Account{ 30 Id: id.String(), 31 } 32 user.Account = toAccount(account) 33 34 expectedRequest := &protocol.RequestHeader{ 35 Version: 1, 36 User: user, 37 Command: protocol.RequestCommandTCP, 38 Address: net.DomainAddress("www.example.com"), 39 Port: net.Port(443), 40 Security: protocol.SecurityType_AES128_GCM, 41 } 42 43 buffer := buf.New() 44 client := NewClientSession(context.TODO(), 0) 45 common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) 46 47 buffer2 := buf.New() 48 buffer2.Write(buffer.Bytes()) 49 50 sessionHistory := NewSessionHistory() 51 defer common.Close(sessionHistory) 52 53 userValidator := vmess.NewTimedUserValidator() 54 userValidator.Add(user) 55 defer common.Close(userValidator) 56 57 server := NewServerSession(userValidator, sessionHistory) 58 actualRequest, err := server.DecodeRequestHeader(buffer, false) 59 common.Must(err) 60 61 if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" { 62 t.Error(r) 63 } 64 65 _, err = server.DecodeRequestHeader(buffer2, false) 66 // anti replay attack 67 if err == nil { 68 t.Error("nil error") 69 } 70 } 71 72 func TestInvalidRequest(t *testing.T) { 73 user := &protocol.MemoryUser{ 74 Level: 0, 75 Email: "test@example.com", 76 } 77 id := uuid.New() 78 account := &vmess.Account{ 79 Id: id.String(), 80 } 81 user.Account = toAccount(account) 82 83 expectedRequest := &protocol.RequestHeader{ 84 Version: 1, 85 User: user, 86 Command: protocol.RequestCommand(100), 87 Address: net.DomainAddress("www.example.com"), 88 Port: net.Port(443), 89 Security: protocol.SecurityType_AES128_GCM, 90 } 91 92 buffer := buf.New() 93 client := NewClientSession(context.TODO(), 0) 94 common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) 95 96 buffer2 := buf.New() 97 buffer2.Write(buffer.Bytes()) 98 99 sessionHistory := NewSessionHistory() 100 defer common.Close(sessionHistory) 101 102 userValidator := vmess.NewTimedUserValidator() 103 userValidator.Add(user) 104 defer common.Close(userValidator) 105 106 server := NewServerSession(userValidator, sessionHistory) 107 _, err := server.DecodeRequestHeader(buffer, false) 108 if err == nil { 109 t.Error("nil error") 110 } 111 } 112 113 func TestMuxRequest(t *testing.T) { 114 user := &protocol.MemoryUser{ 115 Level: 0, 116 Email: "test@example.com", 117 } 118 id := uuid.New() 119 account := &vmess.Account{ 120 Id: id.String(), 121 } 122 user.Account = toAccount(account) 123 124 expectedRequest := &protocol.RequestHeader{ 125 Version: 1, 126 User: user, 127 Command: protocol.RequestCommandMux, 128 Security: protocol.SecurityType_AES128_GCM, 129 Address: net.DomainAddress("v1.mux.cool"), 130 } 131 132 buffer := buf.New() 133 client := NewClientSession(context.TODO(), 0) 134 common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) 135 136 buffer2 := buf.New() 137 buffer2.Write(buffer.Bytes()) 138 139 sessionHistory := NewSessionHistory() 140 defer common.Close(sessionHistory) 141 142 userValidator := vmess.NewTimedUserValidator() 143 userValidator.Add(user) 144 defer common.Close(userValidator) 145 146 server := NewServerSession(userValidator, sessionHistory) 147 actualRequest, err := server.DecodeRequestHeader(buffer, false) 148 common.Must(err) 149 150 if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" { 151 t.Error(r) 152 } 153 }