github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/urpc/urpc_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package urpc
    16  
    17  import (
    18  	"errors"
    19  	"os"
    20  	"testing"
    21  
    22  	"github.com/SagerNet/gvisor/pkg/unet"
    23  )
    24  
    25  type test struct {
    26  }
    27  
    28  type testArg struct {
    29  	StringArg string
    30  	IntArg    int
    31  	FilePayload
    32  }
    33  
    34  type testResult struct {
    35  	StringResult string
    36  	IntResult    int
    37  	FilePayload
    38  }
    39  
    40  func (t test) Func(a *testArg, r *testResult) error {
    41  	r.StringResult = a.StringArg
    42  	r.IntResult = a.IntArg
    43  	return nil
    44  }
    45  
    46  func (t test) Err(a *testArg, r *testResult) error {
    47  	return errors.New("test error")
    48  }
    49  
    50  func (t test) FailNoFile(a *testArg, r *testResult) error {
    51  	if a.Files == nil {
    52  		return errors.New("no file found")
    53  	}
    54  
    55  	return nil
    56  }
    57  
    58  func (t test) SendFile(a *testArg, r *testResult) error {
    59  	r.Files = []*os.File{os.Stdin, os.Stdout, os.Stderr}
    60  	return nil
    61  }
    62  
    63  func (t test) TooManyFiles(a *testArg, r *testResult) error {
    64  	for i := 0; i <= maxFiles; i++ {
    65  		r.Files = append(r.Files, os.Stdin)
    66  	}
    67  	return nil
    68  }
    69  
    70  func startServer(socket *unet.Socket) {
    71  	s := NewServer()
    72  	s.Register(test{})
    73  	s.StartHandling(socket)
    74  }
    75  
    76  func testClient() (*Client, error) {
    77  	serverSock, clientSock, err := unet.SocketPair(false)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	startServer(serverSock)
    82  
    83  	return NewClient(clientSock), nil
    84  }
    85  
    86  func TestCall(t *testing.T) {
    87  	c, err := testClient()
    88  	if err != nil {
    89  		t.Fatalf("error creating test client: %v", err)
    90  	}
    91  	defer c.Close()
    92  
    93  	var r testResult
    94  	if err := c.Call("test.Func", &testArg{}, &r); err != nil {
    95  		t.Errorf("basic call failed: %v", err)
    96  	} else if r.StringResult != "" || r.IntResult != 0 {
    97  		t.Errorf("unexpected result, got %v expected zero value", r)
    98  	}
    99  	if err := c.Call("test.Func", &testArg{StringArg: "hello"}, &r); err != nil {
   100  		t.Errorf("basic call failed: %v", err)
   101  	} else if r.StringResult != "hello" {
   102  		t.Errorf("unexpected result, got %v expected hello", r.StringResult)
   103  	}
   104  	if err := c.Call("test.Func", &testArg{IntArg: 1}, &r); err != nil {
   105  		t.Errorf("basic call failed: %v", err)
   106  	} else if r.IntResult != 1 {
   107  		t.Errorf("unexpected result, got %v expected 1", r.IntResult)
   108  	}
   109  }
   110  
   111  func TestUnknownMethod(t *testing.T) {
   112  	c, err := testClient()
   113  	if err != nil {
   114  		t.Fatalf("error creating test client: %v", err)
   115  	}
   116  	defer c.Close()
   117  
   118  	var r testResult
   119  	if err := c.Call("test.Unknown", &testArg{}, &r); err == nil {
   120  		t.Errorf("expected non-nil err, got nil")
   121  	} else if err.Error() != ErrUnknownMethod.Error() {
   122  		t.Errorf("expected test error, got %v", err)
   123  	}
   124  }
   125  
   126  func TestErr(t *testing.T) {
   127  	c, err := testClient()
   128  	if err != nil {
   129  		t.Fatalf("error creating test client: %v", err)
   130  	}
   131  	defer c.Close()
   132  
   133  	var r testResult
   134  	if err := c.Call("test.Err", &testArg{}, &r); err == nil {
   135  		t.Errorf("expected non-nil err, got nil")
   136  	} else if err.Error() != "test error" {
   137  		t.Errorf("expected test error, got %v", err)
   138  	}
   139  }
   140  
   141  func TestSendFile(t *testing.T) {
   142  	c, err := testClient()
   143  	if err != nil {
   144  		t.Fatalf("error creating test client: %v", err)
   145  	}
   146  	defer c.Close()
   147  
   148  	var r testResult
   149  	if err := c.Call("test.FailNoFile", &testArg{}, &r); err == nil {
   150  		t.Errorf("expected non-nil err, got nil")
   151  	}
   152  	if err := c.Call("test.FailNoFile", &testArg{FilePayload: FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stdin}}}, &r); err != nil {
   153  		t.Errorf("expected nil err, got %v", err)
   154  	}
   155  }
   156  
   157  func TestRecvFile(t *testing.T) {
   158  	c, err := testClient()
   159  	if err != nil {
   160  		t.Fatalf("error creating test client: %v", err)
   161  	}
   162  	defer c.Close()
   163  
   164  	var r testResult
   165  	if err := c.Call("test.SendFile", &testArg{}, &r); err != nil {
   166  		t.Errorf("expected nil err, got %v", err)
   167  	}
   168  	if r.Files == nil {
   169  		t.Errorf("expected file, got nil")
   170  	}
   171  }
   172  
   173  func TestShutdown(t *testing.T) {
   174  	serverSock, clientSock, err := unet.SocketPair(false)
   175  	if err != nil {
   176  		t.Fatalf("error creating test client: %v", err)
   177  	}
   178  	clientSock.Close()
   179  
   180  	s := NewServer()
   181  	if err := s.Handle(serverSock); err == nil {
   182  		t.Errorf("expected non-nil err, got nil")
   183  	}
   184  }
   185  
   186  func TestTooManyFiles(t *testing.T) {
   187  	c, err := testClient()
   188  	if err != nil {
   189  		t.Fatalf("error creating test client: %v", err)
   190  	}
   191  	defer c.Close()
   192  
   193  	var r testResult
   194  	var a testArg
   195  	for i := 0; i <= maxFiles; i++ {
   196  		a.Files = append(a.Files, os.Stdin)
   197  	}
   198  
   199  	// Client-side error.
   200  	if err := c.Call("test.Func", &a, &r); err != ErrTooManyFiles {
   201  		t.Errorf("expected ErrTooManyFiles, got %v", err)
   202  	}
   203  
   204  	// Server-side error.
   205  	if err := c.Call("test.TooManyFiles", &testArg{}, &r); err == nil {
   206  		t.Errorf("expected non-nil err, got nil")
   207  	} else if err.Error() != "too many files" {
   208  		t.Errorf("expected too many files, got %v", err.Error())
   209  	}
   210  }