github.com/mad-day/Yawning-crypto@v0.0.0-20190711051033-5a5f8cca32ec/morus/morus_ref.go (about)

     1  // morus_ref.go - Reference (portable) implementation
     2  //
     3  // To the extent possible under law, Yawning Angel has waived all copyright
     4  // and related or neighboring rights to the software, using the Creative
     5  // Commons "CC0" public domain dedication. See LICENSE or
     6  // <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
     7  
     8  package morus
     9  
    10  import (
    11  	"crypto/subtle"
    12  	"math/bits"
    13  )
    14  
    15  const (
    16  	n1 = 13
    17  	n2 = 46
    18  	n3 = 38
    19  	n4 = 7
    20  	n5 = 4
    21  
    22  	blockSize = 32
    23  )
    24  
    25  type state struct {
    26  	s [20]uint64
    27  }
    28  
    29  func (s *state) update(msgBlk []byte) {
    30  	var tmp uint64
    31  
    32  	s00, s01, s02, s03, s10, s11, s12, s13, s20, s21, s22, s23, s30, s31, s32, s33, s40, s41, s42, s43 := s.s[0], s.s[1], s.s[2], s.s[3], s.s[4], s.s[5], s.s[6], s.s[7], s.s[8], s.s[9], s.s[10], s.s[11], s.s[12], s.s[13], s.s[14], s.s[15], s.s[16], s.s[17], s.s[18], s.s[19]
    33  
    34  	_ = msgBlk[31] // Bounds check elimination
    35  	m0 := byteOrder.Uint64(msgBlk[0:8])
    36  	m1 := byteOrder.Uint64(msgBlk[8:16])
    37  	m2 := byteOrder.Uint64(msgBlk[16:24])
    38  	m3 := byteOrder.Uint64(msgBlk[24:32])
    39  
    40  	s00 ^= s30
    41  	s01 ^= s31
    42  	s02 ^= s32
    43  	s03 ^= s33
    44  	s00 ^= s10 & s20
    45  	s01 ^= s11 & s21
    46  	s02 ^= s12 & s22
    47  	s03 ^= s13 & s23
    48  	s00 = bits.RotateLeft64(s00, n1)
    49  	s01 = bits.RotateLeft64(s01, n1)
    50  	s02 = bits.RotateLeft64(s02, n1)
    51  	s03 = bits.RotateLeft64(s03, n1)
    52  	tmp = s33
    53  	s33 = s32
    54  	s32 = s31
    55  	s31 = s30
    56  	s30 = tmp
    57  
    58  	s10 ^= m0
    59  	s11 ^= m1
    60  	s12 ^= m2
    61  	s13 ^= m3
    62  	s10 ^= s40
    63  	s11 ^= s41
    64  	s12 ^= s42
    65  	s13 ^= s43
    66  	s10 ^= s20 & s30
    67  	s11 ^= s21 & s31
    68  	s12 ^= s22 & s32
    69  	s13 ^= s23 & s33
    70  	s10 = bits.RotateLeft64(s10, n2)
    71  	s11 = bits.RotateLeft64(s11, n2)
    72  	s12 = bits.RotateLeft64(s12, n2)
    73  	s13 = bits.RotateLeft64(s13, n2)
    74  	s43, s41 = s41, s43
    75  	s42, s40 = s40, s42
    76  
    77  	s20 ^= m0
    78  	s21 ^= m1
    79  	s22 ^= m2
    80  	s23 ^= m3
    81  	s20 ^= s00
    82  	s21 ^= s01
    83  	s22 ^= s02
    84  	s23 ^= s03
    85  	s20 ^= s30 & s40
    86  	s21 ^= s31 & s41
    87  	s22 ^= s32 & s42
    88  	s23 ^= s33 & s43
    89  	s20 = bits.RotateLeft64(s20, n3)
    90  	s21 = bits.RotateLeft64(s21, n3)
    91  	s22 = bits.RotateLeft64(s22, n3)
    92  	s23 = bits.RotateLeft64(s23, n3)
    93  	tmp = s00
    94  	s00 = s01
    95  	s01 = s02
    96  	s02 = s03
    97  	s03 = tmp
    98  
    99  	s30 ^= m0
   100  	s31 ^= m1
   101  	s32 ^= m2
   102  	s33 ^= m3
   103  	s30 ^= s10
   104  	s31 ^= s11
   105  	s32 ^= s12
   106  	s33 ^= s13
   107  	s30 ^= s40 & s00
   108  	s31 ^= s41 & s01
   109  	s32 ^= s42 & s02
   110  	s33 ^= s43 & s03
   111  	s30 = bits.RotateLeft64(s30, n4)
   112  	s31 = bits.RotateLeft64(s31, n4)
   113  	s32 = bits.RotateLeft64(s32, n4)
   114  	s33 = bits.RotateLeft64(s33, n4)
   115  	s13, s11 = s11, s13
   116  	s12, s10 = s10, s12
   117  
   118  	s40 ^= m0
   119  	s41 ^= m1
   120  	s42 ^= m2
   121  	s43 ^= m3
   122  	s40 ^= s20
   123  	s41 ^= s21
   124  	s42 ^= s22
   125  	s43 ^= s23
   126  	s40 ^= s00 & s10
   127  	s41 ^= s01 & s11
   128  	s42 ^= s02 & s12
   129  	s43 ^= s03 & s13
   130  	s40 = bits.RotateLeft64(s40, n5)
   131  	s41 = bits.RotateLeft64(s41, n5)
   132  	s42 = bits.RotateLeft64(s42, n5)
   133  	s43 = bits.RotateLeft64(s43, n5)
   134  	tmp = s23
   135  	s23 = s22
   136  	s22 = s21
   137  	s21 = s20
   138  	s20 = tmp
   139  
   140  	s.s[0], s.s[1], s.s[2], s.s[3], s.s[4], s.s[5], s.s[6], s.s[7], s.s[8], s.s[9], s.s[10], s.s[11], s.s[12], s.s[13], s.s[14], s.s[15], s.s[16], s.s[17], s.s[18], s.s[19] = s00, s01, s02, s03, s10, s11, s12, s13, s20, s21, s22, s23, s30, s31, s32, s33, s40, s41, s42, s43
   141  }
   142  
   143  func (s *state) encryptBlock(out, in []byte) {
   144  	_, _ = in[31], out[31] // Bounds check elimination
   145  	in0 := byteOrder.Uint64(in[0:8])
   146  	in1 := byteOrder.Uint64(in[8:16])
   147  	in2 := byteOrder.Uint64(in[16:24])
   148  	in3 := byteOrder.Uint64(in[24:32])
   149  
   150  	out0 := in0 ^ s.s[0] ^ s.s[5] ^ (s.s[8] & s.s[12])
   151  	out1 := in1 ^ s.s[1] ^ s.s[6] ^ (s.s[9] & s.s[13])
   152  	out2 := in2 ^ s.s[2] ^ s.s[7] ^ (s.s[10] & s.s[14])
   153  	out3 := in3 ^ s.s[3] ^ s.s[4] ^ (s.s[11] & s.s[15])
   154  
   155  	s.update(in[:32])
   156  
   157  	// Doing this last lets this work in place.
   158  	byteOrder.PutUint64(out[0:8], out0)
   159  	byteOrder.PutUint64(out[8:16], out1)
   160  	byteOrder.PutUint64(out[16:24], out2)
   161  	byteOrder.PutUint64(out[24:32], out3)
   162  }
   163  
   164  func (s *state) decryptBlockCommon(out, in []byte) {
   165  	_, _ = in[31], out[31] // Bounds check elimination
   166  	in0 := byteOrder.Uint64(in[0:8])
   167  	in1 := byteOrder.Uint64(in[8:16])
   168  	in2 := byteOrder.Uint64(in[16:24])
   169  	in3 := byteOrder.Uint64(in[24:32])
   170  
   171  	out0 := in0 ^ s.s[0] ^ s.s[5] ^ (s.s[8] & s.s[12])
   172  	out1 := in1 ^ s.s[1] ^ s.s[6] ^ (s.s[9] & s.s[13])
   173  	out2 := in2 ^ s.s[2] ^ s.s[7] ^ (s.s[10] & s.s[14])
   174  	out3 := in3 ^ s.s[3] ^ s.s[4] ^ (s.s[11] & s.s[15])
   175  
   176  	byteOrder.PutUint64(out[0:8], out0)
   177  	byteOrder.PutUint64(out[8:16], out1)
   178  	byteOrder.PutUint64(out[16:24], out2)
   179  	byteOrder.PutUint64(out[24:32], out3)
   180  }
   181  
   182  func (s *state) decryptBlock(out, in []byte) {
   183  	s.decryptBlockCommon(out, in)
   184  	s.update(out[:32])
   185  }
   186  
   187  func (s *state) decryptPartialBlock(out, in []byte) {
   188  	var tmp [blockSize]byte
   189  	copy(tmp[:], in)
   190  	s.decryptBlockCommon(tmp[:], tmp[:])
   191  	copy(out, tmp[:])
   192  
   193  	burnBytes(tmp[len(in):])
   194  	s.update(tmp[:])
   195  }
   196  
   197  func (s *state) init(key, iv []byte) {
   198  	_, _ = key[31], iv[15] // Bounds check elimination
   199  	k0 := byteOrder.Uint64(key[0:8])
   200  	k1 := byteOrder.Uint64(key[8:16])
   201  	k2 := byteOrder.Uint64(key[16:24])
   202  	k3 := byteOrder.Uint64(key[24:32])
   203  
   204  	s.s[0] = byteOrder.Uint64(iv[0:8])
   205  	s.s[1] = byteOrder.Uint64(iv[8:16])
   206  	s.s[2], s.s[3] = 0, 0
   207  	s.s[4], s.s[5], s.s[6], s.s[7] = k0, k1, k2, k3
   208  	s.s[8], s.s[9], s.s[10], s.s[11] = 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff
   209  	s.s[12], s.s[13], s.s[14], s.s[15] = 0, 0, 0, 0
   210  	s.s[16] = initializationConstants[0]
   211  	s.s[17] = initializationConstants[1]
   212  	s.s[18] = initializationConstants[2]
   213  	s.s[19] = initializationConstants[3]
   214  
   215  	var tmp [blockSize]byte
   216  	for i := 0; i < 16; i++ {
   217  		s.update(tmp[:])
   218  	}
   219  	s.s[4] ^= k0
   220  	s.s[5] ^= k1
   221  	s.s[6] ^= k2
   222  	s.s[7] ^= k3
   223  
   224  	burnBytes(tmp[:])
   225  }
   226  
   227  func (s *state) absorbData(in []byte) {
   228  	inLen, off := len(in), 0
   229  	if inLen == 0 {
   230  		return
   231  	}
   232  
   233  	for inLen >= blockSize {
   234  		s.update(in[off : off+blockSize])
   235  		inLen, off = inLen-blockSize, off+blockSize
   236  	}
   237  
   238  	if inLen > 0 {
   239  		var tmp [blockSize]byte
   240  		copy(tmp[:], in[off:])
   241  		s.update(tmp[:])
   242  	}
   243  }
   244  
   245  func (s *state) encryptData(out, in []byte) {
   246  	inLen, off := len(in), 0
   247  	if inLen == 0 {
   248  		return
   249  	}
   250  
   251  	for inLen >= blockSize {
   252  		s.encryptBlock(out[off:off+blockSize], in[off:off+blockSize])
   253  		inLen, off = inLen-blockSize, off+blockSize
   254  	}
   255  
   256  	if inLen > 0 {
   257  		var tmp [blockSize]byte
   258  		copy(tmp[:], in[off:])
   259  		s.encryptBlock(tmp[:], tmp[:])
   260  		copy(out[off:], tmp[:])
   261  	}
   262  }
   263  
   264  func (s *state) decryptData(out, in []byte) {
   265  	inLen, off := len(in), 0
   266  	if inLen == 0 {
   267  		return
   268  	}
   269  
   270  	for inLen >= blockSize {
   271  		s.decryptBlock(out[off:off+blockSize], in[off:off+blockSize])
   272  		inLen, off = inLen-blockSize, off+blockSize
   273  	}
   274  
   275  	if inLen > 0 {
   276  		s.decryptPartialBlock(out[off:], in[off:])
   277  	}
   278  }
   279  
   280  func (s *state) finalize(msgLen, adLen uint64, tag []byte) {
   281  	var tmp [blockSize]byte
   282  	byteOrder.PutUint64(tmp[0:8], (adLen << 3))
   283  	byteOrder.PutUint64(tmp[8:16], (msgLen << 3))
   284  
   285  	s.s[16] ^= s.s[0]
   286  	s.s[17] ^= s.s[1]
   287  	s.s[18] ^= s.s[2]
   288  	s.s[19] ^= s.s[3]
   289  
   290  	for i := 0; i < 10; i++ {
   291  		s.update(tmp[:])
   292  	}
   293  
   294  	s.s[0] = s.s[0] ^ s.s[5] ^ (s.s[8] & s.s[12])
   295  	s.s[1] = s.s[1] ^ s.s[6] ^ (s.s[9] & s.s[13])
   296  
   297  	_ = tag[15] // Bounds check elimination
   298  	byteOrder.PutUint64(tag[0:8], s.s[0])
   299  	byteOrder.PutUint64(tag[8:16], s.s[1])
   300  
   301  	burnBytes(tmp[:])
   302  }
   303  
   304  func aeadEncryptRef(c, m, a, nonce, key []byte) []byte {
   305  	var s state
   306  	mLen := len(m)
   307  
   308  	ret, out := sliceForAppend(c, mLen+TagSize)
   309  
   310  	s.init(key, nonce)
   311  	s.absorbData(a)
   312  	s.encryptData(out, m)
   313  	s.finalize(uint64(mLen), uint64(len(a)), out[mLen:])
   314  
   315  	burnUint64s(s.s[:])
   316  
   317  	return ret
   318  }
   319  
   320  func aeadDecryptRef(m, c, a, nonce, key []byte) ([]byte, bool) {
   321  	var s state
   322  	var tag [TagSize]byte
   323  	cLen := len(c)
   324  
   325  	if cLen < TagSize {
   326  		return nil, false
   327  	}
   328  
   329  	mLen := cLen - TagSize
   330  	ret, out := sliceForAppend(m, mLen)
   331  
   332  	s.init(key, nonce)
   333  	s.absorbData(a)
   334  	s.decryptData(out, c[:mLen])
   335  	s.finalize(uint64(mLen), uint64(len(a)), tag[:])
   336  
   337  	srcTag := c[mLen:]
   338  	ok := subtle.ConstantTimeCompare(srcTag, tag[:]) == 1
   339  	if !ok && mLen > 0 {
   340  		// Burn decrypted plaintext on auth failure.
   341  		burnBytes(out[:mLen])
   342  		ret = nil
   343  	}
   344  
   345  	burnUint64s(s.s[:])
   346  
   347  	return ret, ok
   348  }