gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/sm2/util.go (about) 1 // Copyright (c) 2022 zhaochun 2 // gmgo is licensed under Mulan PSL v2. 3 // You can use this software according to the terms and conditions of the Mulan PSL v2. 4 // You may obtain a copy of Mulan PSL v2 at: 5 // http://license.coscl.org.cn/MulanPSL2 6 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 7 // See the Mulan PSL v2 for more details. 8 9 package sm2 10 11 import ( 12 "crypto/ecdsa" 13 "crypto/elliptic" 14 "encoding/hex" 15 "errors" 16 "fmt" 17 "gitee.com/zhaochuninhefei/gmgo/utils" 18 "io" 19 "math/big" 20 "strings" 21 "sync" 22 ) 23 24 var zero = big.NewInt(0) 25 26 // 将大整数转为字节数组,并根据曲线位数计算出的字节数组长度对左侧补0 27 func toBytes(curve elliptic.Curve, value *big.Int) []byte { 28 // 大整数的字节数组 29 bytes := value.Bytes() 30 // 需要的长度: (256 + 7) / 8 = 32 31 byteLen := (curve.Params().BitSize + 7) >> 3 32 if byteLen == len(bytes) { 33 return bytes 34 } 35 // 左侧补0 36 result := make([]byte, byteLen) 37 copy(result[byteLen-len(bytes):], bytes) 38 return result 39 } 40 41 // 将曲线上的点座标(x,y)转为未压缩字节数组 42 // 参考: GB/T 32918.1-2016 4.2.9 43 func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { 44 return elliptic.Marshal(curve, x, y) 45 } 46 47 // 将曲线上的点座标(x,y)转为压缩字节数组 48 // 返回的字节数组长度33, 第一位是C1压缩标识, 2代表y是偶数, 3代表y是奇数 49 // 参考: GB/T 32918.1-2016 4.2.9 50 func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { 51 // buffer长度: (曲线位数(256) + 7) / 8 + 1 = 33 52 buffer := make([]byte, (curve.Params().BitSize+7)>>3+1) 53 // 将x的字节数组填入右侧32个字节 54 copy(buffer[1:], toBytes(curve, x)) 55 // 首位字节是C1压缩标识 56 // 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。 57 // 又因为p是奇素数,所以两个y必然一奇一偶 58 if getLastBitOfY(x, y) > 0 { 59 // y最右侧一位为1,即奇数,压缩标识为 3 60 buffer[0] = compressed03 61 } else { 62 // y最右侧一位为0,即偶数,压缩标识为 2 63 buffer[0] = compressed02 64 } 65 return buffer 66 } 67 68 // 将曲线上的点座标(x,y)转为混合字节数组 69 // 参考: GB/T 32918.1-2016 4.2.9 70 func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte { 71 // buffer是未做压缩的序列化字节数组, 长度65, 4 + x字节数组(32个) + y字节数组(32个) 72 buffer := elliptic.Marshal(curve, x, y) 73 // 修改首位的压缩标识 74 // TODO: 混合模式有何意义? C1实际并未压缩,把首位标识改为混合标识有啥用? 75 if getLastBitOfY(x, y) > 0 { 76 // y最右侧一位为1,即奇数,压缩标识为 7 77 buffer[0] = mixed07 78 } else { 79 // y最右侧一位为0,即偶数,压缩标识为 6 80 buffer[0] = mixed06 81 } 82 return buffer 83 } 84 85 // 获取y最后一位的值 86 // x坐标为0时,直接返回0 87 // 参考: GB/T 32918.1-2016 A.5.2 88 func getLastBitOfY(x, y *big.Int) uint { 89 // x坐标为0时,直接返回0 90 if x.Cmp(zero) == 0 { 91 return 0 92 } 93 // 返回y最右侧一位的值 94 return y.Bit(0) 95 } 96 97 func toPointXY(bytes []byte) *big.Int { 98 return new(big.Int).SetBytes(bytes) 99 } 100 101 // 根据x坐标计算y坐标 102 // 参考: GB/T 32918.1-2016 A.5.2 B.1.4 103 func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) { 104 // x3 : x^3 105 x3 := new(big.Int).Mul(x, x) 106 x3.Mul(x3, x) 107 // threeX : 3x 108 threeX := new(big.Int).Lsh(x, 1) // x*2 109 threeX.Add(threeX, x) // x*2 + x = 3x 110 111 x3.Sub(x3, threeX) // x^3 - 3x 112 x3.Add(x3, curve.Params().B) // x^3 - 3x + b 113 x3.Mod(x3, curve.Params().P) // (x^3 - 3x + b) mod p 114 // y² ≡ x³ - 3x + b (mod p) 的意思: y^2 和 (x^3 - 3x + b) 同余于p 115 // 但是上一步已经对x3做了一次模运算,所以下面的计算实际上是 y² ≡ ((x³ - 3x + b) mod p) (mod p) 116 // 两次模运算和一次模运算的结果其实是一样的: 23对10取余是3,3再对10取余还是3,大概用更小的x3可以加快计算速度? 117 y := x3.ModSqrt(x3, curve.Params().P) 118 119 if y == nil { 120 return nil, errors.New("can't calculate y based on x") 121 } 122 return y, nil 123 } 124 125 // 字节数组转为曲线上的点坐标 126 // 返回x,y数值,以及字节数组长度(未压缩/混合:65, 压缩:33) 127 // 参考: GB/T 32918.1-2016 4.2.10 A.5.2 128 func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { 129 if len(bytes) < 1+(curve.Params().BitSize/8) { 130 return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes)) 131 } 132 // 获取压缩标识 133 format := bytes[0] 134 byteLen := (curve.Params().BitSize + 7) >> 3 135 switch format { 136 case uncompressed, mixed06, mixed07: // what's the mixed format purpose? 137 // 未压缩,或混合模式下,直接将x,y分别取出转换 138 if len(bytes) < 1+byteLen*2 { 139 return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes)) 140 } 141 x := toPointXY(bytes[1 : 1+byteLen]) 142 y := toPointXY(bytes[1+byteLen : 1+byteLen*2]) 143 if !curve.IsOnCurve(x, y) { 144 return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name) 145 } 146 return x, y, 1 + byteLen*2, nil 147 case compressed02, compressed03: 148 // 压缩模式下 149 if len(bytes) < 1+byteLen { 150 return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes)) 151 } 152 if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) { 153 // y² = x³ - 3x + b, prime curves 154 x := toPointXY(bytes[1 : 1+byteLen]) 155 // 根据x推算y数值 156 y, err := calculatePrimeCurveY(curve, x) 157 if err != nil { 158 return nil, nil, 0, err 159 } 160 // 计算出的y的值与压缩标识冲突的话,则 y = p - y 161 // 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。 162 // 又因为p是奇素数,所以两个y必然一奇一偶 163 if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) { 164 y.Sub(curve.Params().P, y) 165 } 166 return x, y, 1 + byteLen, nil 167 } 168 return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name) 169 } 170 return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format) 171 } 172 173 var ( 174 closedChanOnce sync.Once 175 closedChan chan struct{} 176 ) 177 178 // maybeReadByte reads a single byte from r with ~50% probability. This is used 179 // to ensure that callers do not depend on non-guaranteed behaviour, e.g. 180 // assuming that rsa.GenerateKey is deterministic w.r.t. a given random stream. 181 // 182 // This does not affect tests that pass a stream of fixed bytes as the random 183 // source (e.g. a zeroReader). 184 func maybeReadByte(r io.Reader) { 185 closedChanOnce.Do(func() { 186 closedChan = make(chan struct{}) 187 close(closedChan) 188 }) 189 190 select { 191 case <-closedChan: 192 return 193 case <-closedChan: 194 var buf [1]byte 195 _, err := r.Read(buf[:]) 196 if err != nil { 197 panic(err) 198 } 199 } 200 } 201 202 //goland:noinspection GoUnusedExportedFunction 203 func ConvertSM2Priv2ECPriv(sm2Priv *PrivateKey) (*ecdsa.PrivateKey, error) { 204 ecPriv := &ecdsa.PrivateKey{} 205 ecPriv.Curve = sm2Priv.Curve 206 ecPriv.D = sm2Priv.D 207 ecPriv.X = sm2Priv.X 208 ecPriv.Y = sm2Priv.Y 209 return ecPriv, nil 210 } 211 212 //goland:noinspection GoUnusedExportedFunction 213 func ConvertECPriv2SM2Priv(ecPriv *ecdsa.PrivateKey) (*PrivateKey, error) { 214 sm2Priv := &PrivateKey{} 215 sm2Priv.Curve = ecPriv.Curve 216 if sm2Priv.Curve != P256Sm2() { 217 return nil, errors.New("sm2.ConvertECPriv2SM2Priv: 源私钥并未使用SM2曲线,无法转换") 218 } 219 sm2Priv.D = ecPriv.D 220 sm2Priv.X = ecPriv.X 221 sm2Priv.Y = ecPriv.Y 222 return sm2Priv, nil 223 } 224 225 // ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ 226 // SM2公私钥与hex相互转换 227 228 // ReadSm2PrivFromHex 将hex字符串转为sm2私钥 229 // @param Dhex 16进制字符串,对应sm2.PrivateKey.D 230 // @return *PrivateKey sm2私钥 231 // @return error 232 func ReadSm2PrivFromHex(Dhex string) (*PrivateKey, error) { 233 c := P256Sm2() 234 d, err := hex.DecodeString(Dhex) 235 if err != nil { 236 return nil, err 237 } 238 k := new(big.Int).SetBytes(d) 239 params := c.Params() 240 one := new(big.Int).SetInt64(1) 241 n := new(big.Int).Sub(params.N, one) 242 if k.Cmp(n) >= 0 { 243 return nil, errors.New("privateKey's D is overflow") 244 } 245 priv := new(PrivateKey) 246 priv.PublicKey.Curve = c 247 priv.D = k 248 priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) 249 return priv, nil 250 } 251 252 // WriteSm2PrivToHex 将sm2私钥D转为hex字符串 253 // @param key sm2私钥 254 // @return string 255 func WriteSm2PrivToHex(key *PrivateKey) string { 256 return key.D.Text(16) 257 } 258 259 // ReadSm2PubFromHex 将hex字符串转为sm2公钥 260 // @param Qhex sm2公钥座标x,y的字节数组拼接后的hex转码字符串 261 // @return *PublicKey sm2公钥 262 // @return error 263 func ReadSm2PubFromHex(Qhex string) (*PublicKey, error) { 264 q, err := hex.DecodeString(Qhex) 265 if err != nil { 266 return nil, err 267 } 268 if len(q) == 65 && q[0] == byte(0x04) { 269 q = q[1:] 270 } 271 if len(q) != 64 { 272 return nil, errors.New("publicKey is not uncompressed") 273 } 274 pub := new(PublicKey) 275 pub.Curve = P256Sm2() 276 pub.X = new(big.Int).SetBytes(q[:32]) 277 pub.Y = new(big.Int).SetBytes(q[32:]) 278 return pub, nil 279 } 280 281 // WriteSm2PubToHex 将sm2公钥转为hex字符串 282 // @param key sm2公钥 283 // @return string 284 func WriteSm2PubToHex(key *PublicKey) string { 285 x := key.X.Bytes() 286 y := key.Y.Bytes() 287 if n := len(x); n < 32 { 288 x = append(utils.ZeroByteSlice()[:32-n], x...) 289 } 290 if n := len(y); n < 32 { 291 y = append(utils.ZeroByteSlice()[:32-n], y...) 292 } 293 var c []byte 294 c = append(c, x...) 295 c = append(c, y...) 296 c = append([]byte{0x04}, c...) 297 return hex.EncodeToString(c) 298 } 299 300 // SM2公私钥与hex相互转换 301 // ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑