gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/sm2/util.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // core-gm is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  package sm2
    10  
    11  import (
    12  	"crypto/ecdsa"
    13  	"crypto/elliptic"
    14  	"encoding/hex"
    15  	"errors"
    16  	"fmt"
    17  	"gitee.com/ks-custle/core-gm/utils"
    18  	"io"
    19  	"math/big"
    20  	"strings"
    21  	"sync"
    22  )
    23  
    24  var zero = big.NewInt(0)
    25  
    26  // 将大整数转为字节数组,并根据曲线位数计算出的字节数组长度对左侧补0
    27  func toBytes(curve elliptic.Curve, value *big.Int) []byte {
    28  	// 大整数的字节数组
    29  	bytes := value.Bytes()
    30  	// 需要的长度: (256 + 7) / 8 = 32
    31  	byteLen := (curve.Params().BitSize + 7) >> 3
    32  	if byteLen == len(bytes) {
    33  		return bytes
    34  	}
    35  	// 左侧补0
    36  	result := make([]byte, byteLen)
    37  	copy(result[byteLen-len(bytes):], bytes)
    38  	return result
    39  }
    40  
    41  // 将曲线上的点座标(x,y)转为未压缩字节数组
    42  //
    43  //	参考: GB/T 32918.1-2016 4.2.9
    44  func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
    45  	return elliptic.Marshal(curve, x, y)
    46  }
    47  
    48  // 将曲线上的点座标(x,y)转为压缩字节数组
    49  //
    50  //	返回的字节数组长度33, 第一位是C1压缩标识, 2代表y是偶数, 3代表y是奇数
    51  //	参考: GB/T 32918.1-2016 4.2.9
    52  func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
    53  	// buffer长度: (曲线位数(256) + 7) / 8 + 1 = 33
    54  	buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
    55  	// 将x的字节数组填入右侧32个字节
    56  	copy(buffer[1:], toBytes(curve, x))
    57  	// 首位字节是C1压缩标识
    58  	// 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。
    59  	// 又因为p是奇素数,所以两个y必然一奇一偶
    60  	if getLastBitOfY(x, y) > 0 {
    61  		// y最右侧一位为1,即奇数,压缩标识为 3
    62  		buffer[0] = compressed03
    63  	} else {
    64  		// y最右侧一位为0,即偶数,压缩标识为 2
    65  		buffer[0] = compressed02
    66  	}
    67  	return buffer
    68  }
    69  
    70  // 将曲线上的点座标(x,y)转为混合字节数组
    71  //
    72  //	参考: GB/T 32918.1-2016 4.2.9
    73  func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
    74  	// buffer是未做压缩的序列化字节数组, 长度65, 4 + x字节数组(32个) + y字节数组(32个)
    75  	buffer := elliptic.Marshal(curve, x, y)
    76  	// 修改首位的压缩标识
    77  	// TODO: 混合模式有何意义? C1实际并未压缩,把首位标识改为混合标识有啥用?
    78  	if getLastBitOfY(x, y) > 0 {
    79  		// y最右侧一位为1,即奇数,压缩标识为 7
    80  		buffer[0] = mixed07
    81  	} else {
    82  		// y最右侧一位为0,即偶数,压缩标识为 6
    83  		buffer[0] = mixed06
    84  	}
    85  	return buffer
    86  }
    87  
    88  // 获取y最后一位的值
    89  //
    90  //	x坐标为0时,直接返回0
    91  //	参考: GB/T 32918.1-2016 A.5.2
    92  func getLastBitOfY(x, y *big.Int) uint {
    93  	// x坐标为0时,直接返回0
    94  	if x.Cmp(zero) == 0 {
    95  		return 0
    96  	}
    97  	// 返回y最右侧一位的值
    98  	return y.Bit(0)
    99  }
   100  
   101  func toPointXY(bytes []byte) *big.Int {
   102  	return new(big.Int).SetBytes(bytes)
   103  }
   104  
   105  // 根据x坐标计算y坐标
   106  //
   107  //	参考: GB/T 32918.1-2016 A.5.2 B.1.4
   108  func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) {
   109  	// x3 : x^3
   110  	x3 := new(big.Int).Mul(x, x)
   111  	x3.Mul(x3, x)
   112  	// threeX : 3x
   113  	threeX := new(big.Int).Lsh(x, 1) // x*2
   114  	threeX.Add(threeX, x)            // x*2 + x = 3x
   115  
   116  	x3.Sub(x3, threeX)           // x^3 - 3x
   117  	x3.Add(x3, curve.Params().B) // x^3 - 3x + b
   118  	x3.Mod(x3, curve.Params().P) // (x^3 - 3x + b) mod p
   119  	// y² ≡ x³ - 3x + b (mod p) 的意思: y^2 和 (x^3 - 3x + b) 同余于p
   120  	// 但是上一步已经对x3做了一次模运算,所以下面的计算实际上是 y² ≡ ((x³ - 3x + b) mod p) (mod p)
   121  	// 两次模运算和一次模运算的结果其实是一样的: 23对10取余是3,3再对10取余还是3,大概用更小的x3可以加快计算速度?
   122  	y := x3.ModSqrt(x3, curve.Params().P)
   123  
   124  	if y == nil {
   125  		return nil, errors.New("can't calculate y based on x")
   126  	}
   127  	return y, nil
   128  }
   129  
   130  // 字节数组转为曲线上的点坐标
   131  //
   132  //	返回x,y数值,以及字节数组长度(未压缩/混合:65, 压缩:33)
   133  //	参考: GB/T 32918.1-2016 4.2.10 A.5.2
   134  func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
   135  	if len(bytes) < 1+(curve.Params().BitSize/8) {
   136  		return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes))
   137  	}
   138  	// 获取压缩标识
   139  	format := bytes[0]
   140  	byteLen := (curve.Params().BitSize + 7) >> 3
   141  	switch format {
   142  	case uncompressed, mixed06, mixed07: // what's the mixed format purpose?
   143  		// 未压缩,或混合模式下,直接将x,y分别取出转换
   144  		if len(bytes) < 1+byteLen*2 {
   145  			return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes))
   146  		}
   147  		x := toPointXY(bytes[1 : 1+byteLen])
   148  		y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
   149  		if !curve.IsOnCurve(x, y) {
   150  			return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name)
   151  		}
   152  		return x, y, 1 + byteLen*2, nil
   153  	case compressed02, compressed03:
   154  		// 压缩模式下
   155  		if len(bytes) < 1+byteLen {
   156  			return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes))
   157  		}
   158  		if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) {
   159  			// y² = x³ - 3x + b, prime curves
   160  			x := toPointXY(bytes[1 : 1+byteLen])
   161  			// 根据x推算y数值
   162  			y, err := calculatePrimeCurveY(curve, x)
   163  			if err != nil {
   164  				return nil, nil, 0, err
   165  			}
   166  			// 计算出的y的值与压缩标识冲突的话,则 y = p - y
   167  			// 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。
   168  			// 又因为p是奇素数,所以两个y必然一奇一偶
   169  			if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) {
   170  				y.Sub(curve.Params().P, y)
   171  			}
   172  			return x, y, 1 + byteLen, nil
   173  		}
   174  		return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name)
   175  	}
   176  	return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format)
   177  }
   178  
   179  var (
   180  	closedChanOnce sync.Once
   181  	closedChan     chan struct{}
   182  )
   183  
   184  // maybeReadByte reads a single byte from r with ~50% probability. This is used
   185  // to ensure that callers do not depend on non-guaranteed behaviour, e.g.
   186  // assuming that rsa.GenerateKey is deterministic w.r.t. a given random stream.
   187  //
   188  // This does not affect tests that pass a stream of fixed bytes as the random
   189  // source (e.g. a zeroReader).
   190  func maybeReadByte(r io.Reader) {
   191  	closedChanOnce.Do(func() {
   192  		closedChan = make(chan struct{})
   193  		close(closedChan)
   194  	})
   195  
   196  	select {
   197  	case <-closedChan:
   198  		return
   199  	case <-closedChan:
   200  		var buf [1]byte
   201  		_, err := r.Read(buf[:])
   202  		if err != nil {
   203  			panic(err)
   204  		}
   205  	}
   206  }
   207  
   208  //goland:noinspection GoUnusedExportedFunction
   209  func ConvertSM2Priv2ECPriv(sm2Priv *PrivateKey) (*ecdsa.PrivateKey, error) {
   210  	ecPriv := &ecdsa.PrivateKey{}
   211  	ecPriv.Curve = sm2Priv.Curve
   212  	ecPriv.D = sm2Priv.D
   213  	ecPriv.X = sm2Priv.X
   214  	ecPriv.Y = sm2Priv.Y
   215  	return ecPriv, nil
   216  }
   217  
   218  //goland:noinspection GoUnusedExportedFunction
   219  func ConvertECPriv2SM2Priv(ecPriv *ecdsa.PrivateKey) (*PrivateKey, error) {
   220  	sm2Priv := &PrivateKey{}
   221  	sm2Priv.Curve = ecPriv.Curve
   222  	if sm2Priv.Curve != P256Sm2() {
   223  		return nil, errors.New("sm2.ConvertECPriv2SM2Priv: 源私钥并未使用SM2曲线,无法转换")
   224  	}
   225  	sm2Priv.D = ecPriv.D
   226  	sm2Priv.X = ecPriv.X
   227  	sm2Priv.Y = ecPriv.Y
   228  	return sm2Priv, nil
   229  }
   230  
   231  // ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
   232  // SM2公私钥与hex相互转换
   233  
   234  // ReadSm2PrivFromHex 将hex字符串转为sm2私钥
   235  //
   236  //	@param Dhex 16进制字符串,对应sm2.PrivateKey.D
   237  //	@return *PrivateKey sm2私钥
   238  //	@return error
   239  func ReadSm2PrivFromHex(Dhex string) (*PrivateKey, error) {
   240  	c := P256Sm2()
   241  	d, err := hex.DecodeString(Dhex)
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  	k := new(big.Int).SetBytes(d)
   246  	params := c.Params()
   247  	one := new(big.Int).SetInt64(1)
   248  	n := new(big.Int).Sub(params.N, one)
   249  	if k.Cmp(n) >= 0 {
   250  		return nil, errors.New("privateKey's D is overflow")
   251  	}
   252  	priv := new(PrivateKey)
   253  	priv.PublicKey.Curve = c
   254  	priv.D = k
   255  	priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
   256  	return priv, nil
   257  }
   258  
   259  // WriteSm2PrivToHex 将sm2私钥D转为hex字符串
   260  //
   261  //	@param key sm2私钥
   262  //	@return string
   263  func WriteSm2PrivToHex(key *PrivateKey) string {
   264  	return key.D.Text(16)
   265  }
   266  
   267  // ReadSm2PubFromHex 将hex字符串转为sm2公钥
   268  //
   269  //	@param Qhex sm2公钥座标x,y的字节数组拼接后的hex转码字符串
   270  //	@return *PublicKey sm2公钥
   271  //	@return error
   272  func ReadSm2PubFromHex(Qhex string) (*PublicKey, error) {
   273  	q, err := hex.DecodeString(Qhex)
   274  	if err != nil {
   275  		return nil, err
   276  	}
   277  	if len(q) == 65 && q[0] == byte(0x04) {
   278  		q = q[1:]
   279  	}
   280  	if len(q) != 64 {
   281  		return nil, errors.New("publicKey is not uncompressed")
   282  	}
   283  	pub := new(PublicKey)
   284  	pub.Curve = P256Sm2()
   285  	pub.X = new(big.Int).SetBytes(q[:32])
   286  	pub.Y = new(big.Int).SetBytes(q[32:])
   287  	return pub, nil
   288  }
   289  
   290  // WriteSm2PubToHex 将sm2公钥转为hex字符串
   291  //
   292  //	@param key sm2公钥
   293  //	@return string
   294  func WriteSm2PubToHex(key *PublicKey) string {
   295  	x := key.X.Bytes()
   296  	y := key.Y.Bytes()
   297  	if n := len(x); n < 32 {
   298  		x = append(utils.ZeroByteSlice()[:32-n], x...)
   299  	}
   300  	if n := len(y); n < 32 {
   301  		y = append(utils.ZeroByteSlice()[:32-n], y...)
   302  	}
   303  	var c []byte
   304  	c = append(c, x...)
   305  	c = append(c, y...)
   306  	c = append([]byte{0x04}, c...)
   307  	return hex.EncodeToString(c)
   308  }
   309  
   310  // SM2公私钥与hex相互转换
   311  // ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑