github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm2/util.go (about) 1 // Copyright 2022 s1ren@github.com/hxx258456. 2 3 package sm2 4 5 import ( 6 "crypto/ecdsa" 7 "crypto/elliptic" 8 "encoding/hex" 9 "errors" 10 "fmt" 11 "io" 12 "math/big" 13 "strings" 14 "sync" 15 16 "github.com/hxx258456/ccgo/utils" 17 ) 18 19 var zero = big.NewInt(0) 20 21 // 将大整数转为字节数组,并根据曲线位数计算出的字节数组长度对左侧补0 22 func toBytes(curve elliptic.Curve, value *big.Int) []byte { 23 // 大整数的字节数组 24 bytes := value.Bytes() 25 // 需要的长度: (256 + 7) / 8 = 32 26 byteLen := (curve.Params().BitSize + 7) >> 3 27 if byteLen == len(bytes) { 28 return bytes 29 } 30 // 左侧补0 31 result := make([]byte, byteLen) 32 copy(result[byteLen-len(bytes):], bytes) 33 return result 34 } 35 36 // 将曲线上的点座标(x,y)转为未压缩字节数组 37 // 38 // 参考: GB/T 32918.1-2016 4.2.9 39 func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { 40 return elliptic.Marshal(curve, x, y) 41 } 42 43 // 将曲线上的点座标(x,y)转为压缩字节数组 44 // 45 // 返回的字节数组长度33, 第一位是C1压缩标识, 2代表y是偶数, 3代表y是奇数 46 // 参考: GB/T 32918.1-2016 4.2.9 47 func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { 48 // buffer长度: (曲线位数(256) + 7) / 8 + 1 = 33 49 buffer := make([]byte, (curve.Params().BitSize+7)>>3+1) 50 // 将x的字节数组填入右侧32个字节 51 copy(buffer[1:], toBytes(curve, x)) 52 // 首位字节是C1压缩标识 53 // 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。 54 // 又因为p是奇素数,所以两个y必然一奇一偶 55 if getLastBitOfY(x, y) > 0 { 56 // y最右侧一位为1,即奇数,压缩标识为 3 57 buffer[0] = compressed03 58 } else { 59 // y最右侧一位为0,即偶数,压缩标识为 2 60 buffer[0] = compressed02 61 } 62 return buffer 63 } 64 65 // 将曲线上的点座标(x,y)转为混合字节数组 66 // 67 // 参考: GB/T 32918.1-2016 4.2.9 68 func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte { 69 // buffer是未做压缩的序列化字节数组, 长度65, 4 + x字节数组(32个) + y字节数组(32个) 70 buffer := elliptic.Marshal(curve, x, y) 71 // 修改首位的压缩标识 72 // TODO: 混合模式有何意义? C1实际并未压缩,把首位标识改为混合标识有啥用? 73 if getLastBitOfY(x, y) > 0 { 74 // y最右侧一位为1,即奇数,压缩标识为 7 75 buffer[0] = mixed07 76 } else { 77 // y最右侧一位为0,即偶数,压缩标识为 6 78 buffer[0] = mixed06 79 } 80 return buffer 81 } 82 83 // 获取y最后一位的值 84 // 85 // x坐标为0时,直接返回0 86 // 参考: GB/T 32918.1-2016 A.5.2 87 func getLastBitOfY(x, y *big.Int) uint { 88 // x坐标为0时,直接返回0 89 if x.Cmp(zero) == 0 { 90 return 0 91 } 92 // 返回y最右侧一位的值 93 return y.Bit(0) 94 } 95 96 func toPointXY(bytes []byte) *big.Int { 97 return new(big.Int).SetBytes(bytes) 98 } 99 100 // 根据x坐标计算y坐标 101 // 102 // 参考: GB/T 32918.1-2016 A.5.2 B.1.4 103 func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) { 104 // x3 : x^3 105 x3 := new(big.Int).Mul(x, x) 106 x3.Mul(x3, x) 107 // threeX : 3x 108 threeX := new(big.Int).Lsh(x, 1) // x*2 109 threeX.Add(threeX, x) // x*2 + x = 3x 110 111 x3.Sub(x3, threeX) // x^3 - 3x 112 x3.Add(x3, curve.Params().B) // x^3 - 3x + b 113 x3.Mod(x3, curve.Params().P) // (x^3 - 3x + b) mod p 114 // y² ≡ x³ - 3x + b (mod p) 的意思: y^2 和 (x^3 - 3x + b) 同余于p 115 // 但是上一步已经对x3做了一次模运算,所以下面的计算实际上是 y² ≡ ((x³ - 3x + b) mod p) (mod p) 116 // 两次模运算和一次模运算的结果其实是一样的: 23对10取余是3,3再对10取余还是3,大概用更小的x3可以加快计算速度? 117 y := x3.ModSqrt(x3, curve.Params().P) 118 119 if y == nil { 120 return nil, errors.New("can't calculate y based on x") 121 } 122 return y, nil 123 } 124 125 // 字节数组转为曲线上的点坐标 126 // 127 // 返回x,y数值,以及字节数组长度(未压缩/混合:65, 压缩:33) 128 // 参考: GB/T 32918.1-2016 4.2.10 A.5.2 129 func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { 130 if len(bytes) < 1+(curve.Params().BitSize/8) { 131 return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes)) 132 } 133 // 获取压缩标识 134 format := bytes[0] 135 byteLen := (curve.Params().BitSize + 7) >> 3 136 switch format { 137 case uncompressed, mixed06, mixed07: // what's the mixed format purpose? 138 // 未压缩,或混合模式下,直接将x,y分别取出转换 139 if len(bytes) < 1+byteLen*2 { 140 return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes)) 141 } 142 x := toPointXY(bytes[1 : 1+byteLen]) 143 y := toPointXY(bytes[1+byteLen : 1+byteLen*2]) 144 if !curve.IsOnCurve(x, y) { 145 return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name) 146 } 147 return x, y, 1 + byteLen*2, nil 148 case compressed02, compressed03: 149 // 压缩模式下 150 if len(bytes) < 1+byteLen { 151 return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes)) 152 } 153 if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) { 154 // y² = x³ - 3x + b, prime curves 155 x := toPointXY(bytes[1 : 1+byteLen]) 156 // 根据x推算y数值 157 y, err := calculatePrimeCurveY(curve, x) 158 if err != nil { 159 return nil, nil, 0, err 160 } 161 // 计算出的y的值与压缩标识冲突的话,则 y = p - y 162 // 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。 163 // 又因为p是奇素数,所以两个y必然一奇一偶 164 if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) { 165 y.Sub(curve.Params().P, y) 166 } 167 return x, y, 1 + byteLen, nil 168 } 169 return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name) 170 } 171 return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format) 172 } 173 174 var ( 175 closedChanOnce sync.Once 176 closedChan chan struct{} 177 ) 178 179 // maybeReadByte reads a single byte from r with ~50% probability. This is used 180 // to ensure that callers do not depend on non-guaranteed behaviour, e.g. 181 // assuming that rsa.GenerateKey is deterministic w.r.t. a given random stream. 182 // 183 // This does not affect tests that pass a stream of fixed bytes as the random 184 // source (e.g. a zeroReader). 185 func maybeReadByte(r io.Reader) { 186 closedChanOnce.Do(func() { 187 closedChan = make(chan struct{}) 188 close(closedChan) 189 }) 190 191 select { 192 case <-closedChan: 193 return 194 case <-closedChan: 195 var buf [1]byte 196 _, err := r.Read(buf[:]) 197 if err != nil { 198 panic(err) 199 } 200 } 201 } 202 203 //goland:noinspection GoUnusedExportedFunction 204 func ConvertSM2Priv2ECPriv(sm2Priv *PrivateKey) (*ecdsa.PrivateKey, error) { 205 ecPriv := &ecdsa.PrivateKey{} 206 ecPriv.Curve = sm2Priv.Curve 207 ecPriv.D = sm2Priv.D 208 ecPriv.X = sm2Priv.X 209 ecPriv.Y = sm2Priv.Y 210 return ecPriv, nil 211 } 212 213 //goland:noinspection GoUnusedExportedFunction 214 func ConvertECPriv2SM2Priv(ecPriv *ecdsa.PrivateKey) (*PrivateKey, error) { 215 sm2Priv := &PrivateKey{} 216 sm2Priv.Curve = ecPriv.Curve 217 if sm2Priv.Curve != P256Sm2() { 218 return nil, errors.New("sm2.ConvertECPriv2SM2Priv: 源私钥并未使用SM2曲线,无法转换") 219 } 220 sm2Priv.D = ecPriv.D 221 sm2Priv.X = ecPriv.X 222 sm2Priv.Y = ecPriv.Y 223 return sm2Priv, nil 224 } 225 226 // ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ 227 // SM2公私钥与hex相互转换 228 229 // ReadSm2PrivFromHex 将hex字符串转为sm2私钥 230 // 231 // @param Dhex 16进制字符串,对应sm2.PrivateKey.D 232 // @return *PrivateKey sm2私钥 233 // @return error 234 func ReadSm2PrivFromHex(Dhex string) (*PrivateKey, error) { 235 c := P256Sm2() 236 d, err := hex.DecodeString(Dhex) 237 if err != nil { 238 return nil, err 239 } 240 k := new(big.Int).SetBytes(d) 241 params := c.Params() 242 one := new(big.Int).SetInt64(1) 243 n := new(big.Int).Sub(params.N, one) 244 if k.Cmp(n) >= 0 { 245 return nil, errors.New("privateKey's D is overflow") 246 } 247 priv := new(PrivateKey) 248 priv.PublicKey.Curve = c 249 priv.D = k 250 priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) 251 return priv, nil 252 } 253 254 // WriteSm2PrivToHex 将sm2私钥D转为hex字符串 255 // 256 // @param key sm2私钥 257 // @return string 258 func WriteSm2PrivToHex(key *PrivateKey) string { 259 return key.D.Text(16) 260 } 261 262 // ReadSm2PubFromHex 将hex字符串转为sm2公钥 263 // 264 // @param Qhex sm2公钥座标x,y的字节数组拼接后的hex转码字符串 265 // @return *PublicKey sm2公钥 266 // @return error 267 func ReadSm2PubFromHex(Qhex string) (*PublicKey, error) { 268 q, err := hex.DecodeString(Qhex) 269 if err != nil { 270 return nil, err 271 } 272 if len(q) == 65 && q[0] == byte(0x04) { 273 q = q[1:] 274 } 275 if len(q) != 64 { 276 return nil, errors.New("publicKey is not uncompressed") 277 } 278 pub := new(PublicKey) 279 pub.Curve = P256Sm2() 280 pub.X = new(big.Int).SetBytes(q[:32]) 281 pub.Y = new(big.Int).SetBytes(q[32:]) 282 return pub, nil 283 } 284 285 // WriteSm2PubToHex 将sm2公钥转为hex字符串 286 // 287 // @param key sm2公钥 288 // @return string 289 func WriteSm2PubToHex(key *PublicKey) string { 290 x := key.X.Bytes() 291 y := key.Y.Bytes() 292 if n := len(x); n < 32 { 293 x = append(utils.ZeroByteSlice()[:32-n], x...) 294 } 295 if n := len(y); n < 32 { 296 y = append(utils.ZeroByteSlice()[:32-n], y...) 297 } 298 var c []byte 299 c = append(c, x...) 300 c = append(c, y...) 301 c = append([]byte{0x04}, c...) 302 return hex.EncodeToString(c) 303 } 304 305 // SM2公私钥与hex相互转换 306 // ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑