gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/credentials/local/local_test.go (about)

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package local
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"net"
    25  	"runtime"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"gitee.com/ks-custle/core-gm/grpc/credentials"
    31  	"gitee.com/ks-custle/core-gm/grpc/internal/grpctest"
    32  )
    33  
    34  const defaultTestTimeout = 10 * time.Second
    35  
    36  type s struct {
    37  	grpctest.Tester
    38  }
    39  
    40  func Test(t *testing.T) {
    41  	grpctest.RunSubTests(t, s{})
    42  }
    43  
    44  func (s) TestGetSecurityLevel(t *testing.T) {
    45  	testCases := []struct {
    46  		testNetwork string
    47  		testAddr    string
    48  		want        credentials.SecurityLevel
    49  	}{
    50  		{
    51  			testNetwork: "tcp",
    52  			testAddr:    "127.0.0.1:10000",
    53  			want:        credentials.NoSecurity,
    54  		},
    55  		{
    56  			testNetwork: "tcp",
    57  			testAddr:    "[::1]:10000",
    58  			want:        credentials.NoSecurity,
    59  		},
    60  		{
    61  			testNetwork: "unix",
    62  			testAddr:    "/tmp/grpc_fullstack_test",
    63  			want:        credentials.PrivacyAndIntegrity,
    64  		},
    65  		{
    66  			testNetwork: "tcp",
    67  			testAddr:    "192.168.0.1:10000",
    68  			want:        credentials.InvalidSecurityLevel,
    69  		},
    70  	}
    71  	for _, tc := range testCases {
    72  		got, _ := getSecurityLevel(tc.testNetwork, tc.testAddr)
    73  		if got != tc.want {
    74  			t.Fatalf("GetSeurityLevel(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String())
    75  		}
    76  	}
    77  }
    78  
    79  type serverHandshake func(net.Conn) (credentials.AuthInfo, error)
    80  
    81  func getSecurityLevelFromAuthInfo(ai credentials.AuthInfo) credentials.SecurityLevel {
    82  	if c, ok := ai.(interface {
    83  		GetCommonAuthInfo() credentials.CommonAuthInfo
    84  	}); ok {
    85  		return c.GetCommonAuthInfo().SecurityLevel
    86  	}
    87  	return credentials.InvalidSecurityLevel
    88  }
    89  
    90  // Server local handshake implementation.
    91  func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) {
    92  	cred := NewCredentials()
    93  	_, authInfo, err := cred.ServerHandshake(conn)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	return authInfo, nil
    98  }
    99  
   100  // Client local handshake implementation.
   101  func clientLocalHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) {
   102  	cred := NewCredentials()
   103  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   104  	defer cancel()
   105  
   106  	_, authInfo, err := cred.ClientHandshake(ctx, lisAddr, conn)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	return authInfo, nil
   111  }
   112  
   113  // Client connects to a server with local credentials.
   114  func clientHandle(hs func(net.Conn, string) (credentials.AuthInfo, error), network, lisAddr string) (credentials.AuthInfo, error) {
   115  	conn, _ := net.Dial(network, lisAddr)
   116  	defer conn.Close()
   117  	clientAuthInfo, err := hs(conn, lisAddr)
   118  	if err != nil {
   119  		return nil, fmt.Errorf("Error on client while handshake")
   120  	}
   121  	return clientAuthInfo, nil
   122  }
   123  
   124  type testServerHandleResult struct {
   125  	authInfo credentials.AuthInfo
   126  	err      error
   127  }
   128  
   129  // Server accepts a client's connection with local credentials.
   130  func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net.Listener) {
   131  	serverRawConn, err := lis.Accept()
   132  	if err != nil {
   133  		done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)}
   134  		return
   135  	}
   136  	serverAuthInfo, err := hs(serverRawConn)
   137  	if err != nil {
   138  		serverRawConn.Close()
   139  		done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)}
   140  		return
   141  	}
   142  	done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil}
   143  }
   144  
   145  func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, error) {
   146  	done := make(chan testServerHandleResult, 1)
   147  	const timeout = 5 * time.Second
   148  	timer := time.NewTimer(timeout)
   149  	defer timer.Stop()
   150  	go serverHandle(serverLocalHandshake, done, lis)
   151  	defer lis.Close()
   152  	clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String())
   153  	if err != nil {
   154  		return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: %v", err)
   155  	}
   156  	select {
   157  	case <-timer.C:
   158  		return credentials.InvalidSecurityLevel, fmt.Errorf("Test didn't finish in time")
   159  	case serverHandleResult := <-done:
   160  		if serverHandleResult.err != nil {
   161  			return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
   162  		}
   163  		clientSecLevel := getSecurityLevelFromAuthInfo(clientAuthInfo)
   164  		serverSecLevel := getSecurityLevelFromAuthInfo(serverHandleResult.authInfo)
   165  
   166  		if clientSecLevel == credentials.InvalidSecurityLevel {
   167  			return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()")
   168  		}
   169  		if serverSecLevel == credentials.InvalidSecurityLevel {
   170  			return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()")
   171  		}
   172  		if clientSecLevel != serverSecLevel {
   173  			return credentials.InvalidSecurityLevel, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
   174  		}
   175  		return clientSecLevel, nil
   176  	}
   177  }
   178  
   179  func (s) TestServerAndClientHandshake(t *testing.T) {
   180  	testCases := []struct {
   181  		testNetwork string
   182  		testAddr    string
   183  		want        credentials.SecurityLevel
   184  	}{
   185  		{
   186  			testNetwork: "tcp",
   187  			testAddr:    "127.0.0.1:0",
   188  			want:        credentials.NoSecurity,
   189  		},
   190  		{
   191  			testNetwork: "tcp",
   192  			testAddr:    "[::1]:0",
   193  			want:        credentials.NoSecurity,
   194  		},
   195  		{
   196  			testNetwork: "tcp",
   197  			testAddr:    "localhost:0",
   198  			want:        credentials.NoSecurity,
   199  		},
   200  		{
   201  			testNetwork: "unix",
   202  			testAddr:    fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()),
   203  			want:        credentials.PrivacyAndIntegrity,
   204  		},
   205  	}
   206  	for _, tc := range testCases {
   207  		if runtime.GOOS == "windows" && tc.testNetwork == "unix" {
   208  			t.Skip("skipping tests for unix connections on Windows")
   209  		}
   210  		t.Run("serverAndClientHandshakeResult", func(t *testing.T) {
   211  			lis, err := net.Listen(tc.testNetwork, tc.testAddr)
   212  			if err != nil {
   213  				if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
   214  					strings.Contains(err.Error(), "socket: address family not supported by protocol") {
   215  					t.Skipf("no support for address %v", tc.testAddr)
   216  				}
   217  				t.Fatalf("Failed to listen: %v", err)
   218  			}
   219  			got, err := serverAndClientHandshake(lis)
   220  			if got != tc.want {
   221  				t.Fatalf("serverAndClientHandshake(%s, %s) = %v, %v; want %v, nil", tc.testNetwork, tc.testAddr, got, err, tc.want)
   222  			}
   223  		})
   224  	}
   225  }