github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/p9/transport_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package p9
    16  
    17  import (
    18  	"io/ioutil"
    19  	"os"
    20  	"testing"
    21  
    22  	"github.com/SagerNet/gvisor/pkg/fd"
    23  	"github.com/SagerNet/gvisor/pkg/unet"
    24  )
    25  
    26  const (
    27  	MsgTypeBadEncode = iota + 252
    28  	MsgTypeBadDecode
    29  	MsgTypeUnregistered
    30  )
    31  
    32  func TestSendRecv(t *testing.T) {
    33  	server, client, err := unet.SocketPair(false)
    34  	if err != nil {
    35  		t.Fatalf("socketpair got err %v expected nil", err)
    36  	}
    37  	defer server.Close()
    38  	defer client.Close()
    39  
    40  	if err := send(client, Tag(1), &Tlopen{}); err != nil {
    41  		t.Fatalf("send got err %v expected nil", err)
    42  	}
    43  
    44  	tag, m, err := recv(server, maximumLength, msgRegistry.get)
    45  	if err != nil {
    46  		t.Fatalf("recv got err %v expected nil", err)
    47  	}
    48  	if tag != Tag(1) {
    49  		t.Fatalf("got tag %v expected 1", tag)
    50  	}
    51  	if _, ok := m.(*Tlopen); !ok {
    52  		t.Fatalf("got message %v expected *Tlopen", m)
    53  	}
    54  }
    55  
    56  // badDecode overruns on decode.
    57  type badDecode struct{}
    58  
    59  func (*badDecode) decode(b *buffer) { b.markOverrun() }
    60  func (*badDecode) encode(b *buffer) {}
    61  func (*badDecode) Type() MsgType    { return MsgTypeBadDecode }
    62  func (*badDecode) String() string   { return "badDecode{}" }
    63  
    64  func TestRecvOverrun(t *testing.T) {
    65  	server, client, err := unet.SocketPair(false)
    66  	if err != nil {
    67  		t.Fatalf("socketpair got err %v expected nil", err)
    68  	}
    69  	defer server.Close()
    70  	defer client.Close()
    71  
    72  	if err := send(client, Tag(1), &badDecode{}); err != nil {
    73  		t.Fatalf("send got err %v expected nil", err)
    74  	}
    75  
    76  	if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil {
    77  		t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err)
    78  	}
    79  }
    80  
    81  // unregistered is not registered on decode.
    82  type unregistered struct{}
    83  
    84  func (*unregistered) decode(b *buffer) {}
    85  func (*unregistered) encode(b *buffer) {}
    86  func (*unregistered) Type() MsgType    { return MsgTypeUnregistered }
    87  func (*unregistered) String() string   { return "unregistered{}" }
    88  
    89  func TestRecvInvalidType(t *testing.T) {
    90  	server, client, err := unet.SocketPair(false)
    91  	if err != nil {
    92  		t.Fatalf("socketpair got err %v expected nil", err)
    93  	}
    94  	defer server.Close()
    95  	defer client.Close()
    96  
    97  	if err := send(client, Tag(1), &unregistered{}); err != nil {
    98  		t.Fatalf("send got err %v expected nil", err)
    99  	}
   100  
   101  	_, _, err = recv(server, maximumLength, msgRegistry.get)
   102  	if _, ok := err.(*ErrInvalidMsgType); !ok {
   103  		t.Fatalf("recv got err %v expected ErrInvalidMsgType", err)
   104  	}
   105  }
   106  
   107  func TestSendRecvWithFile(t *testing.T) {
   108  	server, client, err := unet.SocketPair(false)
   109  	if err != nil {
   110  		t.Fatalf("socketpair got err %v expected nil", err)
   111  	}
   112  	defer server.Close()
   113  	defer client.Close()
   114  
   115  	// Create a tempfile.
   116  	osf, err := ioutil.TempFile("", "p9")
   117  	if err != nil {
   118  		t.Fatalf("tempfile got err %v expected nil", err)
   119  	}
   120  	os.Remove(osf.Name())
   121  	f, err := fd.NewFromFile(osf)
   122  	osf.Close()
   123  	if err != nil {
   124  		t.Fatalf("unable to create file: %v", err)
   125  	}
   126  
   127  	rlopen := &Rlopen{}
   128  	rlopen.SetFilePayload(f)
   129  	if err := send(client, Tag(1), rlopen); err != nil {
   130  		t.Fatalf("send got err %v expected nil", err)
   131  	}
   132  
   133  	// Enable withFile.
   134  	tag, m, err := recv(server, maximumLength, msgRegistry.get)
   135  	if err != nil {
   136  		t.Fatalf("recv got err %v expected nil", err)
   137  	}
   138  	if tag != Tag(1) {
   139  		t.Fatalf("got tag %v expected 1", tag)
   140  	}
   141  	rlopen, ok := m.(*Rlopen)
   142  	if !ok {
   143  		t.Fatalf("got m %v expected *Rlopen", m)
   144  	}
   145  	if rlopen.File == nil {
   146  		t.Fatalf("got nil file expected non-nil")
   147  	}
   148  }
   149  
   150  func TestRecvClosed(t *testing.T) {
   151  	server, client, err := unet.SocketPair(false)
   152  	if err != nil {
   153  		t.Fatalf("socketpair got err %v expected nil", err)
   154  	}
   155  	defer server.Close()
   156  	client.Close()
   157  
   158  	_, _, err = recv(server, maximumLength, msgRegistry.get)
   159  	if err == nil {
   160  		t.Fatalf("got err nil expected non-nil")
   161  	}
   162  	if _, ok := err.(ErrSocket); !ok {
   163  		t.Fatalf("got err %v expected ErrSocket", err)
   164  	}
   165  }
   166  
   167  func TestSendClosed(t *testing.T) {
   168  	server, client, err := unet.SocketPair(false)
   169  	if err != nil {
   170  		t.Fatalf("socketpair got err %v expected nil", err)
   171  	}
   172  	server.Close()
   173  	defer client.Close()
   174  
   175  	err = send(client, Tag(1), &Tlopen{})
   176  	if err == nil {
   177  		t.Fatalf("send got err nil expected non-nil")
   178  	}
   179  	if _, ok := err.(ErrSocket); !ok {
   180  		t.Fatalf("got err %v expected ErrSocket", err)
   181  	}
   182  }
   183  
   184  func BenchmarkSendRecv(b *testing.B) {
   185  	b.ReportAllocs()
   186  
   187  	server, client, err := unet.SocketPair(false)
   188  	if err != nil {
   189  		b.Fatalf("socketpair got err %v expected nil", err)
   190  	}
   191  	defer server.Close()
   192  	defer client.Close()
   193  
   194  	// Exchange Rflush messages since these contain no data and therefore incur
   195  	// no additional marshaling overhead.
   196  	go func() {
   197  		for i := 0; i < b.N; i++ {
   198  			tag, m, err := recv(server, maximumLength, msgRegistry.get)
   199  			if err != nil {
   200  				b.Errorf("recv got err %v expected nil", err)
   201  			}
   202  			if tag != Tag(1) {
   203  				b.Errorf("got tag %v expected 1", tag)
   204  			}
   205  			if _, ok := m.(*Rflush); !ok {
   206  				b.Errorf("got message %T expected *Rflush", m)
   207  			}
   208  			if err := send(server, Tag(2), &Rflush{}); err != nil {
   209  				b.Errorf("send got err %v expected nil", err)
   210  			}
   211  		}
   212  	}()
   213  	b.ResetTimer()
   214  	for i := 0; i < b.N; i++ {
   215  		if err := send(client, Tag(1), &Rflush{}); err != nil {
   216  			b.Errorf("send got err %v expected nil", err)
   217  		}
   218  		tag, m, err := recv(client, maximumLength, msgRegistry.get)
   219  		if err != nil {
   220  			b.Errorf("recv got err %v expected nil", err)
   221  		}
   222  		if tag != Tag(2) {
   223  			b.Errorf("got tag %v expected 2", tag)
   224  		}
   225  		if _, ok := m.(*Rflush); !ok {
   226  			b.Errorf("got message %v expected *Rflush", m)
   227  		}
   228  	}
   229  }
   230  
   231  func init() {
   232  	msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} })
   233  }