gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/sm2/util.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // gmgo is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  package sm2
    10  
    11  import (
    12  	"crypto/ecdsa"
    13  	"crypto/elliptic"
    14  	"encoding/hex"
    15  	"errors"
    16  	"fmt"
    17  	"gitee.com/zhaochuninhefei/gmgo/utils"
    18  	"io"
    19  	"math/big"
    20  	"strings"
    21  	"sync"
    22  )
    23  
    24  var zero = big.NewInt(0)
    25  
    26  // 将大整数转为字节数组,并根据曲线位数计算出的字节数组长度对左侧补0
    27  func toBytes(curve elliptic.Curve, value *big.Int) []byte {
    28  	// 大整数的字节数组
    29  	bytes := value.Bytes()
    30  	// 需要的长度: (256 + 7) / 8 = 32
    31  	byteLen := (curve.Params().BitSize + 7) >> 3
    32  	if byteLen == len(bytes) {
    33  		return bytes
    34  	}
    35  	// 左侧补0
    36  	result := make([]byte, byteLen)
    37  	copy(result[byteLen-len(bytes):], bytes)
    38  	return result
    39  }
    40  
    41  // 将曲线上的点座标(x,y)转为未压缩字节数组
    42  //  参考: GB/T 32918.1-2016 4.2.9
    43  func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
    44  	return elliptic.Marshal(curve, x, y)
    45  }
    46  
    47  // 将曲线上的点座标(x,y)转为压缩字节数组
    48  //  返回的字节数组长度33, 第一位是C1压缩标识, 2代表y是偶数, 3代表y是奇数
    49  //  参考: GB/T 32918.1-2016 4.2.9
    50  func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
    51  	// buffer长度: (曲线位数(256) + 7) / 8 + 1 = 33
    52  	buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
    53  	// 将x的字节数组填入右侧32个字节
    54  	copy(buffer[1:], toBytes(curve, x))
    55  	// 首位字节是C1压缩标识
    56  	// 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。
    57  	// 又因为p是奇素数,所以两个y必然一奇一偶
    58  	if getLastBitOfY(x, y) > 0 {
    59  		// y最右侧一位为1,即奇数,压缩标识为 3
    60  		buffer[0] = compressed03
    61  	} else {
    62  		// y最右侧一位为0,即偶数,压缩标识为 2
    63  		buffer[0] = compressed02
    64  	}
    65  	return buffer
    66  }
    67  
    68  // 将曲线上的点座标(x,y)转为混合字节数组
    69  //  参考: GB/T 32918.1-2016 4.2.9
    70  func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
    71  	// buffer是未做压缩的序列化字节数组, 长度65, 4 + x字节数组(32个) + y字节数组(32个)
    72  	buffer := elliptic.Marshal(curve, x, y)
    73  	// 修改首位的压缩标识
    74  	// TODO: 混合模式有何意义? C1实际并未压缩,把首位标识改为混合标识有啥用?
    75  	if getLastBitOfY(x, y) > 0 {
    76  		// y最右侧一位为1,即奇数,压缩标识为 7
    77  		buffer[0] = mixed07
    78  	} else {
    79  		// y最右侧一位为0,即偶数,压缩标识为 6
    80  		buffer[0] = mixed06
    81  	}
    82  	return buffer
    83  }
    84  
    85  // 获取y最后一位的值
    86  //  x坐标为0时,直接返回0
    87  //  参考: GB/T 32918.1-2016 A.5.2
    88  func getLastBitOfY(x, y *big.Int) uint {
    89  	// x坐标为0时,直接返回0
    90  	if x.Cmp(zero) == 0 {
    91  		return 0
    92  	}
    93  	// 返回y最右侧一位的值
    94  	return y.Bit(0)
    95  }
    96  
    97  func toPointXY(bytes []byte) *big.Int {
    98  	return new(big.Int).SetBytes(bytes)
    99  }
   100  
   101  // 根据x坐标计算y坐标
   102  //  参考: GB/T 32918.1-2016 A.5.2 B.1.4
   103  func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) {
   104  	// x3 : x^3
   105  	x3 := new(big.Int).Mul(x, x)
   106  	x3.Mul(x3, x)
   107  	// threeX : 3x
   108  	threeX := new(big.Int).Lsh(x, 1) // x*2
   109  	threeX.Add(threeX, x)            // x*2 + x = 3x
   110  
   111  	x3.Sub(x3, threeX)           // x^3 - 3x
   112  	x3.Add(x3, curve.Params().B) // x^3 - 3x + b
   113  	x3.Mod(x3, curve.Params().P) // (x^3 - 3x + b) mod p
   114  	// y² ≡ x³ - 3x + b (mod p) 的意思: y^2 和 (x^3 - 3x + b) 同余于p
   115  	// 但是上一步已经对x3做了一次模运算,所以下面的计算实际上是 y² ≡ ((x³ - 3x + b) mod p) (mod p)
   116  	// 两次模运算和一次模运算的结果其实是一样的: 23对10取余是3,3再对10取余还是3,大概用更小的x3可以加快计算速度?
   117  	y := x3.ModSqrt(x3, curve.Params().P)
   118  
   119  	if y == nil {
   120  		return nil, errors.New("can't calculate y based on x")
   121  	}
   122  	return y, nil
   123  }
   124  
   125  // 字节数组转为曲线上的点坐标
   126  //  返回x,y数值,以及字节数组长度(未压缩/混合:65, 压缩:33)
   127  //  参考: GB/T 32918.1-2016 4.2.10 A.5.2
   128  func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
   129  	if len(bytes) < 1+(curve.Params().BitSize/8) {
   130  		return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes))
   131  	}
   132  	// 获取压缩标识
   133  	format := bytes[0]
   134  	byteLen := (curve.Params().BitSize + 7) >> 3
   135  	switch format {
   136  	case uncompressed, mixed06, mixed07: // what's the mixed format purpose?
   137  		// 未压缩,或混合模式下,直接将x,y分别取出转换
   138  		if len(bytes) < 1+byteLen*2 {
   139  			return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes))
   140  		}
   141  		x := toPointXY(bytes[1 : 1+byteLen])
   142  		y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
   143  		if !curve.IsOnCurve(x, y) {
   144  			return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name)
   145  		}
   146  		return x, y, 1 + byteLen*2, nil
   147  	case compressed02, compressed03:
   148  		// 压缩模式下
   149  		if len(bytes) < 1+byteLen {
   150  			return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes))
   151  		}
   152  		if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) {
   153  			// y² = x³ - 3x + b, prime curves
   154  			x := toPointXY(bytes[1 : 1+byteLen])
   155  			// 根据x推算y数值
   156  			y, err := calculatePrimeCurveY(curve, x)
   157  			if err != nil {
   158  				return nil, nil, 0, err
   159  			}
   160  			// 计算出的y的值与压缩标识冲突的话,则 y = p - y
   161  			// 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。
   162  			// 又因为p是奇素数,所以两个y必然一奇一偶
   163  			if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) {
   164  				y.Sub(curve.Params().P, y)
   165  			}
   166  			return x, y, 1 + byteLen, nil
   167  		}
   168  		return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name)
   169  	}
   170  	return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format)
   171  }
   172  
   173  var (
   174  	closedChanOnce sync.Once
   175  	closedChan     chan struct{}
   176  )
   177  
   178  // maybeReadByte reads a single byte from r with ~50% probability. This is used
   179  // to ensure that callers do not depend on non-guaranteed behaviour, e.g.
   180  // assuming that rsa.GenerateKey is deterministic w.r.t. a given random stream.
   181  //
   182  // This does not affect tests that pass a stream of fixed bytes as the random
   183  // source (e.g. a zeroReader).
   184  func maybeReadByte(r io.Reader) {
   185  	closedChanOnce.Do(func() {
   186  		closedChan = make(chan struct{})
   187  		close(closedChan)
   188  	})
   189  
   190  	select {
   191  	case <-closedChan:
   192  		return
   193  	case <-closedChan:
   194  		var buf [1]byte
   195  		_, err := r.Read(buf[:])
   196  		if err != nil {
   197  			panic(err)
   198  		}
   199  	}
   200  }
   201  
   202  //goland:noinspection GoUnusedExportedFunction
   203  func ConvertSM2Priv2ECPriv(sm2Priv *PrivateKey) (*ecdsa.PrivateKey, error) {
   204  	ecPriv := &ecdsa.PrivateKey{}
   205  	ecPriv.Curve = sm2Priv.Curve
   206  	ecPriv.D = sm2Priv.D
   207  	ecPriv.X = sm2Priv.X
   208  	ecPriv.Y = sm2Priv.Y
   209  	return ecPriv, nil
   210  }
   211  
   212  //goland:noinspection GoUnusedExportedFunction
   213  func ConvertECPriv2SM2Priv(ecPriv *ecdsa.PrivateKey) (*PrivateKey, error) {
   214  	sm2Priv := &PrivateKey{}
   215  	sm2Priv.Curve = ecPriv.Curve
   216  	if sm2Priv.Curve != P256Sm2() {
   217  		return nil, errors.New("sm2.ConvertECPriv2SM2Priv: 源私钥并未使用SM2曲线,无法转换")
   218  	}
   219  	sm2Priv.D = ecPriv.D
   220  	sm2Priv.X = ecPriv.X
   221  	sm2Priv.Y = ecPriv.Y
   222  	return sm2Priv, nil
   223  }
   224  
   225  // ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
   226  // SM2公私钥与hex相互转换
   227  
   228  // ReadSm2PrivFromHex 将hex字符串转为sm2私钥
   229  //  @param Dhex 16进制字符串,对应sm2.PrivateKey.D
   230  //  @return *PrivateKey sm2私钥
   231  //  @return error
   232  func ReadSm2PrivFromHex(Dhex string) (*PrivateKey, error) {
   233  	c := P256Sm2()
   234  	d, err := hex.DecodeString(Dhex)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  	k := new(big.Int).SetBytes(d)
   239  	params := c.Params()
   240  	one := new(big.Int).SetInt64(1)
   241  	n := new(big.Int).Sub(params.N, one)
   242  	if k.Cmp(n) >= 0 {
   243  		return nil, errors.New("privateKey's D is overflow")
   244  	}
   245  	priv := new(PrivateKey)
   246  	priv.PublicKey.Curve = c
   247  	priv.D = k
   248  	priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
   249  	return priv, nil
   250  }
   251  
   252  // WriteSm2PrivToHex 将sm2私钥D转为hex字符串
   253  //  @param key sm2私钥
   254  //  @return string
   255  func WriteSm2PrivToHex(key *PrivateKey) string {
   256  	return key.D.Text(16)
   257  }
   258  
   259  // ReadSm2PubFromHex 将hex字符串转为sm2公钥
   260  //  @param Qhex sm2公钥座标x,y的字节数组拼接后的hex转码字符串
   261  //  @return *PublicKey sm2公钥
   262  //  @return error
   263  func ReadSm2PubFromHex(Qhex string) (*PublicKey, error) {
   264  	q, err := hex.DecodeString(Qhex)
   265  	if err != nil {
   266  		return nil, err
   267  	}
   268  	if len(q) == 65 && q[0] == byte(0x04) {
   269  		q = q[1:]
   270  	}
   271  	if len(q) != 64 {
   272  		return nil, errors.New("publicKey is not uncompressed")
   273  	}
   274  	pub := new(PublicKey)
   275  	pub.Curve = P256Sm2()
   276  	pub.X = new(big.Int).SetBytes(q[:32])
   277  	pub.Y = new(big.Int).SetBytes(q[32:])
   278  	return pub, nil
   279  }
   280  
   281  // WriteSm2PubToHex 将sm2公钥转为hex字符串
   282  //  @param key sm2公钥
   283  //  @return string
   284  func WriteSm2PubToHex(key *PublicKey) string {
   285  	x := key.X.Bytes()
   286  	y := key.Y.Bytes()
   287  	if n := len(x); n < 32 {
   288  		x = append(utils.ZeroByteSlice()[:32-n], x...)
   289  	}
   290  	if n := len(y); n < 32 {
   291  		y = append(utils.ZeroByteSlice()[:32-n], y...)
   292  	}
   293  	var c []byte
   294  	c = append(c, x...)
   295  	c = append(c, y...)
   296  	c = append([]byte{0x04}, c...)
   297  	return hex.EncodeToString(c)
   298  }
   299  
   300  // SM2公私钥与hex相互转换
   301  // ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑