github.com/pingcap/tidb/parser@v0.0.0-20231013125129-93a834a6bf8d/auth/tidb_sm3.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package auth
    15  
    16  import (
    17  	"encoding/binary"
    18  	"hash"
    19  )
    20  
    21  // The concrete Sm3Hash Cryptographic Hash Algorithm can be accessed in http://www.sca.gov.cn/sca/xwdt/2010-12/17/content_1002389.shtml
    22  // This implementation of 'type sm3 struct' is modified from https://github.com/tjfoc/gmsm/tree/601ddb090dcf53d7951cc4dcc66276e2b817837c/sm3
    23  // Some other references:
    24  // 	https://datatracker.ietf.org/doc/draft-sca-cfrg-sm3/
    25  
    26  /*
    27  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
    28  Licensed under the Apache License, Version 2.0 (the "License");
    29  you may not use this file except in compliance with the License.
    30  You may obtain a copy of the License at
    31                   http://www.apache.org/licenses/LICENSE-2.0
    32  Unless required by applicable law or agreed to in writing, software
    33  distributed under the License is distributed on an "AS IS" BASIS,
    34  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    35  See the License for the specific language governing permissions and
    36  limitations under the License.
    37  */
    38  
    39  type sm3 struct {
    40  	digest      [8]uint32 // digest represents the partial evaluation of V
    41  	length      uint64    // length of the message
    42  	unhandleMsg []byte
    43  	blockSize   int
    44  	size        int
    45  }
    46  
    47  func ff0(x, y, z uint32) uint32 { return x ^ y ^ z }
    48  
    49  func ff1(x, y, z uint32) uint32 { return (x & y) | (x & z) | (y & z) }
    50  
    51  func gg0(x, y, z uint32) uint32 { return x ^ y ^ z }
    52  
    53  func gg1(x, y, z uint32) uint32 { return (x & y) | (^x & z) }
    54  
    55  func p0(x uint32) uint32 { return x ^ leftRotate(x, 9) ^ leftRotate(x, 17) }
    56  
    57  func p1(x uint32) uint32 { return x ^ leftRotate(x, 15) ^ leftRotate(x, 23) }
    58  
    59  func leftRotate(x uint32, i uint32) uint32 { return x<<(i%32) | x>>(32-i%32) }
    60  
    61  func (sm3 *sm3) pad() []byte {
    62  	msg := sm3.unhandleMsg
    63  	// Append '1'
    64  	msg = append(msg, 0x80)
    65  	// Append until the resulting message length (in bits) is congruent to 448 (mod 512)
    66  	blockSize := 64
    67  	for i := len(msg); i%blockSize != 56; i++ {
    68  		msg = append(msg, 0x00)
    69  	}
    70  	// append message length
    71  	msg = append(msg, uint8(sm3.length>>56&0xff))
    72  	msg = append(msg, uint8(sm3.length>>48&0xff))
    73  	msg = append(msg, uint8(sm3.length>>40&0xff))
    74  	msg = append(msg, uint8(sm3.length>>32&0xff))
    75  	msg = append(msg, uint8(sm3.length>>24&0xff))
    76  	msg = append(msg, uint8(sm3.length>>16&0xff))
    77  	msg = append(msg, uint8(sm3.length>>8&0xff))
    78  	msg = append(msg, uint8(sm3.length>>0&0xff))
    79  
    80  	return msg
    81  }
    82  
    83  func (sm3 *sm3) update(msg []byte) [8]uint32 {
    84  	var w [68]uint32
    85  	var w1 [64]uint32
    86  
    87  	a, b, c, d, e, f, g, h := sm3.digest[0], sm3.digest[1], sm3.digest[2], sm3.digest[3], sm3.digest[4], sm3.digest[5], sm3.digest[6], sm3.digest[7]
    88  	for len(msg) >= 64 {
    89  		for i := 0; i < 16; i++ {
    90  			w[i] = binary.BigEndian.Uint32(msg[4*i : 4*(i+1)])
    91  		}
    92  		for i := 16; i < 68; i++ {
    93  			w[i] = p1(w[i-16]^w[i-9]^leftRotate(w[i-3], 15)) ^ leftRotate(w[i-13], 7) ^ w[i-6]
    94  		}
    95  		for i := 0; i < 64; i++ {
    96  			w1[i] = w[i] ^ w[i+4]
    97  		}
    98  		a1, b1, c1, d1, e1, f1, g1, h1 := a, b, c, d, e, f, g, h
    99  		for i := 0; i < 16; i++ {
   100  			ss1 := leftRotate(leftRotate(a1, 12)+e1+leftRotate(0x79cc4519, uint32(i)), 7)
   101  			ss2 := ss1 ^ leftRotate(a1, 12)
   102  			tt1 := ff0(a1, b1, c1) + d1 + ss2 + w1[i]
   103  			tt2 := gg0(e1, f1, g1) + h1 + ss1 + w[i]
   104  			d1 = c1
   105  			c1 = leftRotate(b1, 9)
   106  			b1 = a1
   107  			a1 = tt1
   108  			h1 = g1
   109  			g1 = leftRotate(f1, 19)
   110  			f1 = e1
   111  			e1 = p0(tt2)
   112  		}
   113  		for i := 16; i < 64; i++ {
   114  			ss1 := leftRotate(leftRotate(a1, 12)+e1+leftRotate(0x7a879d8a, uint32(i)), 7)
   115  			ss2 := ss1 ^ leftRotate(a1, 12)
   116  			tt1 := ff1(a1, b1, c1) + d1 + ss2 + w1[i]
   117  			tt2 := gg1(e1, f1, g1) + h1 + ss1 + w[i]
   118  			d1 = c1
   119  			c1 = leftRotate(b1, 9)
   120  			b1 = a1
   121  			a1 = tt1
   122  			h1 = g1
   123  			g1 = leftRotate(f1, 19)
   124  			f1 = e1
   125  			e1 = p0(tt2)
   126  		}
   127  		a ^= a1
   128  		b ^= b1
   129  		c ^= c1
   130  		d ^= d1
   131  		e ^= e1
   132  		f ^= f1
   133  		g ^= g1
   134  		h ^= h1
   135  		msg = msg[64:]
   136  	}
   137  	var digest [8]uint32
   138  	digest[0], digest[1], digest[2], digest[3], digest[4], digest[5], digest[6], digest[7] = a, b, c, d, e, f, g, h
   139  	return digest
   140  }
   141  
   142  // BlockSize returns the hash's underlying block size.
   143  // The Write method must be able to accept any amount of data,
   144  // but it may operate more efficiently if all writes are a multiple of the block size.
   145  func (sm3 *sm3) BlockSize() int { return sm3.blockSize }
   146  
   147  // Size returns the number of bytes Sum will return.
   148  func (sm3 *sm3) Size() int { return sm3.size }
   149  
   150  // Reset clears the internal state by zeroing bytes in the state buffer.
   151  // This can be skipped for a newly-created hash state; the default zero-allocated state is correct.
   152  func (sm3 *sm3) Reset() {
   153  	// Reset digest
   154  	sm3.digest[0] = 0x7380166f
   155  	sm3.digest[1] = 0x4914b2b9
   156  	sm3.digest[2] = 0x172442d7
   157  	sm3.digest[3] = 0xda8a0600
   158  	sm3.digest[4] = 0xa96f30bc
   159  	sm3.digest[5] = 0x163138aa
   160  	sm3.digest[6] = 0xe38dee4d
   161  	sm3.digest[7] = 0xb0fb0e4e
   162  
   163  	sm3.length = 0
   164  	sm3.unhandleMsg = []byte{}
   165  	sm3.blockSize = 64
   166  	sm3.size = 32
   167  }
   168  
   169  // Write (via the embedded io.Writer interface) adds more data to the running hash.
   170  // It never returns an error.
   171  func (sm3 *sm3) Write(p []byte) (int, error) {
   172  	toWrite := len(p)
   173  	sm3.length += uint64(len(p) * 8)
   174  	msg := append(sm3.unhandleMsg, p...)
   175  	nblocks := len(msg) / sm3.BlockSize()
   176  	sm3.digest = sm3.update(msg)
   177  	sm3.unhandleMsg = msg[nblocks*sm3.BlockSize():]
   178  
   179  	return toWrite, nil
   180  }
   181  
   182  // Sum appends the current hash to b and returns the resulting slice.
   183  // It does not change the underlying hash state.
   184  func (sm3 *sm3) Sum(in []byte) []byte {
   185  	_, _ = sm3.Write(in)
   186  	msg := sm3.pad()
   187  	// Finalize
   188  	digest := sm3.update(msg)
   189  
   190  	// save hash to in
   191  	needed := sm3.Size()
   192  	if cap(in)-len(in) < needed {
   193  		newIn := make([]byte, len(in), len(in)+needed)
   194  		copy(newIn, in)
   195  		in = newIn
   196  	}
   197  	out := in[len(in) : len(in)+needed]
   198  	for i := 0; i < 8; i++ {
   199  		binary.BigEndian.PutUint32(out[i*4:], digest[i])
   200  	}
   201  	return out
   202  }
   203  
   204  // NewSM3 returns a new hash.Hash computing the Sm3Hash checksum.
   205  func NewSM3() hash.Hash {
   206  	var h sm3
   207  	h.Reset()
   208  	return &h
   209  }
   210  
   211  // Sm3Hash returns the sm3 checksum of the data.
   212  func Sm3Hash(data []byte) []byte {
   213  	h := NewSM3()
   214  	h.Write(data)
   215  	return h.Sum(nil)
   216  }