github.com/bigzoro/my_simplechain@v0.0.0-20240315012955-8ad0a2a29bb9/core/access_contoller/crypto/pkcs11/pkcs11.go (about)

     1  /*
     2  Copyright (C) BABEC. All rights reserved.
     3  Copyright (C) THL A29 Limited, a Tencent company. All rights reserved.
     4  
     5  SPDX-License-Identifier: Apache-2.0
     6  */
     7  
     8  package pkcs11
     9  
    10  import (
    11  	"encoding/hex"
    12  	"fmt"
    13  	"log"
    14  	"os"
    15  	"strconv"
    16  	"time"
    17  
    18  	"github.com/miekg/pkcs11"
    19  	"github.com/pkg/errors"
    20  )
    21  
    22  const (
    23  	defaultSessionSize = 10
    24  )
    25  
    26  type P11Handle struct {
    27  	ctx              *pkcs11.Ctx
    28  	sessions         chan pkcs11.SessionHandle
    29  	slot             uint
    30  	sessionCacheSize int
    31  	hash             string
    32  
    33  	pin string
    34  }
    35  
    36  func New(lib string, label string, password string, sessionCacheSize int, hash string) (*P11Handle, error) {
    37  	ctx := pkcs11.New(lib)
    38  	if ctx == nil {
    39  		libEnv := os.Getenv("HSM_LIB")
    40  		log.Printf("lib[%s] invalid, use HSM_LIB[%s] from env\n", lib, libEnv)
    41  		ctx = pkcs11.New(libEnv)
    42  		if ctx == nil {
    43  			return nil, fmt.Errorf("[PKCS11] error: fail to initialize [%s]", libEnv)
    44  		}
    45  	}
    46  
    47  	if sessionCacheSize <= 0 {
    48  		sessionSizeStr := os.Getenv("HSM_SESSION_CACHE_SIZE")
    49  		sessionSize, err := strconv.Atoi(sessionSizeStr)
    50  		if err == nil && sessionSize > 0 {
    51  			log.Printf("sessionCacheSize[%d] invalid, use HSM_SESSION_CACHE_SIZE[%s] from env\n",
    52  				sessionCacheSize, sessionSizeStr)
    53  			sessionCacheSize = sessionSize
    54  		} else {
    55  			log.Printf("sessionCacheSize[%d] and HSM_SESSION_CACHE_SIZE[%s] invalid, use default size[%d]\n",
    56  				sessionCacheSize, sessionSizeStr, defaultSessionSize)
    57  			sessionCacheSize = defaultSessionSize
    58  		}
    59  	}
    60  
    61  	err := ctx.Initialize()
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	slots, err := ctx.GetSlotList(true)
    67  	if err != nil {
    68  		return nil, fmt.Errorf("PKCS11 error: fail to get slot list [%v]", err)
    69  	}
    70  
    71  	found := false
    72  	var slot uint
    73  	slot, found = findSlot(ctx, slots, label)
    74  	if !found {
    75  		labelEnv := os.Getenv("HSM_LABEL")
    76  		log.Printf("label[%s] invalid, use HSM_LABEL[%s] from env\n", label, labelEnv)
    77  		slot, found = findSlot(ctx, slots, labelEnv)
    78  		if !found {
    79  			return nil, fmt.Errorf("PKCS11 error: fail to find token with label[%s] or HSM_LABEL[%s]", label, labelEnv)
    80  		}
    81  	}
    82  
    83  	var session pkcs11.SessionHandle
    84  	for i := 0; i < 3; i++ {
    85  		session, err = ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION)
    86  		if err == nil {
    87  			break
    88  		}
    89  		time.Sleep(time.Millisecond * 100)
    90  	}
    91  	if err != nil {
    92  		return nil, fmt.Errorf("PKCS11 error: fail to open session [%v]", err)
    93  	}
    94  
    95  	err = ctx.Login(session, pkcs11.CKU_USER, password)
    96  	if err != nil {
    97  		passEnv := os.Getenv("HSM_PASSWORD")
    98  		log.Printf("password[%s] invalid, use HSM_PASSWORD[%s] from env\n",
    99  			hex.EncodeToString([]byte(password)), hex.EncodeToString([]byte(passEnv)))
   100  		err = ctx.Login(session, pkcs11.CKU_USER, passEnv)
   101  		if err != nil {
   102  			return nil, fmt.Errorf("PKCS11 error: fail to login session [%v]", err)
   103  		}
   104  	}
   105  
   106  	sessions := make(chan pkcs11.SessionHandle, sessionCacheSize)
   107  	p11Handle := &P11Handle{
   108  		ctx:              ctx,
   109  		sessions:         sessions,
   110  		slot:             slot,
   111  		sessionCacheSize: sessionCacheSize,
   112  		hash:             hash,
   113  		pin:              password,
   114  	}
   115  	p11Handle.returnSession(nil, session)
   116  
   117  	return p11Handle, nil
   118  }
   119  
   120  func (p11 *P11Handle) getSession() (pkcs11.SessionHandle, error) {
   121  	var session pkcs11.SessionHandle
   122  	select {
   123  	case session = <-p11.sessions:
   124  		return session, nil
   125  	default:
   126  		var err error
   127  		for i := 0; i < 3; i++ {
   128  			session, err = p11.ctx.OpenSession(p11.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION)
   129  			if err == nil {
   130  				break
   131  			}
   132  			time.Sleep(time.Millisecond * 100)
   133  		}
   134  		if err != nil {
   135  			return 0, errors.WithMessage(err, "fail to open session after 3 times attempt")
   136  		}
   137  
   138  		err = p11.ctx.Login(session, pkcs11.CKU_USER, p11.pin)
   139  		if err != nil && err != pkcs11.Error(pkcs11.CKR_USER_ALREADY_LOGGED_IN) {
   140  			_ = p11.ctx.CloseSession(session)
   141  			return 0, errors.WithMessage(err, "login failed")
   142  		}
   143  		return session, nil
   144  	}
   145  }
   146  
   147  func (p11 *P11Handle) returnSession(err error, session pkcs11.SessionHandle) {
   148  	if err != nil {
   149  		log.Printf("PKCS11 session invalidated, closing session: %v", err)
   150  		_ = p11.ctx.CloseSession(session)
   151  		return
   152  	}
   153  	select {
   154  	case p11.sessions <- session:
   155  		return
   156  	default:
   157  		_ = p11.ctx.CloseSession(session)
   158  		return
   159  	}
   160  }
   161  
   162  func findSlot(ctx *pkcs11.Ctx, slots []uint, label string) (uint, bool) {
   163  	var slot uint
   164  	var found bool
   165  	for _, s := range slots {
   166  		info, err := ctx.GetTokenInfo(s)
   167  		if err != nil {
   168  			continue
   169  		}
   170  		if info.Label == label {
   171  			found = true
   172  			slot = s
   173  			break
   174  		}
   175  	}
   176  	return slot, found
   177  }
   178  
   179  func listSlot(ctx *pkcs11.Ctx) (map[string]string, error) {
   180  	slots, err := ctx.GetSlotList(true)
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	res := make(map[string]string)
   186  	for i, s := range slots {
   187  		info, err := ctx.GetTokenInfo(s)
   188  		if err != nil {
   189  			return nil, err
   190  		}
   191  		res[fmt.Sprintf("%d", i)] = info.Label
   192  	}
   193  	return res, nil
   194  }