storj.io/minio@v0.0.0-20230509071714-0cbc90f649b1/pkg/rpc/server_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Copyright 2012 The Gorilla Authors. All rights reserved.
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  // Copyright 2020 MinIO, Inc. All rights reserved.
     7  // forked from https://github.com/gorilla/rpc/v2
     8  // modified to be used with MinIO under Apache
     9  // 2.0 license that can be found in the LICENSE file.
    10  
    11  package rpc
    12  
    13  import (
    14  	"errors"
    15  	"net/http"
    16  	"strconv"
    17  	"testing"
    18  )
    19  
    20  type Service1Request struct {
    21  	A int
    22  	B int
    23  }
    24  
    25  type Service1Response struct {
    26  	Result int
    27  }
    28  
    29  type Service1 struct {
    30  }
    31  
    32  func (t *Service1) Multiply(r *http.Request, req *Service1Request, res *Service1Response) error {
    33  	res.Result = req.A * req.B
    34  	return nil
    35  }
    36  
    37  type Service2 struct {
    38  }
    39  
    40  func TestRegisterService(t *testing.T) {
    41  	var err error
    42  	s := NewServer()
    43  	service1 := new(Service1)
    44  	service2 := new(Service2)
    45  
    46  	// Inferred name.
    47  	err = s.RegisterService(service1, "")
    48  	if err != nil || !s.HasMethod("Service1.Multiply") {
    49  		t.Errorf("Expected to be registered: Service1.Multiply")
    50  	}
    51  	// Provided name.
    52  	err = s.RegisterService(service1, "Foo")
    53  	if err != nil || !s.HasMethod("Foo.Multiply") {
    54  		t.Errorf("Expected to be registered: Foo.Multiply")
    55  	}
    56  	// No methods.
    57  	err = s.RegisterService(service2, "")
    58  	if err == nil {
    59  		t.Errorf("Expected error on service2")
    60  	}
    61  }
    62  
    63  // MockCodec decodes to Service1.Multiply.
    64  type MockCodec struct {
    65  	A, B int
    66  }
    67  
    68  func (c MockCodec) NewRequest(*http.Request) CodecRequest {
    69  	return MockCodecRequest{c.A, c.B}
    70  }
    71  
    72  type MockCodecRequest struct {
    73  	A, B int
    74  }
    75  
    76  func (r MockCodecRequest) Method() (string, error) {
    77  	return "Service1.Multiply", nil
    78  }
    79  
    80  func (r MockCodecRequest) ReadRequest(args interface{}) error {
    81  	req := args.(*Service1Request)
    82  	req.A, req.B = r.A, r.B
    83  	return nil
    84  }
    85  
    86  func (r MockCodecRequest) WriteResponse(w http.ResponseWriter, reply interface{}) {
    87  	res := reply.(*Service1Response)
    88  	w.Write([]byte(strconv.Itoa(res.Result)))
    89  }
    90  
    91  func (r MockCodecRequest) WriteError(w http.ResponseWriter, status int, err error) {
    92  	w.WriteHeader(status)
    93  	w.Write([]byte(err.Error()))
    94  }
    95  
    96  type MockResponseWriter struct {
    97  	header http.Header
    98  	Status int
    99  	Body   string
   100  }
   101  
   102  func NewMockResponseWriter() *MockResponseWriter {
   103  	header := make(http.Header)
   104  	return &MockResponseWriter{header: header}
   105  }
   106  
   107  func (w *MockResponseWriter) Header() http.Header {
   108  	return w.header
   109  }
   110  
   111  func (w *MockResponseWriter) Write(p []byte) (int, error) {
   112  	w.Body = string(p)
   113  	if w.Status == 0 {
   114  		w.Status = 200
   115  	}
   116  	return len(p), nil
   117  }
   118  
   119  func (w *MockResponseWriter) WriteHeader(status int) {
   120  	w.Status = status
   121  }
   122  
   123  func TestServeHTTP(t *testing.T) {
   124  	const (
   125  		A = 2
   126  		B = 3
   127  	)
   128  	expected := A * B
   129  
   130  	s := NewServer()
   131  	s.RegisterService(new(Service1), "")
   132  	s.RegisterCodec(MockCodec{A, B}, "mock")
   133  	r, err := http.NewRequest("POST", "", nil)
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  	r.Header.Set("Content-Type", "mock; dummy")
   138  	w := NewMockResponseWriter()
   139  	s.ServeHTTP(w, r)
   140  	if w.Status != 200 {
   141  		t.Errorf("Status was %d, should be 200.", w.Status)
   142  	}
   143  	if w.Body != strconv.Itoa(expected) {
   144  		t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expected))
   145  	}
   146  
   147  	// Test wrong Content-Type
   148  	r.Header.Set("Content-Type", "invalid")
   149  	w = NewMockResponseWriter()
   150  	s.ServeHTTP(w, r)
   151  	if w.Status != 415 {
   152  		t.Errorf("Status was %d, should be 415.", w.Status)
   153  	}
   154  	if w.Body != "rpc: unrecognized Content-Type: invalid" {
   155  		t.Errorf("Wrong response body.")
   156  	}
   157  
   158  	// Test omitted Content-Type; codec should default to the sole registered one.
   159  	r.Header.Del("Content-Type")
   160  	w = NewMockResponseWriter()
   161  	s.ServeHTTP(w, r)
   162  	if w.Status != 200 {
   163  		t.Errorf("Status was %d, should be 200.", w.Status)
   164  	}
   165  	if w.Body != strconv.Itoa(expected) {
   166  		t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expected))
   167  	}
   168  }
   169  
   170  func TestInterception(t *testing.T) {
   171  	const (
   172  		A = 2
   173  		B = 3
   174  	)
   175  	expected := A * B
   176  
   177  	r2, err := http.NewRequest("POST", "mocked/request", nil)
   178  	if err != nil {
   179  		t.Fatal(err)
   180  	}
   181  
   182  	s := NewServer()
   183  	s.RegisterService(new(Service1), "")
   184  	s.RegisterCodec(MockCodec{A, B}, "mock")
   185  	s.RegisterInterceptFunc(func(i *RequestInfo) *http.Request {
   186  		return r2
   187  	})
   188  	s.RegisterValidateRequestFunc(func(info *RequestInfo, v interface{}) error { return nil })
   189  	s.RegisterAfterFunc(func(i *RequestInfo) {
   190  		if i.Request != r2 {
   191  			t.Errorf("Request was %v, should be %v.", i.Request, r2)
   192  		}
   193  	})
   194  
   195  	r, err := http.NewRequest("POST", "", nil)
   196  	if err != nil {
   197  		t.Fatal(err)
   198  	}
   199  	r.Header.Set("Content-Type", "mock; dummy")
   200  	w := NewMockResponseWriter()
   201  	s.ServeHTTP(w, r)
   202  	if w.Status != 200 {
   203  		t.Errorf("Status was %d, should be 200.", w.Status)
   204  	}
   205  	if w.Body != strconv.Itoa(expected) {
   206  		t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expected))
   207  	}
   208  }
   209  func TestValidationSuccessful(t *testing.T) {
   210  	const (
   211  		A = 2
   212  		B = 3
   213  
   214  		expected = A * B
   215  	)
   216  
   217  	validate := func(info *RequestInfo, v interface{}) error { return nil }
   218  
   219  	s := NewServer()
   220  	s.RegisterService(new(Service1), "")
   221  	s.RegisterCodec(MockCodec{A, B}, "mock")
   222  	s.RegisterValidateRequestFunc(validate)
   223  
   224  	r, err := http.NewRequest("POST", "", nil)
   225  	if err != nil {
   226  		t.Fatal(err)
   227  	}
   228  	r.Header.Set("Content-Type", "mock; dummy")
   229  	w := NewMockResponseWriter()
   230  	s.ServeHTTP(w, r)
   231  	if w.Status != 200 {
   232  		t.Errorf("Status was %d, should be 200.", w.Status)
   233  	}
   234  	if w.Body != strconv.Itoa(expected) {
   235  		t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expected))
   236  	}
   237  }
   238  
   239  func TestValidationFails(t *testing.T) {
   240  	const expected = "this instance only supports zero values"
   241  
   242  	validate := func(r *RequestInfo, v interface{}) error {
   243  		req := v.(*Service1Request)
   244  		if req.A != 0 || req.B != 0 {
   245  			return errors.New(expected)
   246  		}
   247  		return nil
   248  	}
   249  
   250  	s := NewServer()
   251  	s.RegisterService(new(Service1), "")
   252  	s.RegisterCodec(MockCodec{1, 2}, "mock")
   253  	s.RegisterValidateRequestFunc(validate)
   254  
   255  	r, err := http.NewRequest("POST", "", nil)
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  	r.Header.Set("Content-Type", "mock; dummy")
   260  	w := NewMockResponseWriter()
   261  	s.ServeHTTP(w, r)
   262  	if w.Status != 400 {
   263  		t.Errorf("Status was %d, should be 200.", w.Status)
   264  	}
   265  	if w.Body != expected {
   266  		t.Errorf("Response body was %s, should be %s.", w.Body, expected)
   267  	}
   268  }