github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/p9/transport_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package p9 16 17 import ( 18 "io/ioutil" 19 "os" 20 "testing" 21 22 "github.com/SagerNet/gvisor/pkg/fd" 23 "github.com/SagerNet/gvisor/pkg/unet" 24 ) 25 26 const ( 27 MsgTypeBadEncode = iota + 252 28 MsgTypeBadDecode 29 MsgTypeUnregistered 30 ) 31 32 func TestSendRecv(t *testing.T) { 33 server, client, err := unet.SocketPair(false) 34 if err != nil { 35 t.Fatalf("socketpair got err %v expected nil", err) 36 } 37 defer server.Close() 38 defer client.Close() 39 40 if err := send(client, Tag(1), &Tlopen{}); err != nil { 41 t.Fatalf("send got err %v expected nil", err) 42 } 43 44 tag, m, err := recv(server, maximumLength, msgRegistry.get) 45 if err != nil { 46 t.Fatalf("recv got err %v expected nil", err) 47 } 48 if tag != Tag(1) { 49 t.Fatalf("got tag %v expected 1", tag) 50 } 51 if _, ok := m.(*Tlopen); !ok { 52 t.Fatalf("got message %v expected *Tlopen", m) 53 } 54 } 55 56 // badDecode overruns on decode. 57 type badDecode struct{} 58 59 func (*badDecode) decode(b *buffer) { b.markOverrun() } 60 func (*badDecode) encode(b *buffer) {} 61 func (*badDecode) Type() MsgType { return MsgTypeBadDecode } 62 func (*badDecode) String() string { return "badDecode{}" } 63 64 func TestRecvOverrun(t *testing.T) { 65 server, client, err := unet.SocketPair(false) 66 if err != nil { 67 t.Fatalf("socketpair got err %v expected nil", err) 68 } 69 defer server.Close() 70 defer client.Close() 71 72 if err := send(client, Tag(1), &badDecode{}); err != nil { 73 t.Fatalf("send got err %v expected nil", err) 74 } 75 76 if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil { 77 t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err) 78 } 79 } 80 81 // unregistered is not registered on decode. 82 type unregistered struct{} 83 84 func (*unregistered) decode(b *buffer) {} 85 func (*unregistered) encode(b *buffer) {} 86 func (*unregistered) Type() MsgType { return MsgTypeUnregistered } 87 func (*unregistered) String() string { return "unregistered{}" } 88 89 func TestRecvInvalidType(t *testing.T) { 90 server, client, err := unet.SocketPair(false) 91 if err != nil { 92 t.Fatalf("socketpair got err %v expected nil", err) 93 } 94 defer server.Close() 95 defer client.Close() 96 97 if err := send(client, Tag(1), &unregistered{}); err != nil { 98 t.Fatalf("send got err %v expected nil", err) 99 } 100 101 _, _, err = recv(server, maximumLength, msgRegistry.get) 102 if _, ok := err.(*ErrInvalidMsgType); !ok { 103 t.Fatalf("recv got err %v expected ErrInvalidMsgType", err) 104 } 105 } 106 107 func TestSendRecvWithFile(t *testing.T) { 108 server, client, err := unet.SocketPair(false) 109 if err != nil { 110 t.Fatalf("socketpair got err %v expected nil", err) 111 } 112 defer server.Close() 113 defer client.Close() 114 115 // Create a tempfile. 116 osf, err := ioutil.TempFile("", "p9") 117 if err != nil { 118 t.Fatalf("tempfile got err %v expected nil", err) 119 } 120 os.Remove(osf.Name()) 121 f, err := fd.NewFromFile(osf) 122 osf.Close() 123 if err != nil { 124 t.Fatalf("unable to create file: %v", err) 125 } 126 127 rlopen := &Rlopen{} 128 rlopen.SetFilePayload(f) 129 if err := send(client, Tag(1), rlopen); err != nil { 130 t.Fatalf("send got err %v expected nil", err) 131 } 132 133 // Enable withFile. 134 tag, m, err := recv(server, maximumLength, msgRegistry.get) 135 if err != nil { 136 t.Fatalf("recv got err %v expected nil", err) 137 } 138 if tag != Tag(1) { 139 t.Fatalf("got tag %v expected 1", tag) 140 } 141 rlopen, ok := m.(*Rlopen) 142 if !ok { 143 t.Fatalf("got m %v expected *Rlopen", m) 144 } 145 if rlopen.File == nil { 146 t.Fatalf("got nil file expected non-nil") 147 } 148 } 149 150 func TestRecvClosed(t *testing.T) { 151 server, client, err := unet.SocketPair(false) 152 if err != nil { 153 t.Fatalf("socketpair got err %v expected nil", err) 154 } 155 defer server.Close() 156 client.Close() 157 158 _, _, err = recv(server, maximumLength, msgRegistry.get) 159 if err == nil { 160 t.Fatalf("got err nil expected non-nil") 161 } 162 if _, ok := err.(ErrSocket); !ok { 163 t.Fatalf("got err %v expected ErrSocket", err) 164 } 165 } 166 167 func TestSendClosed(t *testing.T) { 168 server, client, err := unet.SocketPair(false) 169 if err != nil { 170 t.Fatalf("socketpair got err %v expected nil", err) 171 } 172 server.Close() 173 defer client.Close() 174 175 err = send(client, Tag(1), &Tlopen{}) 176 if err == nil { 177 t.Fatalf("send got err nil expected non-nil") 178 } 179 if _, ok := err.(ErrSocket); !ok { 180 t.Fatalf("got err %v expected ErrSocket", err) 181 } 182 } 183 184 func BenchmarkSendRecv(b *testing.B) { 185 b.ReportAllocs() 186 187 server, client, err := unet.SocketPair(false) 188 if err != nil { 189 b.Fatalf("socketpair got err %v expected nil", err) 190 } 191 defer server.Close() 192 defer client.Close() 193 194 // Exchange Rflush messages since these contain no data and therefore incur 195 // no additional marshaling overhead. 196 go func() { 197 for i := 0; i < b.N; i++ { 198 tag, m, err := recv(server, maximumLength, msgRegistry.get) 199 if err != nil { 200 b.Errorf("recv got err %v expected nil", err) 201 } 202 if tag != Tag(1) { 203 b.Errorf("got tag %v expected 1", tag) 204 } 205 if _, ok := m.(*Rflush); !ok { 206 b.Errorf("got message %T expected *Rflush", m) 207 } 208 if err := send(server, Tag(2), &Rflush{}); err != nil { 209 b.Errorf("send got err %v expected nil", err) 210 } 211 } 212 }() 213 b.ResetTimer() 214 for i := 0; i < b.N; i++ { 215 if err := send(client, Tag(1), &Rflush{}); err != nil { 216 b.Errorf("send got err %v expected nil", err) 217 } 218 tag, m, err := recv(client, maximumLength, msgRegistry.get) 219 if err != nil { 220 b.Errorf("recv got err %v expected nil", err) 221 } 222 if tag != Tag(2) { 223 b.Errorf("got tag %v expected 2", tag) 224 } 225 if _, ok := m.(*Rflush); !ok { 226 b.Errorf("got message %v expected *Rflush", m) 227 } 228 } 229 } 230 231 func init() { 232 msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} }) 233 }