gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/runsc/boot/portforward/portforward_test_util.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package portforward
    16  
    17  import (
    18  	"bytes"
    19  	"io"
    20  	"sync"
    21  	"time"
    22  
    23  	"gvisor.dev/gvisor/pkg/context"
    24  	"gvisor.dev/gvisor/pkg/errors/linuxerr"
    25  	"gvisor.dev/gvisor/pkg/sentry/vfs"
    26  	"gvisor.dev/gvisor/pkg/tcpip"
    27  	"gvisor.dev/gvisor/pkg/usermem"
    28  	"gvisor.dev/gvisor/pkg/waiter"
    29  )
    30  
    31  // mockEndpoint defines an endpoint that tests can read and write for validating portforwarders.
    32  type mockEndpoint interface {
    33  	read(n int) ([]byte, error)
    34  	write(buf []byte) (int, error)
    35  }
    36  
    37  // portforwarderTestHarness mocks both sides of the portforwarder connection so that behavior can be
    38  // validated between them.
    39  type portforwarderTestHarness struct {
    40  	app  mockEndpoint
    41  	shim mockEndpoint
    42  }
    43  
    44  func (th *portforwarderTestHarness) appWrite(buf []byte) (int, error) {
    45  	return th.app.write(buf)
    46  }
    47  
    48  func (th *portforwarderTestHarness) appRead(n int) ([]byte, error) {
    49  	return th.doRead(n, th.app)
    50  }
    51  
    52  func (th *portforwarderTestHarness) shimWrite(buf []byte) (int, error) {
    53  	return th.shim.write(buf)
    54  }
    55  
    56  func (th *portforwarderTestHarness) shimRead(n int) ([]byte, error) {
    57  	return th.doRead(n, th.shim)
    58  }
    59  
    60  func (th *portforwarderTestHarness) doRead(n int, ep mockEndpoint) ([]byte, error) {
    61  	buf := make([]byte, 0, n)
    62  	for {
    63  		out, err := ep.read(n - len(buf))
    64  		if err != nil && !linuxerr.Equals(linuxerr.ErrWouldBlock, err) {
    65  			return nil, err
    66  		}
    67  		buf = append(buf, out...)
    68  		if len(buf) >= n {
    69  			return buf, nil
    70  		}
    71  	}
    72  }
    73  
    74  // mockApplicationFDImpl mocks a VFS file description endpoint on which the sandboxed application
    75  // and the portforwarder will communicate.
    76  type mockApplicationFDImpl struct {
    77  	vfs.FileDescriptionDefaultImpl
    78  	vfs.NoLockFD
    79  	vfs.DentryMetadataFileDescriptionImpl
    80  	mu         sync.Mutex
    81  	readBuf    bytes.Buffer
    82  	writeBuf   bytes.Buffer
    83  	released   bool
    84  	queue      waiter.Queue
    85  	notifyStop chan struct{}
    86  }
    87  
    88  var _ vfs.FileDescriptionImpl = (*mockApplicationFDImpl)(nil)
    89  
    90  func newMockApplicationFDImpl() *mockApplicationFDImpl {
    91  	app := &mockApplicationFDImpl{notifyStop: make(chan struct{})}
    92  	go app.doNotify()
    93  	return app
    94  }
    95  
    96  // Read implements vfs.FileDescriptionImpl.Read details for the parent mockFileDescription.
    97  func (s *mockApplicationFDImpl) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
    98  	s.mu.Lock()
    99  	defer s.mu.Unlock()
   100  	if s.released {
   101  		return 0, io.EOF
   102  	}
   103  	if s.readBuf.Len() == 0 {
   104  		return 0, linuxerr.ErrWouldBlock
   105  	}
   106  	buf := s.readBuf.Next(s.readBuf.Len())
   107  	n, err := dst.CopyOut(ctx, buf)
   108  	return int64(n), err
   109  }
   110  
   111  // Write implements vfs.FileDescriptionImpl.Write details for the parent mockFileDescription.
   112  func (s *mockApplicationFDImpl) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
   113  	s.mu.Lock()
   114  	defer s.mu.Unlock()
   115  	if s.released {
   116  		return 0, io.EOF
   117  	}
   118  
   119  	buf := make([]byte, src.NumBytes())
   120  	n, _ := src.CopyIn(ctx, buf)
   121  	res, _ := s.writeBuf.Write(buf[:n])
   122  	return int64(res), nil
   123  }
   124  
   125  // write implements mockEndpoint.write.
   126  func (s *mockApplicationFDImpl) write(buf []byte) (int, error) {
   127  	s.mu.Lock()
   128  	defer s.mu.Unlock()
   129  	if s.released {
   130  		return 0, io.EOF
   131  	}
   132  	ret, err := s.readBuf.Write(buf)
   133  	return ret, err
   134  }
   135  
   136  // read implements mockEndpoint.read.
   137  func (s *mockApplicationFDImpl) read(n int) ([]byte, error) {
   138  	s.mu.Lock()
   139  	defer s.mu.Unlock()
   140  	if s.released {
   141  		return nil, io.EOF
   142  	}
   143  	if s.writeBuf.Len() == 0 {
   144  		return nil, linuxerr.ErrWouldBlock
   145  	}
   146  	ret := s.writeBuf.Next(n)
   147  	return ret, nil
   148  }
   149  
   150  func (s *mockApplicationFDImpl) doNotify() {
   151  	for {
   152  		s.queue.Notify(waiter.ReadableEvents | waiter.WritableEvents | waiter.EventHUp)
   153  		select {
   154  		case <-s.notifyStop:
   155  			return
   156  		default:
   157  			time.Sleep(time.Millisecond * 50)
   158  		}
   159  	}
   160  }
   161  
   162  func (s *mockApplicationFDImpl) IsReadable() bool {
   163  	s.mu.Lock()
   164  	defer s.mu.Unlock()
   165  	if s.released {
   166  		return false
   167  	}
   168  	return s.readBuf.Len() > 0
   169  }
   170  
   171  func (s *mockApplicationFDImpl) IsWritable() bool {
   172  	s.mu.Lock()
   173  	defer s.mu.Unlock()
   174  	return !s.released
   175  }
   176  
   177  // EventRegister implements vfs.FileDescriptionImpl.EventRegister details for the parent mockFileDescription.
   178  func (s *mockApplicationFDImpl) EventRegister(we *waiter.Entry) error {
   179  	s.mu.Lock()
   180  	defer s.mu.Unlock()
   181  	s.queue.EventRegister(we)
   182  	return nil
   183  }
   184  
   185  // EventUnregister implements vfs.FileDescriptionImpl.Unregister details for the parent mockFileDescription.
   186  func (s *mockApplicationFDImpl) EventUnregister(we *waiter.Entry) {
   187  	s.mu.Lock()
   188  	defer s.mu.Unlock()
   189  	s.queue.EventUnregister(we)
   190  }
   191  
   192  // Release implements vfs.FileDescriptionImpl.Release details for the parent mockFileDescription.
   193  func (s *mockApplicationFDImpl) Release(context.Context) {
   194  	s.mu.Lock()
   195  	defer s.mu.Unlock()
   196  	s.released = true
   197  	s.notifyStop <- struct{}{}
   198  }
   199  
   200  // mockTCPEndpointImpl is the subset of methods used by tests for the mockTCPEndpoint struct. This
   201  // is so we can quickly change implementations as needed.
   202  type mockTCPEndpointImpl interface {
   203  	Close()
   204  	Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error)
   205  	Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error)
   206  	Shutdown(tcpip.ShutdownFlags) tcpip.Error
   207  }
   208  
   209  // mockTCPEndpoint mocks tcpip.Endpoint for tests.
   210  type mockTCPEndpoint struct {
   211  	impl       mockTCPEndpointImpl // impl implements the subset of methods needed for mockTCPEndpoints.
   212  	wq         *waiter.Queue
   213  	notifyDone chan struct{}
   214  }
   215  
   216  func newMockTCPEndpoint(impl mockTCPEndpointImpl, wq *waiter.Queue) *mockTCPEndpoint {
   217  	ret := &mockTCPEndpoint{
   218  		impl:       impl,
   219  		wq:         wq,
   220  		notifyDone: make(chan struct{}),
   221  	}
   222  
   223  	go ret.doNotify()
   224  	return ret
   225  }
   226  
   227  func (m *mockTCPEndpoint) doNotify() {
   228  	for {
   229  		m.wq.Notify(waiter.ReadableEvents | waiter.WritableEvents | waiter.EventHUp)
   230  		select {
   231  		case <-m.notifyDone:
   232  			return
   233  		default:
   234  			time.Sleep(time.Millisecond * 50)
   235  		}
   236  
   237  	}
   238  }
   239  
   240  // The below are trivial stub methods to get mockTCPEndpoint to implement tcpip.Endpoint. They
   241  // either panic or call the contained impl's methods.
   242  
   243  // Close implements tcpip.Endpoint.Close.
   244  func (m *mockTCPEndpoint) Close() {
   245  	m.impl.Close()
   246  	m.notifyDone <- struct{}{}
   247  }
   248  
   249  // Abort implements tcpip.Endpoint.Abort.
   250  func (m *mockTCPEndpoint) Abort() {
   251  	m.panicWithNotImplementedMsg()
   252  }
   253  
   254  // Read implements tcpip.Endpoint.Read.
   255  func (m *mockTCPEndpoint) Read(w io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
   256  	return m.impl.Read(w, opts)
   257  }
   258  
   259  // Write implements tcpip.Endpoint.Write.
   260  func (m *mockTCPEndpoint) Write(payload tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   261  	return m.impl.Write(payload, opts)
   262  }
   263  
   264  // Connect implements tcpip.Endpoint.Connect.
   265  func (m *mockTCPEndpoint) Connect(address tcpip.FullAddress) tcpip.Error {
   266  	m.panicWithNotImplementedMsg()
   267  	return nil
   268  }
   269  
   270  // Disconnect implements tcpip.Endpoint.Disconnect.
   271  func (m *mockTCPEndpoint) Disconnect() tcpip.Error {
   272  	m.panicWithNotImplementedMsg()
   273  	return nil
   274  }
   275  
   276  // Shutdown implements tcpip.Endpoint.Shutdown.
   277  func (m *mockTCPEndpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
   278  	return m.impl.Shutdown(flags)
   279  }
   280  
   281  // Listen implements tcpip.Endpoint.Listen.
   282  func (m *mockTCPEndpoint) Listen(backlog int) tcpip.Error {
   283  	m.panicWithNotImplementedMsg()
   284  	return nil
   285  }
   286  
   287  // Accept implements tcpip.Endpoint.Accept.
   288  func (m *mockTCPEndpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
   289  	m.panicWithNotImplementedMsg()
   290  	return nil, nil, nil
   291  }
   292  
   293  // Bind implements tcpip.Endpoint.Bind.
   294  func (m *mockTCPEndpoint) Bind(address tcpip.FullAddress) tcpip.Error {
   295  	m.panicWithNotImplementedMsg()
   296  	return nil
   297  }
   298  
   299  // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
   300  func (m mockTCPEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
   301  	m.panicWithNotImplementedMsg()
   302  	return tcpip.FullAddress{}, nil
   303  }
   304  
   305  // GetRemoteAddress implements tcpip.Endpoint.GetRemoreAddress.
   306  func (m *mockTCPEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
   307  	m.panicWithNotImplementedMsg()
   308  	return tcpip.FullAddress{}, nil
   309  }
   310  
   311  // Readiness implements tcpip.Endpoint.Readiness.
   312  func (m *mockTCPEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
   313  	m.panicWithNotImplementedMsg()
   314  	return 0
   315  }
   316  
   317  // SetSockOpt implements tcpip.Endpoint.SetSockOpt.
   318  func (m *mockTCPEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
   319  	m.panicWithNotImplementedMsg()
   320  	return nil
   321  }
   322  
   323  // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
   324  func (m *mockTCPEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
   325  	m.panicWithNotImplementedMsg()
   326  	return nil
   327  }
   328  
   329  // GetSockOpt implements tcpip.Endpoint.GetSockOpt.
   330  func (m *mockTCPEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
   331  	m.panicWithNotImplementedMsg()
   332  	return nil
   333  }
   334  
   335  // GetSockOptInt implements tcpip.Endpoint.GetSockOpt.
   336  func (m *mockTCPEndpoint) GetSockOptInt(tcpip.SockOptInt) (int, tcpip.Error) {
   337  	m.panicWithNotImplementedMsg()
   338  	return 0, nil
   339  }
   340  
   341  // State implements tcpip.Endpoint.State.
   342  func (m *mockTCPEndpoint) State() uint32 {
   343  	m.panicWithNotImplementedMsg()
   344  	return 0
   345  }
   346  
   347  // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf
   348  func (m *mockTCPEndpoint) ModerateRecvBuf(copied int) {
   349  	m.panicWithNotImplementedMsg()
   350  }
   351  
   352  // Info implements tcpip.Endpoint.Info.
   353  func (m *mockTCPEndpoint) Info() tcpip.EndpointInfo {
   354  	m.panicWithNotImplementedMsg()
   355  	return nil
   356  }
   357  
   358  // Stats implements tcpip.Endpoint.Stats.
   359  func (m *mockTCPEndpoint) Stats() tcpip.EndpointStats {
   360  	m.panicWithNotImplementedMsg()
   361  	return nil
   362  }
   363  
   364  // SetOwner implements tcpip.Endpoint.SetOwner.
   365  func (m *mockTCPEndpoint) SetOwner(owner tcpip.PacketOwner) {
   366  	m.panicWithNotImplementedMsg()
   367  }
   368  
   369  // LastError implements tcpip.Endpoint.LastError.
   370  func (m *mockTCPEndpoint) LastError() tcpip.Error {
   371  	m.panicWithNotImplementedMsg()
   372  	return nil
   373  }
   374  
   375  // SocketOptions implements tcpip.Endpoint.SocketOptions.
   376  func (m *mockTCPEndpoint) SocketOptions() *tcpip.SocketOptions {
   377  	m.panicWithNotImplementedMsg()
   378  	return nil
   379  }
   380  
   381  func (*mockTCPEndpoint) panicWithNotImplementedMsg() { panic("not implemented") }