gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/sm2/util.go (about) 1 // Copyright (c) 2022 zhaochun 2 // core-gm is licensed under Mulan PSL v2. 3 // You can use this software according to the terms and conditions of the Mulan PSL v2. 4 // You may obtain a copy of Mulan PSL v2 at: 5 // http://license.coscl.org.cn/MulanPSL2 6 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 7 // See the Mulan PSL v2 for more details. 8 9 package sm2 10 11 import ( 12 "crypto/ecdsa" 13 "crypto/elliptic" 14 "encoding/hex" 15 "errors" 16 "fmt" 17 "gitee.com/ks-custle/core-gm/utils" 18 "io" 19 "math/big" 20 "strings" 21 "sync" 22 ) 23 24 var zero = big.NewInt(0) 25 26 // 将大整数转为字节数组,并根据曲线位数计算出的字节数组长度对左侧补0 27 func toBytes(curve elliptic.Curve, value *big.Int) []byte { 28 // 大整数的字节数组 29 bytes := value.Bytes() 30 // 需要的长度: (256 + 7) / 8 = 32 31 byteLen := (curve.Params().BitSize + 7) >> 3 32 if byteLen == len(bytes) { 33 return bytes 34 } 35 // 左侧补0 36 result := make([]byte, byteLen) 37 copy(result[byteLen-len(bytes):], bytes) 38 return result 39 } 40 41 // 将曲线上的点座标(x,y)转为未压缩字节数组 42 // 43 // 参考: GB/T 32918.1-2016 4.2.9 44 func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { 45 return elliptic.Marshal(curve, x, y) 46 } 47 48 // 将曲线上的点座标(x,y)转为压缩字节数组 49 // 50 // 返回的字节数组长度33, 第一位是C1压缩标识, 2代表y是偶数, 3代表y是奇数 51 // 参考: GB/T 32918.1-2016 4.2.9 52 func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { 53 // buffer长度: (曲线位数(256) + 7) / 8 + 1 = 33 54 buffer := make([]byte, (curve.Params().BitSize+7)>>3+1) 55 // 将x的字节数组填入右侧32个字节 56 copy(buffer[1:], toBytes(curve, x)) 57 // 首位字节是C1压缩标识 58 // 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。 59 // 又因为p是奇素数,所以两个y必然一奇一偶 60 if getLastBitOfY(x, y) > 0 { 61 // y最右侧一位为1,即奇数,压缩标识为 3 62 buffer[0] = compressed03 63 } else { 64 // y最右侧一位为0,即偶数,压缩标识为 2 65 buffer[0] = compressed02 66 } 67 return buffer 68 } 69 70 // 将曲线上的点座标(x,y)转为混合字节数组 71 // 72 // 参考: GB/T 32918.1-2016 4.2.9 73 func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte { 74 // buffer是未做压缩的序列化字节数组, 长度65, 4 + x字节数组(32个) + y字节数组(32个) 75 buffer := elliptic.Marshal(curve, x, y) 76 // 修改首位的压缩标识 77 // TODO: 混合模式有何意义? C1实际并未压缩,把首位标识改为混合标识有啥用? 78 if getLastBitOfY(x, y) > 0 { 79 // y最右侧一位为1,即奇数,压缩标识为 7 80 buffer[0] = mixed07 81 } else { 82 // y最右侧一位为0,即偶数,压缩标识为 6 83 buffer[0] = mixed06 84 } 85 return buffer 86 } 87 88 // 获取y最后一位的值 89 // 90 // x坐标为0时,直接返回0 91 // 参考: GB/T 32918.1-2016 A.5.2 92 func getLastBitOfY(x, y *big.Int) uint { 93 // x坐标为0时,直接返回0 94 if x.Cmp(zero) == 0 { 95 return 0 96 } 97 // 返回y最右侧一位的值 98 return y.Bit(0) 99 } 100 101 func toPointXY(bytes []byte) *big.Int { 102 return new(big.Int).SetBytes(bytes) 103 } 104 105 // 根据x坐标计算y坐标 106 // 107 // 参考: GB/T 32918.1-2016 A.5.2 B.1.4 108 func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) { 109 // x3 : x^3 110 x3 := new(big.Int).Mul(x, x) 111 x3.Mul(x3, x) 112 // threeX : 3x 113 threeX := new(big.Int).Lsh(x, 1) // x*2 114 threeX.Add(threeX, x) // x*2 + x = 3x 115 116 x3.Sub(x3, threeX) // x^3 - 3x 117 x3.Add(x3, curve.Params().B) // x^3 - 3x + b 118 x3.Mod(x3, curve.Params().P) // (x^3 - 3x + b) mod p 119 // y² ≡ x³ - 3x + b (mod p) 的意思: y^2 和 (x^3 - 3x + b) 同余于p 120 // 但是上一步已经对x3做了一次模运算,所以下面的计算实际上是 y² ≡ ((x³ - 3x + b) mod p) (mod p) 121 // 两次模运算和一次模运算的结果其实是一样的: 23对10取余是3,3再对10取余还是3,大概用更小的x3可以加快计算速度? 122 y := x3.ModSqrt(x3, curve.Params().P) 123 124 if y == nil { 125 return nil, errors.New("can't calculate y based on x") 126 } 127 return y, nil 128 } 129 130 // 字节数组转为曲线上的点坐标 131 // 132 // 返回x,y数值,以及字节数组长度(未压缩/混合:65, 压缩:33) 133 // 参考: GB/T 32918.1-2016 4.2.10 A.5.2 134 func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { 135 if len(bytes) < 1+(curve.Params().BitSize/8) { 136 return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes)) 137 } 138 // 获取压缩标识 139 format := bytes[0] 140 byteLen := (curve.Params().BitSize + 7) >> 3 141 switch format { 142 case uncompressed, mixed06, mixed07: // what's the mixed format purpose? 143 // 未压缩,或混合模式下,直接将x,y分别取出转换 144 if len(bytes) < 1+byteLen*2 { 145 return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes)) 146 } 147 x := toPointXY(bytes[1 : 1+byteLen]) 148 y := toPointXY(bytes[1+byteLen : 1+byteLen*2]) 149 if !curve.IsOnCurve(x, y) { 150 return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name) 151 } 152 return x, y, 1 + byteLen*2, nil 153 case compressed02, compressed03: 154 // 压缩模式下 155 if len(bytes) < 1+byteLen { 156 return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes)) 157 } 158 if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) { 159 // y² = x³ - 3x + b, prime curves 160 x := toPointXY(bytes[1 : 1+byteLen]) 161 // 根据x推算y数值 162 y, err := calculatePrimeCurveY(curve, x) 163 if err != nil { 164 return nil, nil, 0, err 165 } 166 // 计算出的y的值与压缩标识冲突的话,则 y = p - y 167 // 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。 168 // 又因为p是奇素数,所以两个y必然一奇一偶 169 if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) { 170 y.Sub(curve.Params().P, y) 171 } 172 return x, y, 1 + byteLen, nil 173 } 174 return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name) 175 } 176 return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format) 177 } 178 179 var ( 180 closedChanOnce sync.Once 181 closedChan chan struct{} 182 ) 183 184 // maybeReadByte reads a single byte from r with ~50% probability. This is used 185 // to ensure that callers do not depend on non-guaranteed behaviour, e.g. 186 // assuming that rsa.GenerateKey is deterministic w.r.t. a given random stream. 187 // 188 // This does not affect tests that pass a stream of fixed bytes as the random 189 // source (e.g. a zeroReader). 190 func maybeReadByte(r io.Reader) { 191 closedChanOnce.Do(func() { 192 closedChan = make(chan struct{}) 193 close(closedChan) 194 }) 195 196 select { 197 case <-closedChan: 198 return 199 case <-closedChan: 200 var buf [1]byte 201 _, err := r.Read(buf[:]) 202 if err != nil { 203 panic(err) 204 } 205 } 206 } 207 208 //goland:noinspection GoUnusedExportedFunction 209 func ConvertSM2Priv2ECPriv(sm2Priv *PrivateKey) (*ecdsa.PrivateKey, error) { 210 ecPriv := &ecdsa.PrivateKey{} 211 ecPriv.Curve = sm2Priv.Curve 212 ecPriv.D = sm2Priv.D 213 ecPriv.X = sm2Priv.X 214 ecPriv.Y = sm2Priv.Y 215 return ecPriv, nil 216 } 217 218 //goland:noinspection GoUnusedExportedFunction 219 func ConvertECPriv2SM2Priv(ecPriv *ecdsa.PrivateKey) (*PrivateKey, error) { 220 sm2Priv := &PrivateKey{} 221 sm2Priv.Curve = ecPriv.Curve 222 if sm2Priv.Curve != P256Sm2() { 223 return nil, errors.New("sm2.ConvertECPriv2SM2Priv: 源私钥并未使用SM2曲线,无法转换") 224 } 225 sm2Priv.D = ecPriv.D 226 sm2Priv.X = ecPriv.X 227 sm2Priv.Y = ecPriv.Y 228 return sm2Priv, nil 229 } 230 231 // ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ 232 // SM2公私钥与hex相互转换 233 234 // ReadSm2PrivFromHex 将hex字符串转为sm2私钥 235 // 236 // @param Dhex 16进制字符串,对应sm2.PrivateKey.D 237 // @return *PrivateKey sm2私钥 238 // @return error 239 func ReadSm2PrivFromHex(Dhex string) (*PrivateKey, error) { 240 c := P256Sm2() 241 d, err := hex.DecodeString(Dhex) 242 if err != nil { 243 return nil, err 244 } 245 k := new(big.Int).SetBytes(d) 246 params := c.Params() 247 one := new(big.Int).SetInt64(1) 248 n := new(big.Int).Sub(params.N, one) 249 if k.Cmp(n) >= 0 { 250 return nil, errors.New("privateKey's D is overflow") 251 } 252 priv := new(PrivateKey) 253 priv.PublicKey.Curve = c 254 priv.D = k 255 priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) 256 return priv, nil 257 } 258 259 // WriteSm2PrivToHex 将sm2私钥D转为hex字符串 260 // 261 // @param key sm2私钥 262 // @return string 263 func WriteSm2PrivToHex(key *PrivateKey) string { 264 return key.D.Text(16) 265 } 266 267 // ReadSm2PubFromHex 将hex字符串转为sm2公钥 268 // 269 // @param Qhex sm2公钥座标x,y的字节数组拼接后的hex转码字符串 270 // @return *PublicKey sm2公钥 271 // @return error 272 func ReadSm2PubFromHex(Qhex string) (*PublicKey, error) { 273 q, err := hex.DecodeString(Qhex) 274 if err != nil { 275 return nil, err 276 } 277 if len(q) == 65 && q[0] == byte(0x04) { 278 q = q[1:] 279 } 280 if len(q) != 64 { 281 return nil, errors.New("publicKey is not uncompressed") 282 } 283 pub := new(PublicKey) 284 pub.Curve = P256Sm2() 285 pub.X = new(big.Int).SetBytes(q[:32]) 286 pub.Y = new(big.Int).SetBytes(q[32:]) 287 return pub, nil 288 } 289 290 // WriteSm2PubToHex 将sm2公钥转为hex字符串 291 // 292 // @param key sm2公钥 293 // @return string 294 func WriteSm2PubToHex(key *PublicKey) string { 295 x := key.X.Bytes() 296 y := key.Y.Bytes() 297 if n := len(x); n < 32 { 298 x = append(utils.ZeroByteSlice()[:32-n], x...) 299 } 300 if n := len(y); n < 32 { 301 y = append(utils.ZeroByteSlice()[:32-n], y...) 302 } 303 var c []byte 304 c = append(c, x...) 305 c = append(c, y...) 306 c = append([]byte{0x04}, c...) 307 return hex.EncodeToString(c) 308 } 309 310 // SM2公私钥与hex相互转换 311 // ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑