gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/grpc/grpc_test/grpc_test.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // gmgo is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  /*
    10  grpc_test 是对`gitee.com/zhaochuninhefei/gmgo/grpc`的测试包
    11  */
    12  
    13  package grpc_test
    14  
    15  import (
    16  	"errors"
    17  	"fmt"
    18  	"io/ioutil"
    19  	"log"
    20  	"net"
    21  	"testing"
    22  	"time"
    23  
    24  	"gitee.com/zhaochuninhefei/gmgo/gmtls"
    25  	"gitee.com/zhaochuninhefei/gmgo/grpc"
    26  	"gitee.com/zhaochuninhefei/gmgo/grpc/credentials"
    27  	"gitee.com/zhaochuninhefei/gmgo/grpc/grpc_test/echo"
    28  	"gitee.com/zhaochuninhefei/gmgo/net/context"
    29  	"gitee.com/zhaochuninhefei/gmgo/x509"
    30  	"gitee.com/zhaochuninhefei/zcgolog/zclog"
    31  )
    32  
    33  //goland:noinspection GoSnakeCaseUsage
    34  const (
    35  	port    = ":50051"
    36  	address = "localhost:50051"
    37  
    38  	sm2_ca       = "testdata/sm2_ca.cert"
    39  	sm2_signCert = "testdata/sm2_sign.cert"
    40  	sm2_signKey  = "testdata/sm2_sign.key"
    41  	sm2_userCert = "testdata/sm2_user.cert"
    42  	sm2_userKey  = "testdata/sm2_user.key"
    43  
    44  	ecdsa_ca       = "testdata/ecdsa_ca.cert"
    45  	ecdsa_signCert = "testdata/ecdsa_sign.cert"
    46  	ecdsa_signKey  = "testdata/ecdsa_sign.key"
    47  	ecdsa_userCert = "testdata/ecdsa_user.cert"
    48  	ecdsa_userKey  = "testdata/ecdsa_user.key"
    49  
    50  	ecdsaext_ca       = "testdata/ecdsaext_ca.cert"
    51  	ecdsaext_signCert = "testdata/ecdsaext_sign.cert"
    52  	ecdsaext_signKey  = "testdata/ecdsaext_sign.key"
    53  	ecdsaext_userCert = "testdata/ecdsaext_user.cert"
    54  	ecdsaext_userKey  = "testdata/ecdsaext_user.key"
    55  )
    56  
    57  func TestMain(m *testing.M) {
    58  	zcgologConfig := &zclog.Config{
    59  		LogLevelGlobal: zclog.LOG_LEVEL_DEBUG,
    60  	}
    61  	zclog.InitLogger(zcgologConfig)
    62  	go serverRun()
    63  	time.Sleep(1000000)
    64  	m.Run()
    65  }
    66  
    67  var end chan bool
    68  
    69  func Test_credentials_sm2(t *testing.T) {
    70  	end = make(chan bool, 64)
    71  	go clientRun("sm2")
    72  	<-end
    73  }
    74  
    75  func Test_credentials_ecdsa(t *testing.T) {
    76  	end = make(chan bool, 64)
    77  	go clientRun("ecdsa")
    78  	<-end
    79  }
    80  
    81  func Test_credentials_ecdsaext(t *testing.T) {
    82  	end = make(chan bool, 64)
    83  	go clientRun("ecdsaext")
    84  	<-end
    85  }
    86  
    87  func serverRun() {
    88  	// 准备3份服务端证书, 分别是sm2, ecdsa, ecdsaext
    89  	var certs []gmtls.Certificate
    90  	sm2SignCert, err := gmtls.LoadX509KeyPair(sm2_signCert, sm2_signKey)
    91  	if err != nil {
    92  		log.Fatal(err)
    93  	}
    94  	certs = append(certs, sm2SignCert)
    95  	ecdsaSignCert, err := gmtls.LoadX509KeyPair(ecdsa_signCert, ecdsa_signKey)
    96  	if err != nil {
    97  		log.Fatal(err)
    98  	}
    99  	certs = append(certs, ecdsaSignCert)
   100  	ecdsaextSignCert, err := gmtls.LoadX509KeyPair(ecdsaext_signCert, ecdsaext_signKey)
   101  	if err != nil {
   102  		log.Fatal(err)
   103  	}
   104  	certs = append(certs, ecdsaextSignCert)
   105  
   106  	// 准备CA证书池,导入颁发客户端证书的CA证书
   107  	certPool := x509.NewCertPool()
   108  	sm2CaCert, err := ioutil.ReadFile(sm2_ca)
   109  	if err != nil {
   110  		log.Fatal(err)
   111  	}
   112  	certPool.AppendCertsFromPEM(sm2CaCert)
   113  	ecdsaCaCert, err := ioutil.ReadFile(ecdsa_ca)
   114  	if err != nil {
   115  		log.Fatal(err)
   116  	}
   117  	certPool.AppendCertsFromPEM(ecdsaCaCert)
   118  	ecdsaextCaCert, err := ioutil.ReadFile(ecdsaext_ca)
   119  	if err != nil {
   120  		log.Fatal(err)
   121  	}
   122  	certPool.AppendCertsFromPEM(ecdsaextCaCert)
   123  
   124  	// 创建gmtls配置
   125  	config := &gmtls.Config{
   126  		Certificates: certs,
   127  		ClientAuth:   gmtls.RequireAndVerifyClientCert,
   128  		ClientCAs:    certPool,
   129  	}
   130  
   131  	// 创建grpc服务端
   132  	creds := credentials.NewTLS(config)
   133  	s := grpc.NewServer(grpc.Creds(creds))
   134  	echo.RegisterEchoServer(s, &server{})
   135  
   136  	// 开启tcp监听端口
   137  	lis, err := net.Listen("tcp", port)
   138  	if err != nil {
   139  		log.Fatalf("fail to listen: %v", err)
   140  	}
   141  	// 启动grpc服务
   142  	err = s.Serve(lis)
   143  	if err != nil {
   144  		log.Fatalf("Serve: %v", err)
   145  	}
   146  }
   147  
   148  func clientRun(certType string) {
   149  	// 创建客户端本地的证书池
   150  	caPool := x509.NewCertPool()
   151  	// ca证书
   152  	var cacert []byte
   153  	// 客户端证书
   154  	var cert gmtls.Certificate
   155  	// 客户端优先曲线列表
   156  	var curvePreference []gmtls.CurveID
   157  	// 客户端优先密码套件列表
   158  	var cipherSuitesPrefer []uint16
   159  	// 客户端优先签名算法
   160  	var sigAlgPrefer []gmtls.SignatureScheme
   161  	var err error
   162  	switch certType {
   163  	case "sm2":
   164  		// 读取sm2 ca证书
   165  		cacert, err = ioutil.ReadFile(sm2_ca)
   166  		// 读取User证书与私钥,作为客户端的证书与私钥,一般用作密钥交换证书。
   167  		// 但如果服务端要求查看客户端证书(双向tls通信)则也作为客户端身份验证用证书,
   168  		// 此时该证书应该由第三方ca机构颁发签名。
   169  		cert, err = gmtls.LoadX509KeyPair(sm2_userCert, sm2_userKey)
   170  		curvePreference = append(curvePreference, gmtls.Curve256Sm2)
   171  		cipherSuitesPrefer = append(cipherSuitesPrefer, gmtls.TLS_SM4_GCM_SM3)
   172  	case "ecdsa":
   173  		// 读取ecdsa ca证书
   174  		cacert, err = ioutil.ReadFile(ecdsa_ca)
   175  		// 读取User证书与私钥,作为客户端的证书与私钥,一般用作密钥交换证书。
   176  		// 但如果服务端要求查看客户端证书(双向tls通信)则也作为客户端身份验证用证书,
   177  		// 此时该证书应该由第三方ca机构颁发签名。
   178  		cert, err = gmtls.LoadX509KeyPair(ecdsa_userCert, ecdsa_userKey)
   179  		curvePreference = append(curvePreference, gmtls.CurveP256)
   180  		cipherSuitesPrefer = append(cipherSuitesPrefer, gmtls.TLS_AES_128_GCM_SHA256)
   181  		sigAlgPrefer = append(sigAlgPrefer, gmtls.ECDSAWithP256AndSHA256)
   182  	case "ecdsaext":
   183  		// 读取ecdsaext ca证书
   184  		cacert, err = ioutil.ReadFile(ecdsaext_ca)
   185  		// 读取User证书与私钥,作为客户端的证书与私钥,一般用作密钥交换证书。
   186  		// 但如果服务端要求查看客户端证书(双向tls通信)则也作为客户端身份验证用证书,
   187  		// 此时该证书应该由第三方ca机构颁发签名。
   188  		cert, err = gmtls.LoadX509KeyPair(ecdsaext_userCert, ecdsaext_userKey)
   189  		curvePreference = append(curvePreference, gmtls.CurveP256)
   190  		cipherSuitesPrefer = append(cipherSuitesPrefer, gmtls.TLS_AES_128_GCM_SHA256)
   191  		sigAlgPrefer = append(sigAlgPrefer, gmtls.ECDSAEXTWithP256AndSHA256)
   192  	default:
   193  		err = errors.New("目前只支持sm2/ecdsa/ecdsaext")
   194  	}
   195  	if err != nil {
   196  		zclog.Fatal(err)
   197  	}
   198  	// 将ca证书作为根证书加入证书池
   199  	// 即,客户端相信持有该ca颁发的证书的服务端
   200  	caPool.AppendCertsFromPEM(cacert)
   201  
   202  	// 定义gmtls配置
   203  	config := &gmtls.Config{
   204  		RootCAs:      caPool,
   205  		Certificates: []gmtls.Certificate{cert},
   206  		// 因为相关证书是由`x509/x509_test.go`的`TestCreateCertFromCA`生成的,
   207  		// 指定了SAN包含"server.test.com"
   208  		ServerName:         "server.test.com",
   209  		CurvePreferences:   curvePreference,
   210  		PreferCipherSuites: cipherSuitesPrefer,
   211  		SignAlgPrefer:      sigAlgPrefer,
   212  		ClientAuth:         gmtls.RequireAndVerifyClientCert,
   213  	}
   214  	creds := credentials.NewTLS(config)
   215  	conn, err := grpc.Dial(address, grpc.WithTransportCredentials(creds))
   216  	if err != nil {
   217  		log.Fatalf("cannot to connect: %v", err)
   218  	}
   219  	defer func(conn *grpc.ClientConn) {
   220  		_ = conn.Close()
   221  	}(conn)
   222  	c := echo.NewEchoClient(conn)
   223  	echoInClient(c)
   224  	end <- true
   225  }
   226  
   227  // 客户端echo处理
   228  func echoInClient(c echo.EchoClient) {
   229  	msgClient := "hello, this is client."
   230  	fmt.Printf("客户端发出消息: %s\n", msgClient)
   231  	r, err := c.Echo(context.Background(), &echo.EchoRequest{Req: msgClient})
   232  	if err != nil {
   233  		log.Fatalf("failed to echo: %v", err)
   234  	}
   235  	msgServer := r.Result
   236  	fmt.Printf("客户端收到消息: %s\n", msgServer)
   237  }
   238  
   239  type server struct{}
   240  
   241  // Echo 服务端echo处理
   242  //goland:noinspection GoUnusedParameter
   243  func (s *server) Echo(ctx context.Context, req *echo.EchoRequest) (*echo.EchoResponse, error) {
   244  	msgClient := req.Req
   245  	fmt.Printf("服务端接收到消息: %s\n", msgClient)
   246  	msgServer := "hello,this is server."
   247  	fmt.Printf("服务端返回消息: %s\n", msgServer)
   248  	return &echo.EchoResponse{Result: msgServer}, nil
   249  }