gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/runsc/boot/portforward/portforward_hostinet_test.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package portforward 16 17 import ( 18 "fmt" 19 "net" 20 "slices" 21 "strings" 22 "sync" 23 "testing" 24 "time" 25 26 "golang.org/x/sync/errgroup" 27 "gvisor.dev/gvisor/pkg/context" 28 "gvisor.dev/gvisor/pkg/errors/linuxerr" 29 "gvisor.dev/gvisor/pkg/sentry/contexttest" 30 ) 31 32 func TestLocalHostSocket(t *testing.T) { 33 ctx := contexttest.Context(t) 34 clientData := append( 35 []byte("do what must be done\n"), 36 []byte("do not hesitate\n")..., 37 ) 38 39 serverData := append( 40 []byte("commander cody...the time has come\n"), 41 []byte("execute order 66\n")..., 42 ) 43 44 l, err := net.Listen("tcp", ":0") 45 if err != nil { 46 t.Fatalf("net.Listen failed: %v", err) 47 } 48 defer l.Close() 49 50 port := l.Addr().(*net.TCPAddr).Port 51 var g errgroup.Group 52 53 g.Go(func() error { 54 conn, err := l.Accept() 55 if err != nil { 56 t.Fatalf("could not accept connection: %v", err) 57 } 58 defer conn.Close() 59 60 data := make([]byte, 1024) 61 recLen, err := conn.Read(data) 62 if err != nil { 63 return fmt.Errorf("could not read data: %v", err) 64 } 65 66 if !slices.Equal(data[:recLen], clientData) { 67 return fmt.Errorf("server mismatch data recieved: got: %s want: %s", data[:recLen], clientData) 68 } 69 70 sentLen, err := conn.Write(serverData) 71 if err != nil { 72 return fmt.Errorf("could not write data: %v", err) 73 } 74 75 if sentLen != len(serverData) { 76 return fmt.Errorf("server mismatch data sent: got: %d want: %d", sentLen, len(serverData)) 77 } 78 79 return nil 80 }) 81 82 g.Go(func() error { 83 sock, err := NewHostInetConn(uint16(port)) 84 if err != nil { 85 t.Fatalf("could not create local host socket: %v", err) 86 } 87 for i := 0; i < len(clientData); { 88 n, err := sock.Write(ctx, clientData[i:], nil) 89 if err != nil { 90 return fmt.Errorf("could not write to local host socket: %v", err) 91 } 92 i += n 93 } 94 95 data := make([]byte, 1024) 96 dataLen := 0 97 for dataLen < len(serverData) { 98 n, err := sock.Read(ctx, data[dataLen:], nil) 99 if err != nil { 100 t.Fatalf("could not read from local host socket: %v", err) 101 } 102 dataLen += n 103 } 104 105 if !slices.Equal(data[:dataLen], serverData) { 106 return fmt.Errorf("server mismatch data received: got: %s want: %s", data[:dataLen], clientData) 107 } 108 return nil 109 }) 110 111 if err := g.Wait(); err != nil { 112 t.Fatal(err) 113 } 114 } 115 116 type netConnMockEndpoint struct { 117 conn net.Conn 118 mu sync.Mutex 119 } 120 121 // read implements portforwarderTestHarness.read. 122 func (nc *netConnMockEndpoint) read(n int) ([]byte, error) { 123 nc.mu.Lock() 124 defer nc.mu.Unlock() 125 126 buf := make([]byte, n) 127 nc.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 500)) 128 res, err := nc.conn.Read(buf) 129 if err != nil && strings.Contains(err.Error(), "timeout") { 130 return nil, linuxerr.ErrWouldBlock 131 } 132 return buf[:res], err 133 } 134 135 // write implements portforwarderTestHarness write. 136 func (nc *netConnMockEndpoint) write(buf []byte) (int, error) { 137 nc.mu.Lock() 138 defer nc.mu.Unlock() 139 written := 0 140 for { 141 n, err := nc.conn.Write(buf[written:]) 142 if err != nil && !linuxerr.Equals(linuxerr.ErrWouldBlock, err) { 143 return n, err 144 } 145 written += n 146 if written >= len(buf) { 147 return written, nil 148 } 149 } 150 } 151 152 func TestHostInetProxy(t *testing.T) { 153 for _, tc := range []struct { 154 name string 155 requests map[string]string 156 }{ 157 { 158 name: "single", 159 requests: map[string]string{ 160 "PING": "PONG", 161 }, 162 }, 163 { 164 name: "multiple", 165 requests: map[string]string{ 166 "PING": "PONG", 167 "HELLO": "GOODBYE", 168 "IMPRESSIVE": "MOST IMPRESSIVE", 169 }, 170 }, 171 { 172 name: "empty", 173 requests: map[string]string{ 174 "EMPTY": "", 175 "NOT": "EMPTY", 176 "OTHER EMPTY": "", 177 }, 178 }, 179 } { 180 t.Run(tc.name, func(t *testing.T) { 181 doHostinetTest(t, tc.name, tc.requests) 182 }) 183 } 184 } 185 186 func doHostinetTest(t *testing.T, name string, requests map[string]string) { 187 ctx := context.Background() 188 appEndpoint := newMockApplicationFDImpl() 189 client, err := newMockFileDescription(ctx, appEndpoint) 190 if err != nil { 191 t.Fatalf("newMockFileDescription: %v", err) 192 } 193 194 l, err := net.Listen("tcp", ":0") 195 if err != nil { 196 t.Fatalf("net.Listen failed: %v", err) 197 } 198 defer l.Close() 199 port := uint16(l.Addr().(*net.TCPAddr).Port) 200 sock, err := NewHostInetConn(port) 201 if err != nil { 202 t.Fatalf("could not create local host socket: %v", err) 203 } 204 205 proxy := NewProxy(ProxyPair{To: sock, From: &fileDescriptionConn{file: client}}, name) 206 207 proxy.Start(ctx) 208 209 shim, err := l.Accept() 210 if err != nil { 211 t.Fatalf("could not accept shim connection: %v", err) 212 } 213 defer shim.Close() 214 harness := portforwarderTestHarness{ 215 app: appEndpoint, 216 shim: &netConnMockEndpoint{conn: shim}, 217 } 218 219 for req, resp := range requests { 220 if _, err := harness.shimWrite([]byte(req)); err != nil { 221 t.Fatalf("failed to write to shim: %v", err) 222 } 223 224 got, err := harness.appRead(len(req)) 225 if err != nil { 226 t.Fatalf("failed to read from app: %v", err) 227 } 228 229 if string(got) != req { 230 t.Fatalf("app mismatch: got: %s want: %s", string(got), req) 231 } 232 233 if _, err := harness.appWrite([]byte(resp)); err != nil { 234 t.Fatalf("failed to write to app: %v", err) 235 } 236 237 got, err = harness.shimRead(len(resp)) 238 if err != nil { 239 t.Fatalf("failed to read from shim: %v", err) 240 } 241 if string(got) != resp { 242 t.Fatalf("shim mismatch: got: %s want: %s", string(got), resp) 243 } 244 } 245 }