gitee.com/lh-her-team/common@v1.5.1/crypto/pkcs11/pkcs11.go (about)

     1  package pkcs11
     2  
     3  import (
     4  	"encoding/hex"
     5  	"fmt"
     6  	"log"
     7  	"os"
     8  	"strconv"
     9  	"time"
    10  
    11  	"github.com/miekg/pkcs11"
    12  	"github.com/pkg/errors"
    13  )
    14  
    15  const (
    16  	defaultSessionSize = 10
    17  )
    18  
    19  type P11Handle struct {
    20  	ctx              *pkcs11.Ctx
    21  	sessions         chan pkcs11.SessionHandle
    22  	slot             uint
    23  	sessionCacheSize int
    24  	hash             string
    25  	pin              string
    26  }
    27  
    28  func New(lib string, label string, password string, sessionCacheSize int, hash string) (*P11Handle, error) {
    29  	ctx := pkcs11.New(lib)
    30  	if ctx == nil {
    31  		libEnv := os.Getenv("HSM_LIB")
    32  		log.Printf("lib[%s] invalid, use HSM_LIB[%s] from env\n", lib, libEnv)
    33  		ctx = pkcs11.New(libEnv)
    34  		if ctx == nil {
    35  			return nil, fmt.Errorf("[PKCS11] error: fail to initialize [%s]", libEnv)
    36  		}
    37  	}
    38  	if sessionCacheSize <= 0 {
    39  		sessionSizeStr := os.Getenv("HSM_SESSION_CACHE_SIZE")
    40  		sessionSize, err := strconv.Atoi(sessionSizeStr)
    41  		if err == nil && sessionSize > 0 {
    42  			log.Printf("sessionCacheSize[%d] invalid, use HSM_SESSION_CACHE_SIZE[%s] from env\n",
    43  				sessionCacheSize, sessionSizeStr)
    44  			sessionCacheSize = sessionSize
    45  		} else {
    46  			log.Printf("sessionCacheSize[%d] and HSM_SESSION_CACHE_SIZE[%s] invalid, use default size[%d]\n",
    47  				sessionCacheSize, sessionSizeStr, defaultSessionSize)
    48  			sessionCacheSize = defaultSessionSize
    49  		}
    50  	}
    51  	err := ctx.Initialize()
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	slots, err := ctx.GetSlotList(true)
    56  	if err != nil {
    57  		return nil, fmt.Errorf("PKCS11 error: fail to get slot list [%v]", err)
    58  	}
    59  	found := false
    60  	var slot uint
    61  	slot, found = findSlot(ctx, slots, label)
    62  	if !found {
    63  		labelEnv := os.Getenv("HSM_LABEL")
    64  		log.Printf("label[%s] invalid, use HSM_LABEL[%s] from env\n", label, labelEnv)
    65  		slot, found = findSlot(ctx, slots, labelEnv)
    66  		if !found {
    67  			return nil, fmt.Errorf("PKCS11 error: fail to find token with label[%s] or HSM_LABEL[%s]", label, labelEnv)
    68  		}
    69  	}
    70  	var session pkcs11.SessionHandle
    71  	for i := 0; i < 3; i++ {
    72  		session, err = ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION)
    73  		if err == nil {
    74  			break
    75  		}
    76  		time.Sleep(time.Millisecond * 100)
    77  	}
    78  	if err != nil {
    79  		return nil, fmt.Errorf("PKCS11 error: fail to open session [%v]", err)
    80  	}
    81  	err = ctx.Login(session, pkcs11.CKU_USER, password)
    82  	if err != nil {
    83  		passEnv := os.Getenv("HSM_PASSWORD")
    84  		log.Printf("password[%s] invalid, use HSM_PASSWORD[%s] from env\n",
    85  			hex.EncodeToString([]byte(password)), hex.EncodeToString([]byte(passEnv)))
    86  		err = ctx.Login(session, pkcs11.CKU_USER, passEnv)
    87  		if err != nil {
    88  			return nil, fmt.Errorf("PKCS11 error: fail to login session [%v]", err)
    89  		}
    90  	}
    91  	sessions := make(chan pkcs11.SessionHandle, sessionCacheSize)
    92  	p11Handle := &P11Handle{
    93  		ctx:              ctx,
    94  		sessions:         sessions,
    95  		slot:             slot,
    96  		sessionCacheSize: sessionCacheSize,
    97  		hash:             hash,
    98  		pin:              password,
    99  	}
   100  	p11Handle.returnSession(nil, session)
   101  	return p11Handle, nil
   102  }
   103  
   104  func (p11 *P11Handle) getSession() (pkcs11.SessionHandle, error) {
   105  	var session pkcs11.SessionHandle
   106  	select {
   107  	case session = <-p11.sessions:
   108  		return session, nil
   109  	default:
   110  		var err error
   111  		for i := 0; i < 3; i++ {
   112  			session, err = p11.ctx.OpenSession(p11.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION)
   113  			if err == nil {
   114  				break
   115  			}
   116  			time.Sleep(time.Millisecond * 100)
   117  		}
   118  		if err != nil {
   119  			return 0, errors.WithMessage(err, "fail to open session after 3 times attempt")
   120  		}
   121  		err = p11.ctx.Login(session, pkcs11.CKU_USER, p11.pin)
   122  		if err != nil && err != pkcs11.Error(pkcs11.CKR_USER_ALREADY_LOGGED_IN) {
   123  			_ = p11.ctx.CloseSession(session)
   124  			return 0, errors.WithMessage(err, "login failed")
   125  		}
   126  		return session, nil
   127  	}
   128  }
   129  
   130  func (p11 *P11Handle) returnSession(err error, session pkcs11.SessionHandle) {
   131  	if err != nil {
   132  		log.Printf("PKCS11 session invalidated, closing session: %v", err)
   133  		_ = p11.ctx.CloseSession(session)
   134  		return
   135  	}
   136  	select {
   137  	case p11.sessions <- session:
   138  		return
   139  	default:
   140  		_ = p11.ctx.CloseSession(session)
   141  		return
   142  	}
   143  }
   144  
   145  func findSlot(ctx *pkcs11.Ctx, slots []uint, label string) (uint, bool) {
   146  	var slot uint
   147  	var found bool
   148  	for _, s := range slots {
   149  		info, err := ctx.GetTokenInfo(s)
   150  		if err != nil {
   151  			continue
   152  		}
   153  		if info.Label == label {
   154  			found = true
   155  			slot = s
   156  			break
   157  		}
   158  	}
   159  	return slot, found
   160  }
   161  
   162  func listSlot(ctx *pkcs11.Ctx) (map[string]string, error) {
   163  	slots, err := ctx.GetSlotList(true)
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  	res := make(map[string]string)
   168  	for i, s := range slots {
   169  		info, err := ctx.GetTokenInfo(s)
   170  		if err != nil {
   171  			return nil, err
   172  		}
   173  		res[fmt.Sprintf("%d", i)] = info.Label
   174  	}
   175  	return res, nil
   176  }