gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/runsc/boot/portforward/portforward_netstack_test.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package portforward
    16  
    17  import (
    18  	"bytes"
    19  	"io"
    20  	"sync"
    21  	"testing"
    22  
    23  	"gvisor.dev/gvisor/pkg/sentry/contexttest"
    24  	"gvisor.dev/gvisor/pkg/tcpip"
    25  	"gvisor.dev/gvisor/pkg/waiter"
    26  )
    27  
    28  type baseTCPEndpointImpl struct {
    29  	closed   bool
    30  	readBuf  bytes.Buffer
    31  	writeBuf bytes.Buffer
    32  	mu       sync.Mutex
    33  }
    34  
    35  // read reads data from the buffer that "Write" writes to.
    36  func (b *baseTCPEndpointImpl) read(n int) ([]byte, error) {
    37  	b.mu.Lock()
    38  	defer b.mu.Unlock()
    39  	if b.closed {
    40  		return nil, io.EOF
    41  	}
    42  	ret := b.writeBuf.Next(n)
    43  	return ret, nil
    44  }
    45  
    46  // write writes data to the read buffer that "Read" reads from.
    47  func (b *baseTCPEndpointImpl) write(buf []byte) (int, error) {
    48  	b.mu.Lock()
    49  	defer b.mu.Unlock()
    50  	if b.closed {
    51  		return 0, io.EOF
    52  	}
    53  	n, err := b.readBuf.Write(buf)
    54  	return n, err
    55  }
    56  
    57  func (b *baseTCPEndpointImpl) Close() {
    58  	b.mu.Lock()
    59  	defer b.mu.Unlock()
    60  	b.closed = true
    61  }
    62  
    63  func (b *baseTCPEndpointImpl) Read(w io.Writer, _ tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
    64  	b.mu.Lock()
    65  	defer b.mu.Unlock()
    66  	if b.closed {
    67  		return tcpip.ReadResult{}, &tcpip.ErrClosedForReceive{}
    68  	}
    69  	buf := b.readBuf.Next(b.readBuf.Len())
    70  	n, err := w.Write(buf)
    71  	if err != nil {
    72  		return tcpip.ReadResult{}, &tcpip.ErrInvalidEndpointState{}
    73  	}
    74  	return tcpip.ReadResult{
    75  		Count: n,
    76  		Total: n,
    77  	}, nil
    78  }
    79  
    80  func (b *baseTCPEndpointImpl) Write(payload tcpip.Payloader, _ tcpip.WriteOptions) (int64, tcpip.Error) {
    81  	b.mu.Lock()
    82  	defer b.mu.Unlock()
    83  	if b.closed {
    84  		return 0, &tcpip.ErrClosedForSend{}
    85  	}
    86  	buf := make([]byte, payload.Len())
    87  	n, err := payload.Read(buf)
    88  	if err != nil {
    89  		return 0, &tcpip.ErrInvalidEndpointState{}
    90  	}
    91  	n, err = b.writeBuf.Write(buf[:n])
    92  	if err != nil {
    93  		return int64(n), &tcpip.ErrConnectionRefused{}
    94  	}
    95  	return int64(n), nil
    96  }
    97  
    98  func (b *baseTCPEndpointImpl) Shutdown(shutdown tcpip.ShutdownFlags) tcpip.Error {
    99  	b.mu.Lock()
   100  	defer b.mu.Unlock()
   101  	b.closed = true
   102  	return nil
   103  }
   104  
   105  func TestNetstackProxy(t *testing.T) {
   106  	for _, tc := range []struct {
   107  		name     string
   108  		requests map[string]string
   109  	}{
   110  		{
   111  			name: "single",
   112  			requests: map[string]string{
   113  				"PING": "PONG",
   114  			},
   115  		},
   116  		{
   117  			name: "multiple",
   118  			requests: map[string]string{
   119  				"PING":       "PONG",
   120  				"HELLO":      "GOODBYE",
   121  				"IMPRESSIVE": "MOST IMPRESSIVE",
   122  			},
   123  		},
   124  		{
   125  			name: "empty",
   126  			requests: map[string]string{
   127  				"EMPTY":       "",
   128  				"NOT":         "EMPTY",
   129  				"OTHER EMPTY": "",
   130  			},
   131  		},
   132  	} {
   133  		t.Run(tc.name, func(t *testing.T) {
   134  			doNetstackTest(t, tc.name, tc.requests)
   135  		})
   136  	}
   137  }
   138  
   139  func doNetstackTest(t *testing.T, name string, responses map[string]string) {
   140  	ctx := contexttest.Context(t)
   141  	appEndpoint := newMockApplicationFDImpl()
   142  	fd, err := newMockFileDescription(ctx, appEndpoint)
   143  	if err != nil {
   144  		t.Fatalf("newMockFileDescription: %v", err)
   145  	}
   146  
   147  	wq := &waiter.Queue{}
   148  	impl := &baseTCPEndpointImpl{}
   149  	ep := newMockTCPEndpoint(impl, wq)
   150  	sock := &netstackConn{
   151  		ep: ep,
   152  		wq: wq,
   153  	}
   154  
   155  	proxy := NewProxy(ProxyPair{To: sock, From: &fileDescriptionConn{file: fd}}, name)
   156  	proxy.Start(ctx)
   157  	defer proxy.Close()
   158  
   159  	harness := portforwarderTestHarness{
   160  		app:  appEndpoint,
   161  		shim: impl,
   162  	}
   163  
   164  	for req, resp := range responses {
   165  		if _, err := harness.shimWrite([]byte(req)); err != nil {
   166  			t.Fatalf("failed to write to shim: %v", err)
   167  		}
   168  
   169  		got, err := harness.appRead(len(req))
   170  		if err != nil {
   171  			t.Fatalf("failed to read from app: %v", err)
   172  		}
   173  
   174  		if string(got) != req {
   175  			t.Fatalf("app mismatch: got: %s want: %s", string(got), req)
   176  		}
   177  
   178  		if _, err := harness.appWrite([]byte(resp)); err != nil {
   179  			t.Fatalf("failed to write to app: %v", err)
   180  		}
   181  
   182  		got, err = harness.shimRead(len(resp))
   183  		if err != nil {
   184  			t.Fatalf("failed to read from shim: %v", err)
   185  		}
   186  
   187  		if string(got) != resp {
   188  			t.Fatalf("shim mismatch: got: %s want: %s", string(got), resp)
   189  		}
   190  	}
   191  }
   192  
   193  // tcpErrImpl blocks on the first Read/Write and then throws an error afterwards.
   194  type tcpErrImpl struct {
   195  	mu     sync.Mutex
   196  	reads  bool
   197  	writes bool
   198  }
   199  
   200  // Read implements mockTCPEndpointImpl.Read.
   201  func (e *tcpErrImpl) Read(w io.Writer, _ tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
   202  	e.mu.Lock()
   203  	defer e.mu.Unlock()
   204  	if e.reads {
   205  		return tcpip.ReadResult{}, &tcpip.ErrBadLocalAddress{}
   206  	}
   207  	e.reads = true
   208  	return tcpip.ReadResult{}, &tcpip.ErrWouldBlock{}
   209  }
   210  
   211  // Write implements mockTCPEndpointImpl.Write.
   212  func (e *tcpErrImpl) Write(payload tcpip.Payloader, _ tcpip.WriteOptions) (int64, tcpip.Error) {
   213  	e.mu.Lock()
   214  	defer e.mu.Unlock()
   215  	if e.writes {
   216  		return 0, &tcpip.ErrBadLocalAddress{}
   217  	}
   218  	e.writes = true
   219  	return 0, &tcpip.ErrWouldBlock{}
   220  }
   221  
   222  // Shutdown implements mockTCPEndpointImpl.Shutdown.
   223  func (e *tcpErrImpl) Shutdown(shutdown tcpip.ShutdownFlags) tcpip.Error {
   224  	return nil
   225  }
   226  
   227  // Close implements mockTCPEndpointImpl.Shutdown.
   228  func (e *tcpErrImpl) Close() {}
   229  
   230  // TestNTestNestackReadsWrites checks that reads/writes check errors from the underlying endpoint
   231  // multiple times.
   232  func TestNestackReadsWrites(t *testing.T) {
   233  	ctx := contexttest.Context(t)
   234  	wq := &waiter.Queue{}
   235  	ep := newMockTCPEndpoint(&tcpErrImpl{}, wq)
   236  	cancel := make(chan struct{})
   237  	conn := netstackConn{ep: ep, wq: wq}
   238  	defer close(cancel)
   239  	defer conn.Close(ctx)
   240  
   241  	_, err := conn.Read(ctx, []byte("something"), cancel)
   242  	if err != io.EOF {
   243  		t.Fatalf("mismatch read err: want: %v got: %v", io.EOF, err)
   244  	}
   245  
   246  	_, err = conn.Write(ctx, []byte("something"), cancel)
   247  	if err != io.EOF {
   248  		t.Fatalf("mismatch write err: want: %v got: %v", io.EOF, err)
   249  	}
   250  }