gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/grpc_test/grpc_test.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // core-gm is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  /*
    10  grpc_test 是对`core-gm/grpc`的测试包
    11  */
    12  
    13  package grpc_test
    14  
    15  import (
    16  	"errors"
    17  	"fmt"
    18  	"io/ioutil"
    19  	"log"
    20  	"net"
    21  	"testing"
    22  	"time"
    23  
    24  	"gitee.com/ks-custle/core-gm/gmtls"
    25  	"gitee.com/ks-custle/core-gm/grpc"
    26  	"gitee.com/ks-custle/core-gm/grpc/credentials"
    27  	"gitee.com/ks-custle/core-gm/grpc/grpc_test/echo"
    28  	"gitee.com/ks-custle/core-gm/net/context"
    29  	"gitee.com/ks-custle/core-gm/x509"
    30  )
    31  
    32  //goland:noinspection GoSnakeCaseUsage
    33  const (
    34  	port    = ":50051"
    35  	address = "localhost:50051"
    36  
    37  	sm2_ca       = "testdata/sm2_ca.cert"
    38  	sm2_signCert = "testdata/sm2_sign.cert"
    39  	sm2_signKey  = "testdata/sm2_sign.key"
    40  	sm2_userCert = "testdata/sm2_user.cert"
    41  	sm2_userKey  = "testdata/sm2_user.key"
    42  
    43  	ecdsa_ca       = "testdata/ecdsa_ca.cert"
    44  	ecdsa_signCert = "testdata/ecdsa_sign.cert"
    45  	ecdsa_signKey  = "testdata/ecdsa_sign.key"
    46  	ecdsa_userCert = "testdata/ecdsa_user.cert"
    47  	ecdsa_userKey  = "testdata/ecdsa_user.key"
    48  
    49  	ecdsaext_ca       = "testdata/ecdsaext_ca.cert"
    50  	ecdsaext_signCert = "testdata/ecdsaext_sign.cert"
    51  	ecdsaext_signKey  = "testdata/ecdsaext_sign.key"
    52  	ecdsaext_userCert = "testdata/ecdsaext_user.cert"
    53  	ecdsaext_userKey  = "testdata/ecdsaext_user.key"
    54  )
    55  
    56  func TestMain(m *testing.M) {
    57  	go serverRun()
    58  	time.Sleep(1000000)
    59  	m.Run()
    60  }
    61  
    62  var end chan bool
    63  
    64  func Test_credentials_sm2(t *testing.T) {
    65  	end = make(chan bool, 64)
    66  	go clientRun("sm2")
    67  	<-end
    68  }
    69  
    70  func Test_credentials_ecdsa(t *testing.T) {
    71  	end = make(chan bool, 64)
    72  	go clientRun("ecdsa")
    73  	<-end
    74  }
    75  
    76  func Test_credentials_ecdsaext(t *testing.T) {
    77  	end = make(chan bool, 64)
    78  	go clientRun("ecdsaext")
    79  	<-end
    80  }
    81  
    82  func serverRun() {
    83  	// 准备3份服务端证书, 分别是sm2, ecdsa, ecdsaext
    84  	var certs []gmtls.Certificate
    85  	sm2SignCert, err := gmtls.LoadX509KeyPair(sm2_signCert, sm2_signKey)
    86  	if err != nil {
    87  		log.Fatal(err)
    88  	}
    89  	certs = append(certs, sm2SignCert)
    90  	ecdsaSignCert, err := gmtls.LoadX509KeyPair(ecdsa_signCert, ecdsa_signKey)
    91  	if err != nil {
    92  		log.Fatal(err)
    93  	}
    94  	certs = append(certs, ecdsaSignCert)
    95  	ecdsaextSignCert, err := gmtls.LoadX509KeyPair(ecdsaext_signCert, ecdsaext_signKey)
    96  	if err != nil {
    97  		log.Fatal(err)
    98  	}
    99  	certs = append(certs, ecdsaextSignCert)
   100  
   101  	// 准备CA证书池,导入颁发客户端证书的CA证书
   102  	certPool := x509.NewCertPool()
   103  	sm2CaCert, err := ioutil.ReadFile(sm2_ca)
   104  	if err != nil {
   105  		log.Fatal(err)
   106  	}
   107  	certPool.AppendCertsFromPEM(sm2CaCert)
   108  	ecdsaCaCert, err := ioutil.ReadFile(ecdsa_ca)
   109  	if err != nil {
   110  		log.Fatal(err)
   111  	}
   112  	certPool.AppendCertsFromPEM(ecdsaCaCert)
   113  	ecdsaextCaCert, err := ioutil.ReadFile(ecdsaext_ca)
   114  	if err != nil {
   115  		log.Fatal(err)
   116  	}
   117  	certPool.AppendCertsFromPEM(ecdsaextCaCert)
   118  
   119  	// 创建gmtls配置
   120  	config := &gmtls.Config{
   121  		Certificates: certs,
   122  		ClientAuth:   gmtls.RequireAndVerifyClientCert,
   123  		ClientCAs:    certPool,
   124  	}
   125  
   126  	// 创建grpc服务端
   127  	creds := credentials.NewTLS(config)
   128  	s := grpc.NewServer(grpc.Creds(creds))
   129  	echo.RegisterEchoServer(s, &server{})
   130  
   131  	// 开启tcp监听端口
   132  	lis, err := net.Listen("tcp", port)
   133  	if err != nil {
   134  		log.Fatalf("fail to listen: %v", err)
   135  	}
   136  	// 启动grpc服务
   137  	err = s.Serve(lis)
   138  	if err != nil {
   139  		log.Fatalf("Serve: %v", err)
   140  	}
   141  }
   142  
   143  func clientRun(certType string) {
   144  	// 创建客户端本地的证书池
   145  	caPool := x509.NewCertPool()
   146  	// ca证书
   147  	var cacert []byte
   148  	// 客户端证书
   149  	var cert gmtls.Certificate
   150  	// 客户端优先曲线列表
   151  	var curvePreference []gmtls.CurveID
   152  	// 客户端优先密码套件列表
   153  	var cipherSuitesPrefer []uint16
   154  	// 客户端优先签名算法
   155  	var sigAlgPrefer []gmtls.SignatureScheme
   156  	var err error
   157  	switch certType {
   158  	case "sm2":
   159  		// 读取sm2 ca证书
   160  		cacert, err = ioutil.ReadFile(sm2_ca)
   161  		// 读取User证书与私钥,作为客户端的证书与私钥,一般用作密钥交换证书。
   162  		// 但如果服务端要求查看客户端证书(双向tls通信)则也作为客户端身份验证用证书,
   163  		// 此时该证书应该由第三方ca机构颁发签名。
   164  		cert, err = gmtls.LoadX509KeyPair(sm2_userCert, sm2_userKey)
   165  		curvePreference = append(curvePreference, gmtls.Curve256Sm2)
   166  		cipherSuitesPrefer = append(cipherSuitesPrefer, gmtls.TLS_SM4_GCM_SM3)
   167  	case "ecdsa":
   168  		// 读取ecdsa ca证书
   169  		cacert, err = ioutil.ReadFile(ecdsa_ca)
   170  		// 读取User证书与私钥,作为客户端的证书与私钥,一般用作密钥交换证书。
   171  		// 但如果服务端要求查看客户端证书(双向tls通信)则也作为客户端身份验证用证书,
   172  		// 此时该证书应该由第三方ca机构颁发签名。
   173  		cert, err = gmtls.LoadX509KeyPair(ecdsa_userCert, ecdsa_userKey)
   174  		curvePreference = append(curvePreference, gmtls.CurveP256)
   175  		cipherSuitesPrefer = append(cipherSuitesPrefer, gmtls.TLS_AES_128_GCM_SHA256)
   176  		sigAlgPrefer = append(sigAlgPrefer, gmtls.ECDSAWithP256AndSHA256)
   177  	case "ecdsaext":
   178  		// 读取ecdsaext ca证书
   179  		cacert, err = ioutil.ReadFile(ecdsaext_ca)
   180  		// 读取User证书与私钥,作为客户端的证书与私钥,一般用作密钥交换证书。
   181  		// 但如果服务端要求查看客户端证书(双向tls通信)则也作为客户端身份验证用证书,
   182  		// 此时该证书应该由第三方ca机构颁发签名。
   183  		cert, err = gmtls.LoadX509KeyPair(ecdsaext_userCert, ecdsaext_userKey)
   184  		curvePreference = append(curvePreference, gmtls.CurveP256)
   185  		cipherSuitesPrefer = append(cipherSuitesPrefer, gmtls.TLS_AES_128_GCM_SHA256)
   186  		sigAlgPrefer = append(sigAlgPrefer, gmtls.ECDSAEXTWithP256AndSHA256)
   187  	default:
   188  		err = errors.New("目前只支持sm2/ecdsa/ecdsaext")
   189  	}
   190  	if err != nil {
   191  		log.Fatal(err)
   192  	}
   193  	// 将ca证书作为根证书加入证书池
   194  	// 即,客户端相信持有该ca颁发的证书的服务端
   195  	caPool.AppendCertsFromPEM(cacert)
   196  
   197  	// 定义gmtls配置
   198  	config := &gmtls.Config{
   199  		RootCAs:      caPool,
   200  		Certificates: []gmtls.Certificate{cert},
   201  		// 因为相关证书是由`x509/x509_test.go`的`TestCreateCertFromCA`生成的,
   202  		// 指定了SAN包含"server.test.com"
   203  		ServerName:         "server.test.com",
   204  		CurvePreferences:   curvePreference,
   205  		PreferCipherSuites: cipherSuitesPrefer,
   206  		SignAlgPrefer:      sigAlgPrefer,
   207  		ClientAuth:         gmtls.RequireAndVerifyClientCert,
   208  	}
   209  	creds := credentials.NewTLS(config)
   210  	conn, err := grpc.Dial(address, grpc.WithTransportCredentials(creds))
   211  	if err != nil {
   212  		log.Fatalf("cannot to connect: %v", err)
   213  	}
   214  	defer func(conn *grpc.ClientConn) {
   215  		_ = conn.Close()
   216  	}(conn)
   217  	c := echo.NewEchoClient(conn)
   218  	echoInClient(c)
   219  	end <- true
   220  }
   221  
   222  // 客户端echo处理
   223  func echoInClient(c echo.EchoClient) {
   224  	msgClient := "hello, this is client."
   225  	fmt.Printf("客户端发出消息: %s\n", msgClient)
   226  	r, err := c.Echo(context.Background(), &echo.EchoRequest{Req: msgClient})
   227  	if err != nil {
   228  		log.Fatalf("failed to echo: %v", err)
   229  	}
   230  	msgServer := r.Result
   231  	fmt.Printf("客户端收到消息: %s\n", msgServer)
   232  }
   233  
   234  type server struct{}
   235  
   236  // Echo 服务端echo处理
   237  //
   238  //goland:noinspection GoUnusedParameter
   239  func (s *server) Echo(ctx context.Context, req *echo.EchoRequest) (*echo.EchoResponse, error) {
   240  	msgClient := req.Req
   241  	fmt.Printf("服务端接收到消息: %s\n", msgClient)
   242  	msgServer := "hello,this is server."
   243  	fmt.Printf("服务端返回消息: %s\n", msgServer)
   244  	return &echo.EchoResponse{Result: msgServer}, nil
   245  }