gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/runsc/boot/portforward/portforward_netstack_test.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package portforward 16 17 import ( 18 "bytes" 19 "io" 20 "sync" 21 "testing" 22 23 "gvisor.dev/gvisor/pkg/sentry/contexttest" 24 "gvisor.dev/gvisor/pkg/tcpip" 25 "gvisor.dev/gvisor/pkg/waiter" 26 ) 27 28 type baseTCPEndpointImpl struct { 29 closed bool 30 readBuf bytes.Buffer 31 writeBuf bytes.Buffer 32 mu sync.Mutex 33 } 34 35 // read reads data from the buffer that "Write" writes to. 36 func (b *baseTCPEndpointImpl) read(n int) ([]byte, error) { 37 b.mu.Lock() 38 defer b.mu.Unlock() 39 if b.closed { 40 return nil, io.EOF 41 } 42 ret := b.writeBuf.Next(n) 43 return ret, nil 44 } 45 46 // write writes data to the read buffer that "Read" reads from. 47 func (b *baseTCPEndpointImpl) write(buf []byte) (int, error) { 48 b.mu.Lock() 49 defer b.mu.Unlock() 50 if b.closed { 51 return 0, io.EOF 52 } 53 n, err := b.readBuf.Write(buf) 54 return n, err 55 } 56 57 func (b *baseTCPEndpointImpl) Close() { 58 b.mu.Lock() 59 defer b.mu.Unlock() 60 b.closed = true 61 } 62 63 func (b *baseTCPEndpointImpl) Read(w io.Writer, _ tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { 64 b.mu.Lock() 65 defer b.mu.Unlock() 66 if b.closed { 67 return tcpip.ReadResult{}, &tcpip.ErrClosedForReceive{} 68 } 69 buf := b.readBuf.Next(b.readBuf.Len()) 70 n, err := w.Write(buf) 71 if err != nil { 72 return tcpip.ReadResult{}, &tcpip.ErrInvalidEndpointState{} 73 } 74 return tcpip.ReadResult{ 75 Count: n, 76 Total: n, 77 }, nil 78 } 79 80 func (b *baseTCPEndpointImpl) Write(payload tcpip.Payloader, _ tcpip.WriteOptions) (int64, tcpip.Error) { 81 b.mu.Lock() 82 defer b.mu.Unlock() 83 if b.closed { 84 return 0, &tcpip.ErrClosedForSend{} 85 } 86 buf := make([]byte, payload.Len()) 87 n, err := payload.Read(buf) 88 if err != nil { 89 return 0, &tcpip.ErrInvalidEndpointState{} 90 } 91 n, err = b.writeBuf.Write(buf[:n]) 92 if err != nil { 93 return int64(n), &tcpip.ErrConnectionRefused{} 94 } 95 return int64(n), nil 96 } 97 98 func (b *baseTCPEndpointImpl) Shutdown(shutdown tcpip.ShutdownFlags) tcpip.Error { 99 b.mu.Lock() 100 defer b.mu.Unlock() 101 b.closed = true 102 return nil 103 } 104 105 func TestNetstackProxy(t *testing.T) { 106 for _, tc := range []struct { 107 name string 108 requests map[string]string 109 }{ 110 { 111 name: "single", 112 requests: map[string]string{ 113 "PING": "PONG", 114 }, 115 }, 116 { 117 name: "multiple", 118 requests: map[string]string{ 119 "PING": "PONG", 120 "HELLO": "GOODBYE", 121 "IMPRESSIVE": "MOST IMPRESSIVE", 122 }, 123 }, 124 { 125 name: "empty", 126 requests: map[string]string{ 127 "EMPTY": "", 128 "NOT": "EMPTY", 129 "OTHER EMPTY": "", 130 }, 131 }, 132 } { 133 t.Run(tc.name, func(t *testing.T) { 134 doNetstackTest(t, tc.name, tc.requests) 135 }) 136 } 137 } 138 139 func doNetstackTest(t *testing.T, name string, responses map[string]string) { 140 ctx := contexttest.Context(t) 141 appEndpoint := newMockApplicationFDImpl() 142 fd, err := newMockFileDescription(ctx, appEndpoint) 143 if err != nil { 144 t.Fatalf("newMockFileDescription: %v", err) 145 } 146 147 wq := &waiter.Queue{} 148 impl := &baseTCPEndpointImpl{} 149 ep := newMockTCPEndpoint(impl, wq) 150 sock := &netstackConn{ 151 ep: ep, 152 wq: wq, 153 } 154 155 proxy := NewProxy(ProxyPair{To: sock, From: &fileDescriptionConn{file: fd}}, name) 156 proxy.Start(ctx) 157 defer proxy.Close() 158 159 harness := portforwarderTestHarness{ 160 app: appEndpoint, 161 shim: impl, 162 } 163 164 for req, resp := range responses { 165 if _, err := harness.shimWrite([]byte(req)); err != nil { 166 t.Fatalf("failed to write to shim: %v", err) 167 } 168 169 got, err := harness.appRead(len(req)) 170 if err != nil { 171 t.Fatalf("failed to read from app: %v", err) 172 } 173 174 if string(got) != req { 175 t.Fatalf("app mismatch: got: %s want: %s", string(got), req) 176 } 177 178 if _, err := harness.appWrite([]byte(resp)); err != nil { 179 t.Fatalf("failed to write to app: %v", err) 180 } 181 182 got, err = harness.shimRead(len(resp)) 183 if err != nil { 184 t.Fatalf("failed to read from shim: %v", err) 185 } 186 187 if string(got) != resp { 188 t.Fatalf("shim mismatch: got: %s want: %s", string(got), resp) 189 } 190 } 191 } 192 193 // tcpErrImpl blocks on the first Read/Write and then throws an error afterwards. 194 type tcpErrImpl struct { 195 mu sync.Mutex 196 reads bool 197 writes bool 198 } 199 200 // Read implements mockTCPEndpointImpl.Read. 201 func (e *tcpErrImpl) Read(w io.Writer, _ tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { 202 e.mu.Lock() 203 defer e.mu.Unlock() 204 if e.reads { 205 return tcpip.ReadResult{}, &tcpip.ErrBadLocalAddress{} 206 } 207 e.reads = true 208 return tcpip.ReadResult{}, &tcpip.ErrWouldBlock{} 209 } 210 211 // Write implements mockTCPEndpointImpl.Write. 212 func (e *tcpErrImpl) Write(payload tcpip.Payloader, _ tcpip.WriteOptions) (int64, tcpip.Error) { 213 e.mu.Lock() 214 defer e.mu.Unlock() 215 if e.writes { 216 return 0, &tcpip.ErrBadLocalAddress{} 217 } 218 e.writes = true 219 return 0, &tcpip.ErrWouldBlock{} 220 } 221 222 // Shutdown implements mockTCPEndpointImpl.Shutdown. 223 func (e *tcpErrImpl) Shutdown(shutdown tcpip.ShutdownFlags) tcpip.Error { 224 return nil 225 } 226 227 // Close implements mockTCPEndpointImpl.Shutdown. 228 func (e *tcpErrImpl) Close() {} 229 230 // TestNTestNestackReadsWrites checks that reads/writes check errors from the underlying endpoint 231 // multiple times. 232 func TestNestackReadsWrites(t *testing.T) { 233 ctx := contexttest.Context(t) 234 wq := &waiter.Queue{} 235 ep := newMockTCPEndpoint(&tcpErrImpl{}, wq) 236 cancel := make(chan struct{}) 237 conn := netstackConn{ep: ep, wq: wq} 238 defer close(cancel) 239 defer conn.Close(ctx) 240 241 _, err := conn.Read(ctx, []byte("something"), cancel) 242 if err != io.EOF { 243 t.Fatalf("mismatch read err: want: %v got: %v", io.EOF, err) 244 } 245 246 _, err = conn.Write(ctx, []byte("something"), cancel) 247 if err != io.EOF { 248 t.Fatalf("mismatch write err: want: %v got: %v", io.EOF, err) 249 } 250 }