gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/runsc/boot/portforward/portforward_hostinet_test.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package portforward
    16  
    17  import (
    18  	"fmt"
    19  	"net"
    20  	"slices"
    21  	"strings"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	"golang.org/x/sync/errgroup"
    27  	"gvisor.dev/gvisor/pkg/context"
    28  	"gvisor.dev/gvisor/pkg/errors/linuxerr"
    29  	"gvisor.dev/gvisor/pkg/sentry/contexttest"
    30  )
    31  
    32  func TestLocalHostSocket(t *testing.T) {
    33  	ctx := contexttest.Context(t)
    34  	clientData := append(
    35  		[]byte("do what must be done\n"),
    36  		[]byte("do not hesitate\n")...,
    37  	)
    38  
    39  	serverData := append(
    40  		[]byte("commander cody...the time has come\n"),
    41  		[]byte("execute order 66\n")...,
    42  	)
    43  
    44  	l, err := net.Listen("tcp", ":0")
    45  	if err != nil {
    46  		t.Fatalf("net.Listen failed: %v", err)
    47  	}
    48  	defer l.Close()
    49  
    50  	port := l.Addr().(*net.TCPAddr).Port
    51  	var g errgroup.Group
    52  
    53  	g.Go(func() error {
    54  		conn, err := l.Accept()
    55  		if err != nil {
    56  			t.Fatalf("could not accept connection: %v", err)
    57  		}
    58  		defer conn.Close()
    59  
    60  		data := make([]byte, 1024)
    61  		recLen, err := conn.Read(data)
    62  		if err != nil {
    63  			return fmt.Errorf("could not read data: %v", err)
    64  		}
    65  
    66  		if !slices.Equal(data[:recLen], clientData) {
    67  			return fmt.Errorf("server mismatch data recieved: got: %s want: %s", data[:recLen], clientData)
    68  		}
    69  
    70  		sentLen, err := conn.Write(serverData)
    71  		if err != nil {
    72  			return fmt.Errorf("could not write data: %v", err)
    73  		}
    74  
    75  		if sentLen != len(serverData) {
    76  			return fmt.Errorf("server mismatch data sent: got: %d want: %d", sentLen, len(serverData))
    77  		}
    78  
    79  		return nil
    80  	})
    81  
    82  	g.Go(func() error {
    83  		sock, err := NewHostInetConn(uint16(port))
    84  		if err != nil {
    85  			t.Fatalf("could not create local host socket: %v", err)
    86  		}
    87  		for i := 0; i < len(clientData); {
    88  			n, err := sock.Write(ctx, clientData[i:], nil)
    89  			if err != nil {
    90  				return fmt.Errorf("could not write to local host socket: %v", err)
    91  			}
    92  			i += n
    93  		}
    94  
    95  		data := make([]byte, 1024)
    96  		dataLen := 0
    97  		for dataLen < len(serverData) {
    98  			n, err := sock.Read(ctx, data[dataLen:], nil)
    99  			if err != nil {
   100  				t.Fatalf("could not read from local host socket: %v", err)
   101  			}
   102  			dataLen += n
   103  		}
   104  
   105  		if !slices.Equal(data[:dataLen], serverData) {
   106  			return fmt.Errorf("server mismatch data received: got: %s want: %s", data[:dataLen], clientData)
   107  		}
   108  		return nil
   109  	})
   110  
   111  	if err := g.Wait(); err != nil {
   112  		t.Fatal(err)
   113  	}
   114  }
   115  
   116  type netConnMockEndpoint struct {
   117  	conn net.Conn
   118  	mu   sync.Mutex
   119  }
   120  
   121  // read implements portforwarderTestHarness.read.
   122  func (nc *netConnMockEndpoint) read(n int) ([]byte, error) {
   123  	nc.mu.Lock()
   124  	defer nc.mu.Unlock()
   125  
   126  	buf := make([]byte, n)
   127  	nc.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 500))
   128  	res, err := nc.conn.Read(buf)
   129  	if err != nil && strings.Contains(err.Error(), "timeout") {
   130  		return nil, linuxerr.ErrWouldBlock
   131  	}
   132  	return buf[:res], err
   133  }
   134  
   135  // write implements portforwarderTestHarness write.
   136  func (nc *netConnMockEndpoint) write(buf []byte) (int, error) {
   137  	nc.mu.Lock()
   138  	defer nc.mu.Unlock()
   139  	written := 0
   140  	for {
   141  		n, err := nc.conn.Write(buf[written:])
   142  		if err != nil && !linuxerr.Equals(linuxerr.ErrWouldBlock, err) {
   143  			return n, err
   144  		}
   145  		written += n
   146  		if written >= len(buf) {
   147  			return written, nil
   148  		}
   149  	}
   150  }
   151  
   152  func TestHostInetProxy(t *testing.T) {
   153  	for _, tc := range []struct {
   154  		name     string
   155  		requests map[string]string
   156  	}{
   157  		{
   158  			name: "single",
   159  			requests: map[string]string{
   160  				"PING": "PONG",
   161  			},
   162  		},
   163  		{
   164  			name: "multiple",
   165  			requests: map[string]string{
   166  				"PING":       "PONG",
   167  				"HELLO":      "GOODBYE",
   168  				"IMPRESSIVE": "MOST IMPRESSIVE",
   169  			},
   170  		},
   171  		{
   172  			name: "empty",
   173  			requests: map[string]string{
   174  				"EMPTY":       "",
   175  				"NOT":         "EMPTY",
   176  				"OTHER EMPTY": "",
   177  			},
   178  		},
   179  	} {
   180  		t.Run(tc.name, func(t *testing.T) {
   181  			doHostinetTest(t, tc.name, tc.requests)
   182  		})
   183  	}
   184  }
   185  
   186  func doHostinetTest(t *testing.T, name string, requests map[string]string) {
   187  	ctx := context.Background()
   188  	appEndpoint := newMockApplicationFDImpl()
   189  	client, err := newMockFileDescription(ctx, appEndpoint)
   190  	if err != nil {
   191  		t.Fatalf("newMockFileDescription: %v", err)
   192  	}
   193  
   194  	l, err := net.Listen("tcp", ":0")
   195  	if err != nil {
   196  		t.Fatalf("net.Listen failed: %v", err)
   197  	}
   198  	defer l.Close()
   199  	port := uint16(l.Addr().(*net.TCPAddr).Port)
   200  	sock, err := NewHostInetConn(port)
   201  	if err != nil {
   202  		t.Fatalf("could not create local host socket: %v", err)
   203  	}
   204  
   205  	proxy := NewProxy(ProxyPair{To: sock, From: &fileDescriptionConn{file: client}}, name)
   206  
   207  	proxy.Start(ctx)
   208  
   209  	shim, err := l.Accept()
   210  	if err != nil {
   211  		t.Fatalf("could not accept shim connection: %v", err)
   212  	}
   213  	defer shim.Close()
   214  	harness := portforwarderTestHarness{
   215  		app:  appEndpoint,
   216  		shim: &netConnMockEndpoint{conn: shim},
   217  	}
   218  
   219  	for req, resp := range requests {
   220  		if _, err := harness.shimWrite([]byte(req)); err != nil {
   221  			t.Fatalf("failed to write to shim: %v", err)
   222  		}
   223  
   224  		got, err := harness.appRead(len(req))
   225  		if err != nil {
   226  			t.Fatalf("failed to read from app: %v", err)
   227  		}
   228  
   229  		if string(got) != req {
   230  			t.Fatalf("app mismatch: got: %s want: %s", string(got), req)
   231  		}
   232  
   233  		if _, err := harness.appWrite([]byte(resp)); err != nil {
   234  			t.Fatalf("failed to write to app: %v", err)
   235  		}
   236  
   237  		got, err = harness.shimRead(len(resp))
   238  		if err != nil {
   239  			t.Fatalf("failed to read from shim: %v", err)
   240  		}
   241  		if string(got) != resp {
   242  			t.Fatalf("shim mismatch: got: %s want: %s", string(got), resp)
   243  		}
   244  	}
   245  }