github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm2/util.go (about)

     1  // Copyright 2022 s1ren@github.com/hxx258456.
     2  
     3  package sm2
     4  
     5  import (
     6  	"crypto/ecdsa"
     7  	"crypto/elliptic"
     8  	"encoding/hex"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"math/big"
    13  	"strings"
    14  	"sync"
    15  
    16  	"github.com/hxx258456/ccgo/utils"
    17  )
    18  
    19  var zero = big.NewInt(0)
    20  
    21  // 将大整数转为字节数组,并根据曲线位数计算出的字节数组长度对左侧补0
    22  func toBytes(curve elliptic.Curve, value *big.Int) []byte {
    23  	// 大整数的字节数组
    24  	bytes := value.Bytes()
    25  	// 需要的长度: (256 + 7) / 8 = 32
    26  	byteLen := (curve.Params().BitSize + 7) >> 3
    27  	if byteLen == len(bytes) {
    28  		return bytes
    29  	}
    30  	// 左侧补0
    31  	result := make([]byte, byteLen)
    32  	copy(result[byteLen-len(bytes):], bytes)
    33  	return result
    34  }
    35  
    36  // 将曲线上的点座标(x,y)转为未压缩字节数组
    37  //
    38  //	参考: GB/T 32918.1-2016 4.2.9
    39  func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
    40  	return elliptic.Marshal(curve, x, y)
    41  }
    42  
    43  // 将曲线上的点座标(x,y)转为压缩字节数组
    44  //
    45  //	返回的字节数组长度33, 第一位是C1压缩标识, 2代表y是偶数, 3代表y是奇数
    46  //	参考: GB/T 32918.1-2016 4.2.9
    47  func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
    48  	// buffer长度: (曲线位数(256) + 7) / 8 + 1 = 33
    49  	buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
    50  	// 将x的字节数组填入右侧32个字节
    51  	copy(buffer[1:], toBytes(curve, x))
    52  	// 首位字节是C1压缩标识
    53  	// 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。
    54  	// 又因为p是奇素数,所以两个y必然一奇一偶
    55  	if getLastBitOfY(x, y) > 0 {
    56  		// y最右侧一位为1,即奇数,压缩标识为 3
    57  		buffer[0] = compressed03
    58  	} else {
    59  		// y最右侧一位为0,即偶数,压缩标识为 2
    60  		buffer[0] = compressed02
    61  	}
    62  	return buffer
    63  }
    64  
    65  // 将曲线上的点座标(x,y)转为混合字节数组
    66  //
    67  //	参考: GB/T 32918.1-2016 4.2.9
    68  func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
    69  	// buffer是未做压缩的序列化字节数组, 长度65, 4 + x字节数组(32个) + y字节数组(32个)
    70  	buffer := elliptic.Marshal(curve, x, y)
    71  	// 修改首位的压缩标识
    72  	// TODO: 混合模式有何意义? C1实际并未压缩,把首位标识改为混合标识有啥用?
    73  	if getLastBitOfY(x, y) > 0 {
    74  		// y最右侧一位为1,即奇数,压缩标识为 7
    75  		buffer[0] = mixed07
    76  	} else {
    77  		// y最右侧一位为0,即偶数,压缩标识为 6
    78  		buffer[0] = mixed06
    79  	}
    80  	return buffer
    81  }
    82  
    83  // 获取y最后一位的值
    84  //
    85  //	x坐标为0时,直接返回0
    86  //	参考: GB/T 32918.1-2016 A.5.2
    87  func getLastBitOfY(x, y *big.Int) uint {
    88  	// x坐标为0时,直接返回0
    89  	if x.Cmp(zero) == 0 {
    90  		return 0
    91  	}
    92  	// 返回y最右侧一位的值
    93  	return y.Bit(0)
    94  }
    95  
    96  func toPointXY(bytes []byte) *big.Int {
    97  	return new(big.Int).SetBytes(bytes)
    98  }
    99  
   100  // 根据x坐标计算y坐标
   101  //
   102  //	参考: GB/T 32918.1-2016 A.5.2 B.1.4
   103  func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) {
   104  	// x3 : x^3
   105  	x3 := new(big.Int).Mul(x, x)
   106  	x3.Mul(x3, x)
   107  	// threeX : 3x
   108  	threeX := new(big.Int).Lsh(x, 1) // x*2
   109  	threeX.Add(threeX, x)            // x*2 + x = 3x
   110  
   111  	x3.Sub(x3, threeX)           // x^3 - 3x
   112  	x3.Add(x3, curve.Params().B) // x^3 - 3x + b
   113  	x3.Mod(x3, curve.Params().P) // (x^3 - 3x + b) mod p
   114  	// y² ≡ x³ - 3x + b (mod p) 的意思: y^2 和 (x^3 - 3x + b) 同余于p
   115  	// 但是上一步已经对x3做了一次模运算,所以下面的计算实际上是 y² ≡ ((x³ - 3x + b) mod p) (mod p)
   116  	// 两次模运算和一次模运算的结果其实是一样的: 23对10取余是3,3再对10取余还是3,大概用更小的x3可以加快计算速度?
   117  	y := x3.ModSqrt(x3, curve.Params().P)
   118  
   119  	if y == nil {
   120  		return nil, errors.New("can't calculate y based on x")
   121  	}
   122  	return y, nil
   123  }
   124  
   125  // 字节数组转为曲线上的点坐标
   126  //
   127  //	返回x,y数值,以及字节数组长度(未压缩/混合:65, 压缩:33)
   128  //	参考: GB/T 32918.1-2016 4.2.10 A.5.2
   129  func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
   130  	if len(bytes) < 1+(curve.Params().BitSize/8) {
   131  		return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes))
   132  	}
   133  	// 获取压缩标识
   134  	format := bytes[0]
   135  	byteLen := (curve.Params().BitSize + 7) >> 3
   136  	switch format {
   137  	case uncompressed, mixed06, mixed07: // what's the mixed format purpose?
   138  		// 未压缩,或混合模式下,直接将x,y分别取出转换
   139  		if len(bytes) < 1+byteLen*2 {
   140  			return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes))
   141  		}
   142  		x := toPointXY(bytes[1 : 1+byteLen])
   143  		y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
   144  		if !curve.IsOnCurve(x, y) {
   145  			return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name)
   146  		}
   147  		return x, y, 1 + byteLen*2, nil
   148  	case compressed02, compressed03:
   149  		// 压缩模式下
   150  		if len(bytes) < 1+byteLen {
   151  			return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes))
   152  		}
   153  		if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) {
   154  			// y² = x³ - 3x + b, prime curves
   155  			x := toPointXY(bytes[1 : 1+byteLen])
   156  			// 根据x推算y数值
   157  			y, err := calculatePrimeCurveY(curve, x)
   158  			if err != nil {
   159  				return nil, nil, 0, err
   160  			}
   161  			// 计算出的y的值与压缩标识冲突的话,则 y = p - y
   162  			// 因为椭圆曲线取模后的点是沿 y=p/2 这条线对称的,即一个x可能对应着两个y,这两个y关于 p/2 对称,因此 y1 = p - y2。
   163  			// 又因为p是奇素数,所以两个y必然一奇一偶
   164  			if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) {
   165  				y.Sub(curve.Params().P, y)
   166  			}
   167  			return x, y, 1 + byteLen, nil
   168  		}
   169  		return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name)
   170  	}
   171  	return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format)
   172  }
   173  
   174  var (
   175  	closedChanOnce sync.Once
   176  	closedChan     chan struct{}
   177  )
   178  
   179  // maybeReadByte reads a single byte from r with ~50% probability. This is used
   180  // to ensure that callers do not depend on non-guaranteed behaviour, e.g.
   181  // assuming that rsa.GenerateKey is deterministic w.r.t. a given random stream.
   182  //
   183  // This does not affect tests that pass a stream of fixed bytes as the random
   184  // source (e.g. a zeroReader).
   185  func maybeReadByte(r io.Reader) {
   186  	closedChanOnce.Do(func() {
   187  		closedChan = make(chan struct{})
   188  		close(closedChan)
   189  	})
   190  
   191  	select {
   192  	case <-closedChan:
   193  		return
   194  	case <-closedChan:
   195  		var buf [1]byte
   196  		_, err := r.Read(buf[:])
   197  		if err != nil {
   198  			panic(err)
   199  		}
   200  	}
   201  }
   202  
   203  //goland:noinspection GoUnusedExportedFunction
   204  func ConvertSM2Priv2ECPriv(sm2Priv *PrivateKey) (*ecdsa.PrivateKey, error) {
   205  	ecPriv := &ecdsa.PrivateKey{}
   206  	ecPriv.Curve = sm2Priv.Curve
   207  	ecPriv.D = sm2Priv.D
   208  	ecPriv.X = sm2Priv.X
   209  	ecPriv.Y = sm2Priv.Y
   210  	return ecPriv, nil
   211  }
   212  
   213  //goland:noinspection GoUnusedExportedFunction
   214  func ConvertECPriv2SM2Priv(ecPriv *ecdsa.PrivateKey) (*PrivateKey, error) {
   215  	sm2Priv := &PrivateKey{}
   216  	sm2Priv.Curve = ecPriv.Curve
   217  	if sm2Priv.Curve != P256Sm2() {
   218  		return nil, errors.New("sm2.ConvertECPriv2SM2Priv: 源私钥并未使用SM2曲线,无法转换")
   219  	}
   220  	sm2Priv.D = ecPriv.D
   221  	sm2Priv.X = ecPriv.X
   222  	sm2Priv.Y = ecPriv.Y
   223  	return sm2Priv, nil
   224  }
   225  
   226  // ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
   227  // SM2公私钥与hex相互转换
   228  
   229  // ReadSm2PrivFromHex 将hex字符串转为sm2私钥
   230  //
   231  //	@param Dhex 16进制字符串,对应sm2.PrivateKey.D
   232  //	@return *PrivateKey sm2私钥
   233  //	@return error
   234  func ReadSm2PrivFromHex(Dhex string) (*PrivateKey, error) {
   235  	c := P256Sm2()
   236  	d, err := hex.DecodeString(Dhex)
   237  	if err != nil {
   238  		return nil, err
   239  	}
   240  	k := new(big.Int).SetBytes(d)
   241  	params := c.Params()
   242  	one := new(big.Int).SetInt64(1)
   243  	n := new(big.Int).Sub(params.N, one)
   244  	if k.Cmp(n) >= 0 {
   245  		return nil, errors.New("privateKey's D is overflow")
   246  	}
   247  	priv := new(PrivateKey)
   248  	priv.PublicKey.Curve = c
   249  	priv.D = k
   250  	priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
   251  	return priv, nil
   252  }
   253  
   254  // WriteSm2PrivToHex 将sm2私钥D转为hex字符串
   255  //
   256  //	@param key sm2私钥
   257  //	@return string
   258  func WriteSm2PrivToHex(key *PrivateKey) string {
   259  	return key.D.Text(16)
   260  }
   261  
   262  // ReadSm2PubFromHex 将hex字符串转为sm2公钥
   263  //
   264  //	@param Qhex sm2公钥座标x,y的字节数组拼接后的hex转码字符串
   265  //	@return *PublicKey sm2公钥
   266  //	@return error
   267  func ReadSm2PubFromHex(Qhex string) (*PublicKey, error) {
   268  	q, err := hex.DecodeString(Qhex)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  	if len(q) == 65 && q[0] == byte(0x04) {
   273  		q = q[1:]
   274  	}
   275  	if len(q) != 64 {
   276  		return nil, errors.New("publicKey is not uncompressed")
   277  	}
   278  	pub := new(PublicKey)
   279  	pub.Curve = P256Sm2()
   280  	pub.X = new(big.Int).SetBytes(q[:32])
   281  	pub.Y = new(big.Int).SetBytes(q[32:])
   282  	return pub, nil
   283  }
   284  
   285  // WriteSm2PubToHex 将sm2公钥转为hex字符串
   286  //
   287  //	@param key sm2公钥
   288  //	@return string
   289  func WriteSm2PubToHex(key *PublicKey) string {
   290  	x := key.X.Bytes()
   291  	y := key.Y.Bytes()
   292  	if n := len(x); n < 32 {
   293  		x = append(utils.ZeroByteSlice()[:32-n], x...)
   294  	}
   295  	if n := len(y); n < 32 {
   296  		y = append(utils.ZeroByteSlice()[:32-n], y...)
   297  	}
   298  	var c []byte
   299  	c = append(c, x...)
   300  	c = append(c, y...)
   301  	c = append([]byte{0x04}, c...)
   302  	return hex.EncodeToString(c)
   303  }
   304  
   305  // SM2公私钥与hex相互转换
   306  // ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑