github.com/ConsenSys/Quorum@v20.10.0+incompatible/rpc/server_test.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"bufio"
    21  	"bytes"
    22  	"context"
    23  	"errors"
    24  	"io"
    25  	"io/ioutil"
    26  	"net"
    27  	"net/http"
    28  	"path/filepath"
    29  	"strings"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/golang/protobuf/ptypes"
    34  	"github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
    35  	"github.com/stretchr/testify/assert"
    36  )
    37  
    38  func TestServerRegisterName(t *testing.T) {
    39  	server := NewServer()
    40  	service := new(testService)
    41  
    42  	if err := server.RegisterName("test", service); err != nil {
    43  		t.Fatalf("%v", err)
    44  	}
    45  
    46  	if len(server.services.services) != 2 {
    47  		t.Fatalf("Expected 2 service entries, got %d", len(server.services.services))
    48  	}
    49  
    50  	svc, ok := server.services.services["test"]
    51  	if !ok {
    52  		t.Fatalf("Expected service calc to be registered")
    53  	}
    54  
    55  	wantCallbacks := 7
    56  	if len(svc.callbacks) != wantCallbacks {
    57  		t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks))
    58  	}
    59  }
    60  
    61  func TestServer(t *testing.T) {
    62  	files, err := ioutil.ReadDir("testdata")
    63  	if err != nil {
    64  		t.Fatal("where'd my testdata go?")
    65  	}
    66  	for _, f := range files {
    67  		if f.IsDir() || strings.HasPrefix(f.Name(), ".") {
    68  			continue
    69  		}
    70  		path := filepath.Join("testdata", f.Name())
    71  		name := strings.TrimSuffix(f.Name(), filepath.Ext(f.Name()))
    72  		t.Run(name, func(t *testing.T) {
    73  			runTestScript(t, path)
    74  		})
    75  	}
    76  }
    77  
    78  func runTestScript(t *testing.T, file string) {
    79  	server := newTestServer()
    80  	content, err := ioutil.ReadFile(file)
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  
    85  	clientConn, serverConn := net.Pipe()
    86  	defer clientConn.Close()
    87  	go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions)
    88  	readbuf := bufio.NewReader(clientConn)
    89  	for _, line := range strings.Split(string(content), "\n") {
    90  		line = strings.TrimSpace(line)
    91  		switch {
    92  		case len(line) == 0 || strings.HasPrefix(line, "//"):
    93  			// skip comments, blank lines
    94  			continue
    95  		case strings.HasPrefix(line, "--> "):
    96  			t.Log(line)
    97  			// write to connection
    98  			clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second))
    99  			if _, err := io.WriteString(clientConn, line[4:]+"\n"); err != nil {
   100  				t.Fatalf("write error: %v", err)
   101  			}
   102  		case strings.HasPrefix(line, "<-- "):
   103  			t.Log(line)
   104  			want := line[4:]
   105  			// read line from connection and compare text
   106  			clientConn.SetReadDeadline(time.Now().Add(5 * time.Second))
   107  			sent, err := readbuf.ReadString('\n')
   108  			if err != nil {
   109  				t.Fatalf("read error: %v", err)
   110  			}
   111  			sent = strings.TrimRight(sent, "\r\n")
   112  			if sent != want {
   113  				t.Errorf("wrong line from server\ngot:  %s\nwant: %s", sent, want)
   114  			}
   115  		default:
   116  			panic("invalid line in test script: " + line)
   117  		}
   118  	}
   119  }
   120  
   121  // This test checks that responses are delivered for very short-lived connections that
   122  // only carry a single request.
   123  func TestServerShortLivedConn(t *testing.T) {
   124  	server := newTestServer()
   125  	defer server.Stop()
   126  
   127  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   128  	if err != nil {
   129  		t.Fatal("can't listen:", err)
   130  	}
   131  	defer listener.Close()
   132  	go server.ServeListener(listener)
   133  
   134  	var (
   135  		request  = `{"jsonrpc":"2.0","id":1,"method":"rpc_modules"}` + "\n"
   136  		wantResp = `{"jsonrpc":"2.0","id":1,"result":{"nftest":"1.0","rpc":"1.0","test":"1.0"}}` + "\n"
   137  		deadline = time.Now().Add(10 * time.Second)
   138  	)
   139  	for i := 0; i < 20; i++ {
   140  		conn, err := net.Dial("tcp", listener.Addr().String())
   141  		if err != nil {
   142  			t.Fatal("can't dial:", err)
   143  		}
   144  		defer conn.Close()
   145  		conn.SetDeadline(deadline)
   146  		// Write the request, then half-close the connection so the server stops reading.
   147  		conn.Write([]byte(request))
   148  		conn.(*net.TCPConn).CloseWrite()
   149  		// Now try to get the response.
   150  		buf := make([]byte, 2000)
   151  		n, err := conn.Read(buf)
   152  		if err != nil {
   153  			t.Fatal("read error:", err)
   154  		}
   155  		if !bytes.Equal(buf[:n], []byte(wantResp)) {
   156  			t.Fatalf("wrong response: %s", buf[:n])
   157  		}
   158  	}
   159  }
   160  
   161  func TestAuthenticateHttpRequest_whenAuthenticationManagerFails(t *testing.T) {
   162  	protectedServer := NewProtectedServer(&stubAuthenticationManager{false, errors.New("arbitrary error")})
   163  	arbitraryRequest, _ := http.NewRequest("POST", "https://arbitraryUrl", nil)
   164  	captor := &securityContextConfigurerCaptor{}
   165  
   166  	protectedServer.authenticateHttpRequest(arbitraryRequest, captor)
   167  
   168  	actualErr, hasError := captor.context.Value(ctxAuthenticationError).(error)
   169  	assert.True(t, hasError, "must have error")
   170  	assert.EqualError(t, actualErr, "internal error")
   171  	_, hasAuthToken := captor.context.Value(ctxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
   172  	assert.False(t, hasAuthToken, "must not be preauthenticated")
   173  }
   174  
   175  func TestAuthenticateHttpRequest_whenTypical(t *testing.T) {
   176  	protectedServer := NewProtectedServer(&stubAuthenticationManager{true, nil})
   177  	arbitraryRequest, _ := http.NewRequest("POST", "https://arbitraryUrl", nil)
   178  	arbitraryRequest.Header.Set(HttpAuthorizationHeader, "arbitrary value")
   179  	captor := &securityContextConfigurerCaptor{}
   180  
   181  	protectedServer.authenticateHttpRequest(arbitraryRequest, captor)
   182  
   183  	_, hasError := captor.context.Value(ctxAuthenticationError).(error)
   184  	assert.False(t, hasError, "must not have error")
   185  	_, hasAuthToken := captor.context.Value(ctxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
   186  	assert.True(t, hasAuthToken, "must be preauthenticated")
   187  }
   188  
   189  func TestAuthenticateHttpRequest_whenAuthenticationManagerIsDisabled(t *testing.T) {
   190  	protectedServer := NewProtectedServer(&stubAuthenticationManager{false, nil})
   191  	arbitraryRequest, _ := http.NewRequest("POST", "https://arbitraryUrl", nil)
   192  	captor := &securityContextConfigurerCaptor{}
   193  
   194  	protectedServer.authenticateHttpRequest(arbitraryRequest, captor)
   195  
   196  	_, hasError := captor.context.Value(ctxAuthenticationError).(error)
   197  	assert.False(t, hasError, "must not have error")
   198  	_, hasAuthToken := captor.context.Value(ctxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
   199  	assert.False(t, hasAuthToken, "must not be preauthenticated")
   200  }
   201  
   202  func TestAuthenticateHttpRequest_whenMissingAccessToken(t *testing.T) {
   203  	protectedServer := NewProtectedServer(&stubAuthenticationManager{true, nil})
   204  	arbitraryRequest, _ := http.NewRequest("POST", "https://arbitraryUrl", nil)
   205  	captor := &securityContextConfigurerCaptor{}
   206  
   207  	protectedServer.authenticateHttpRequest(arbitraryRequest, captor)
   208  
   209  	actualErr, hasError := captor.context.Value(ctxAuthenticationError).(error)
   210  	assert.True(t, hasError, "must have error")
   211  	assert.EqualError(t, actualErr, "missing access token")
   212  	_, hasAuthToken := captor.context.Value(ctxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
   213  	assert.False(t, hasAuthToken, "must not be preauthenticated")
   214  }
   215  
   216  type securityContextConfigurerCaptor struct {
   217  	context securityContext
   218  }
   219  
   220  func (sc *securityContextConfigurerCaptor) Configure(secCtx securityContext) {
   221  	sc.context = secCtx
   222  }
   223  
   224  type stubAuthenticationManager struct {
   225  	isEnabled bool
   226  	stubErr   error
   227  }
   228  
   229  func (s *stubAuthenticationManager) Authenticate(_ context.Context, _ string) (*proto.PreAuthenticatedAuthenticationToken, error) {
   230  	expiredAt, err := ptypes.TimestampProto(time.Now().Add(1 * time.Hour))
   231  	if err != nil {
   232  		return nil, err
   233  	}
   234  	return &proto.PreAuthenticatedAuthenticationToken{
   235  		ExpiredAt: expiredAt,
   236  	}, nil
   237  }
   238  
   239  func (s *stubAuthenticationManager) IsEnabled(_ context.Context) (bool, error) {
   240  	return s.isEnabled, s.stubErr
   241  }