gitee.com/lh-her-team/common@v1.5.1/crypto/pkcs11/pkcs11.go (about) 1 package pkcs11 2 3 import ( 4 "encoding/hex" 5 "fmt" 6 "log" 7 "os" 8 "strconv" 9 "time" 10 11 "github.com/miekg/pkcs11" 12 "github.com/pkg/errors" 13 ) 14 15 const ( 16 defaultSessionSize = 10 17 ) 18 19 type P11Handle struct { 20 ctx *pkcs11.Ctx 21 sessions chan pkcs11.SessionHandle 22 slot uint 23 sessionCacheSize int 24 hash string 25 pin string 26 } 27 28 func New(lib string, label string, password string, sessionCacheSize int, hash string) (*P11Handle, error) { 29 ctx := pkcs11.New(lib) 30 if ctx == nil { 31 libEnv := os.Getenv("HSM_LIB") 32 log.Printf("lib[%s] invalid, use HSM_LIB[%s] from env\n", lib, libEnv) 33 ctx = pkcs11.New(libEnv) 34 if ctx == nil { 35 return nil, fmt.Errorf("[PKCS11] error: fail to initialize [%s]", libEnv) 36 } 37 } 38 if sessionCacheSize <= 0 { 39 sessionSizeStr := os.Getenv("HSM_SESSION_CACHE_SIZE") 40 sessionSize, err := strconv.Atoi(sessionSizeStr) 41 if err == nil && sessionSize > 0 { 42 log.Printf("sessionCacheSize[%d] invalid, use HSM_SESSION_CACHE_SIZE[%s] from env\n", 43 sessionCacheSize, sessionSizeStr) 44 sessionCacheSize = sessionSize 45 } else { 46 log.Printf("sessionCacheSize[%d] and HSM_SESSION_CACHE_SIZE[%s] invalid, use default size[%d]\n", 47 sessionCacheSize, sessionSizeStr, defaultSessionSize) 48 sessionCacheSize = defaultSessionSize 49 } 50 } 51 err := ctx.Initialize() 52 if err != nil { 53 return nil, err 54 } 55 slots, err := ctx.GetSlotList(true) 56 if err != nil { 57 return nil, fmt.Errorf("PKCS11 error: fail to get slot list [%v]", err) 58 } 59 found := false 60 var slot uint 61 slot, found = findSlot(ctx, slots, label) 62 if !found { 63 labelEnv := os.Getenv("HSM_LABEL") 64 log.Printf("label[%s] invalid, use HSM_LABEL[%s] from env\n", label, labelEnv) 65 slot, found = findSlot(ctx, slots, labelEnv) 66 if !found { 67 return nil, fmt.Errorf("PKCS11 error: fail to find token with label[%s] or HSM_LABEL[%s]", label, labelEnv) 68 } 69 } 70 var session pkcs11.SessionHandle 71 for i := 0; i < 3; i++ { 72 session, err = ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) 73 if err == nil { 74 break 75 } 76 time.Sleep(time.Millisecond * 100) 77 } 78 if err != nil { 79 return nil, fmt.Errorf("PKCS11 error: fail to open session [%v]", err) 80 } 81 err = ctx.Login(session, pkcs11.CKU_USER, password) 82 if err != nil { 83 passEnv := os.Getenv("HSM_PASSWORD") 84 log.Printf("password[%s] invalid, use HSM_PASSWORD[%s] from env\n", 85 hex.EncodeToString([]byte(password)), hex.EncodeToString([]byte(passEnv))) 86 err = ctx.Login(session, pkcs11.CKU_USER, passEnv) 87 if err != nil { 88 return nil, fmt.Errorf("PKCS11 error: fail to login session [%v]", err) 89 } 90 } 91 sessions := make(chan pkcs11.SessionHandle, sessionCacheSize) 92 p11Handle := &P11Handle{ 93 ctx: ctx, 94 sessions: sessions, 95 slot: slot, 96 sessionCacheSize: sessionCacheSize, 97 hash: hash, 98 pin: password, 99 } 100 p11Handle.returnSession(nil, session) 101 return p11Handle, nil 102 } 103 104 func (p11 *P11Handle) getSession() (pkcs11.SessionHandle, error) { 105 var session pkcs11.SessionHandle 106 select { 107 case session = <-p11.sessions: 108 return session, nil 109 default: 110 var err error 111 for i := 0; i < 3; i++ { 112 session, err = p11.ctx.OpenSession(p11.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) 113 if err == nil { 114 break 115 } 116 time.Sleep(time.Millisecond * 100) 117 } 118 if err != nil { 119 return 0, errors.WithMessage(err, "fail to open session after 3 times attempt") 120 } 121 err = p11.ctx.Login(session, pkcs11.CKU_USER, p11.pin) 122 if err != nil && err != pkcs11.Error(pkcs11.CKR_USER_ALREADY_LOGGED_IN) { 123 _ = p11.ctx.CloseSession(session) 124 return 0, errors.WithMessage(err, "login failed") 125 } 126 return session, nil 127 } 128 } 129 130 func (p11 *P11Handle) returnSession(err error, session pkcs11.SessionHandle) { 131 if err != nil { 132 log.Printf("PKCS11 session invalidated, closing session: %v", err) 133 _ = p11.ctx.CloseSession(session) 134 return 135 } 136 select { 137 case p11.sessions <- session: 138 return 139 default: 140 _ = p11.ctx.CloseSession(session) 141 return 142 } 143 } 144 145 func findSlot(ctx *pkcs11.Ctx, slots []uint, label string) (uint, bool) { 146 var slot uint 147 var found bool 148 for _, s := range slots { 149 info, err := ctx.GetTokenInfo(s) 150 if err != nil { 151 continue 152 } 153 if info.Label == label { 154 found = true 155 slot = s 156 break 157 } 158 } 159 return slot, found 160 } 161 162 func listSlot(ctx *pkcs11.Ctx) (map[string]string, error) { 163 slots, err := ctx.GetSlotList(true) 164 if err != nil { 165 return nil, err 166 } 167 res := make(map[string]string) 168 for i, s := range slots { 169 info, err := ctx.GetTokenInfo(s) 170 if err != nil { 171 return nil, err 172 } 173 res[fmt.Sprintf("%d", i)] = info.Label 174 } 175 return res, nil 176 }