github.com/cloudflare/circl@v1.5.0/xof/k12/k12.go (about)

     1  // k12 implements the KangarooTwelve XOF.
     2  //
     3  // KangarooTwelve is being standardised at the CFRG working group
     4  // of the IRTF. This package implements draft 10.
     5  //
     6  // https://datatracker.ietf.org/doc/draft-irtf-cfrg-kangarootwelve/10/
     7  package k12
     8  
     9  import (
    10  	"encoding/binary"
    11  
    12  	"github.com/cloudflare/circl/internal/sha3"
    13  	"github.com/cloudflare/circl/simd/keccakf1600"
    14  )
    15  
    16  const chunkSize = 8192 // aka B
    17  
    18  // KangarooTwelve splits the message into chunks of 8192 bytes each.
    19  // The first chunk is absorbed directly in a TurboSHAKE128 instance, which
    20  // we call the stalk. The subsequent chunks aren't absorbed directly, but
    21  // instead their hash is absorbed: they're like leafs on a stalk.
    22  // If we have a fast TurboSHAKE128 available, we buffer chunks until we have
    23  // enough to do the parallel TurboSHAKE128. If not, we absorb directly into
    24  // a separate TurboSHAKE128 state.
    25  
    26  type State struct {
    27  	initialTodo int // Bytes left to absorb for the first chunk.
    28  
    29  	stalk sha3.State
    30  
    31  	context []byte // context string "C" provided by the user
    32  
    33  	// buffer of incoming data so we can do parallel TurboSHAKE128:
    34  	// nil when we haven't absorbed the first chunk yet;
    35  	// empty if we have, but we do not have a fast parallel TurboSHAKE128;
    36  	// and chunkSize*lanes in length if we have.
    37  	buf []byte
    38  
    39  	offset int // offset in buf or bytes written to leaf
    40  
    41  	// Number of chunk hashes ("CV_i") absorbed into the stalk.
    42  	chunk uint
    43  
    44  	// TurboSHAKE128 instance to compute the leaf in case we don't have
    45  	// a fast parallel TurboSHAKE128, viz when lanes == 1.
    46  	leaf *sha3.State
    47  
    48  	lanes uint8 // number of TurboSHAKE128s to compute in parallel
    49  }
    50  
    51  // NewDraft10 creates a new instance of Kangaroo12 draft version -10.
    52  func NewDraft10(c []byte) State {
    53  	var lanes byte = 1
    54  
    55  	if keccakf1600.IsEnabledX4() {
    56  		lanes = 4
    57  	} else if keccakf1600.IsEnabledX2() {
    58  		lanes = 2
    59  	}
    60  
    61  	return newDraft10(c, lanes)
    62  }
    63  
    64  func newDraft10(c []byte, lanes byte) State {
    65  	return State{
    66  		initialTodo: chunkSize,
    67  		stalk:       sha3.NewTurboShake128(0x07),
    68  		context:     c,
    69  		lanes:       lanes,
    70  	}
    71  }
    72  
    73  func (s *State) Reset() {
    74  	s.initialTodo = chunkSize
    75  	s.stalk.Reset()
    76  	s.stalk.SwitchDS(0x07)
    77  	s.buf = nil
    78  	s.offset = 0
    79  	s.chunk = 0
    80  }
    81  
    82  func (s *State) Clone() State {
    83  	stalk := s.stalk.Clone().(*sha3.State)
    84  	ret := State{
    85  		initialTodo: s.initialTodo,
    86  		stalk:       *stalk,
    87  		context:     s.context,
    88  		offset:      s.offset,
    89  		chunk:       s.chunk,
    90  		lanes:       s.lanes,
    91  	}
    92  
    93  	if s.leaf != nil {
    94  		ret.leaf = s.leaf.Clone().(*sha3.State)
    95  	}
    96  
    97  	if s.buf != nil {
    98  		ret.buf = make([]byte, len(s.buf))
    99  		copy(ret.buf, s.buf)
   100  	}
   101  
   102  	return ret
   103  }
   104  
   105  func Draft10Sum(hash []byte, msg []byte, c []byte) {
   106  	// TODO Tweak number of lanes depending on the length of the message
   107  	s := NewDraft10(c)
   108  	_, _ = s.Write(msg)
   109  	_, _ = s.Read(hash)
   110  }
   111  
   112  func (s *State) Write(p []byte) (int, error) {
   113  	written := len(p)
   114  
   115  	// The first chunk is written directly to the stalk.
   116  	if s.initialTodo > 0 {
   117  		taken := s.initialTodo
   118  		if len(p) < taken {
   119  			taken = len(p)
   120  		}
   121  		headP := p[:taken]
   122  		_, _ = s.stalk.Write(headP)
   123  		s.initialTodo -= taken
   124  		p = p[taken:]
   125  	}
   126  
   127  	if len(p) == 0 {
   128  		return written, nil
   129  	}
   130  
   131  	// If this is the first bit of data written after the initial chunk,
   132  	// we're out of the fast-path and allocate some buffers.
   133  	if s.buf == nil {
   134  		if s.lanes != 1 {
   135  			s.buf = make([]byte, int(s.lanes)*chunkSize)
   136  		} else {
   137  			// We create the buffer to signal we're past the first chunk,
   138  			// but do not use it.
   139  			s.buf = make([]byte, 0)
   140  			h := sha3.NewTurboShake128(0x0B)
   141  			s.leaf = &h
   142  		}
   143  		_, _ = s.stalk.Write([]byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
   144  		s.stalk.SwitchDS(0x06)
   145  	}
   146  
   147  	// If we're just using one lane, we don't need to cache in a buffer
   148  	// for parallel hashing. Instead, we feed directly to TurboSHAKE.
   149  	if s.lanes == 1 {
   150  		for len(p) > 0 {
   151  			// Write to current leaf.
   152  			to := chunkSize - s.offset
   153  			if len(p) < to {
   154  				to = len(p)
   155  			}
   156  			_, _ = s.leaf.Write(p[:to])
   157  			p = p[to:]
   158  			s.offset += to
   159  
   160  			// Did we fill the chunk?
   161  			if s.offset == chunkSize {
   162  				var cv [32]byte
   163  				_, _ = s.leaf.Read(cv[:])
   164  				_, _ = s.stalk.Write(cv[:])
   165  				s.leaf.Reset()
   166  				s.offset = 0
   167  				s.chunk++
   168  			}
   169  		}
   170  
   171  		return written, nil
   172  	}
   173  
   174  	// If we can't fill all our lanes or the buffer isn't empty, we write the
   175  	// data to the buffer.
   176  	if s.offset != 0 || len(p) < len(s.buf) {
   177  		to := len(s.buf) - s.offset
   178  		if len(p) < to {
   179  			to = len(p)
   180  		}
   181  		p2 := p[:to]
   182  		p = p[to:]
   183  		copy(s.buf[s.offset:], p2)
   184  		s.offset += to
   185  	}
   186  
   187  	// Absorb the buffer if we filled it
   188  	if s.offset == len(s.buf) {
   189  		s.writeX(s.buf)
   190  		s.offset = 0
   191  	}
   192  
   193  	// Note that at this point we may assume that s.offset = 0 if len(p) != 0
   194  	if len(p) != 0 && s.offset != 0 {
   195  		panic("shouldn't happen")
   196  	}
   197  
   198  	// Absorb a bunch of chunks at the same time.
   199  	if len(p) >= int(s.lanes)*chunkSize {
   200  		p = s.writeX(p)
   201  	}
   202  
   203  	// Put the remainder in the buffer.
   204  	if len(p) > 0 {
   205  		copy(s.buf, p)
   206  		s.offset = len(p)
   207  	}
   208  
   209  	return written, nil
   210  }
   211  
   212  // Absorb a multiple of a multiple of lanes * chunkSize.
   213  // Returns the remainder.
   214  func (s *State) writeX(p []byte) []byte {
   215  	switch s.lanes {
   216  	case 4:
   217  		return s.writeX4(p)
   218  	default:
   219  		return s.writeX2(p)
   220  	}
   221  }
   222  
   223  func (s *State) writeX4(p []byte) []byte {
   224  	for len(p) >= 4*chunkSize {
   225  		var x4 keccakf1600.StateX4
   226  		a := x4.Initialize(true)
   227  
   228  		for offset := 0; offset < 48*168; offset += 168 {
   229  			for i := 0; i < 21; i++ {
   230  				a[i*4] ^= binary.LittleEndian.Uint64(
   231  					p[8*i+offset:],
   232  				)
   233  				a[i*4+1] ^= binary.LittleEndian.Uint64(
   234  					p[chunkSize+8*i+offset:],
   235  				)
   236  				a[i*4+2] ^= binary.LittleEndian.Uint64(
   237  					p[chunkSize*2+8*i+offset:],
   238  				)
   239  				a[i*4+3] ^= binary.LittleEndian.Uint64(
   240  					p[chunkSize*3+8*i+offset:],
   241  				)
   242  			}
   243  
   244  			x4.Permute()
   245  		}
   246  
   247  		for i := 0; i < 16; i++ {
   248  			a[i*4] ^= binary.LittleEndian.Uint64(
   249  				p[8*i+48*168:],
   250  			)
   251  			a[i*4+1] ^= binary.LittleEndian.Uint64(
   252  				p[chunkSize+8*i+48*168:],
   253  			)
   254  			a[i*4+2] ^= binary.LittleEndian.Uint64(
   255  				p[chunkSize*2+8*i+48*168:],
   256  			)
   257  			a[i*4+3] ^= binary.LittleEndian.Uint64(
   258  				p[chunkSize*3+8*i+48*168:],
   259  			)
   260  		}
   261  
   262  		a[16*4] ^= 0x0b
   263  		a[16*4+1] ^= 0x0b
   264  		a[16*4+2] ^= 0x0b
   265  		a[16*4+3] ^= 0x0b
   266  		a[20*4] ^= 0x80 << 56
   267  		a[20*4+1] ^= 0x80 << 56
   268  		a[20*4+2] ^= 0x80 << 56
   269  		a[20*4+3] ^= 0x80 << 56
   270  
   271  		x4.Permute()
   272  
   273  		var buf [32 * 4]byte
   274  		for i := 0; i < 4; i++ {
   275  			binary.LittleEndian.PutUint64(buf[8*i:], a[4*i])
   276  			binary.LittleEndian.PutUint64(buf[32+8*i:], a[4*i+1])
   277  			binary.LittleEndian.PutUint64(buf[32*2+8*i:], a[4*i+2])
   278  			binary.LittleEndian.PutUint64(buf[32*3+8*i:], a[4*i+3])
   279  		}
   280  
   281  		_, _ = s.stalk.Write(buf[:])
   282  		p = p[chunkSize*4:]
   283  		s.chunk += 4
   284  	}
   285  
   286  	return p
   287  }
   288  
   289  func (s *State) writeX2(p []byte) []byte {
   290  	// TODO On M2 Pro, 1/3 of the time is spent on this function
   291  	// and LittleEndian.Uint64 excluding the actual permutation.
   292  	// Rewriting in assembler might be worthwhile.
   293  	for len(p) >= 2*chunkSize {
   294  		var x2 keccakf1600.StateX2
   295  		a := x2.Initialize(true)
   296  
   297  		for offset := 0; offset < 48*168; offset += 168 {
   298  			for i := 0; i < 21; i++ {
   299  				a[i*2] ^= binary.LittleEndian.Uint64(
   300  					p[8*i+offset:],
   301  				)
   302  				a[i*2+1] ^= binary.LittleEndian.Uint64(
   303  					p[chunkSize+8*i+offset:],
   304  				)
   305  			}
   306  
   307  			x2.Permute()
   308  		}
   309  
   310  		for i := 0; i < 16; i++ {
   311  			a[i*2] ^= binary.LittleEndian.Uint64(
   312  				p[8*i+48*168:],
   313  			)
   314  			a[i*2+1] ^= binary.LittleEndian.Uint64(
   315  				p[chunkSize+8*i+48*168:],
   316  			)
   317  		}
   318  
   319  		a[16*2] ^= 0x0b
   320  		a[16*2+1] ^= 0x0b
   321  		a[20*2] ^= 0x80 << 56
   322  		a[20*2+1] ^= 0x80 << 56
   323  
   324  		x2.Permute()
   325  
   326  		var buf [32 * 2]byte
   327  		for i := 0; i < 4; i++ {
   328  			binary.LittleEndian.PutUint64(buf[8*i:], a[2*i])
   329  			binary.LittleEndian.PutUint64(buf[32+8*i:], a[2*i+1])
   330  		}
   331  
   332  		_, _ = s.stalk.Write(buf[:])
   333  		p = p[chunkSize*2:]
   334  		s.chunk += 2
   335  	}
   336  
   337  	return p
   338  }
   339  
   340  func (s *State) Read(p []byte) (int, error) {
   341  	if s.stalk.IsAbsorbing() {
   342  		// Write context string C
   343  		_, _ = s.Write(s.context)
   344  
   345  		// Write length_encode( |C| )
   346  		var buf [9]byte
   347  		binary.BigEndian.PutUint64(buf[:8], uint64(len(s.context)))
   348  
   349  		// Find first non-zero digit in big endian encoding of context length
   350  		i := 0
   351  		for buf[i] == 0 && i < 8 {
   352  			i++
   353  		}
   354  
   355  		buf[8] = byte(8 - i) // number of bytes to represent |C|
   356  		_, _ = s.Write(buf[i:])
   357  
   358  		// We need to write the chunk number if we're past the first chunk.
   359  		if s.buf != nil {
   360  			// Write last remaining chunk(s)
   361  			var cv [32]byte
   362  			if s.lanes == 1 {
   363  				if s.offset != 0 {
   364  					_, _ = s.leaf.Read(cv[:])
   365  					_, _ = s.stalk.Write(cv[:])
   366  					s.chunk++
   367  				}
   368  			} else {
   369  				remainingBuf := s.buf[:s.offset]
   370  				for len(remainingBuf) > 0 {
   371  					h := sha3.NewTurboShake128(0x0B)
   372  					to := chunkSize
   373  					if len(remainingBuf) < to {
   374  						to = len(remainingBuf)
   375  					}
   376  					_, _ = h.Write(remainingBuf[:to])
   377  					_, _ = h.Read(cv[:])
   378  					_, _ = s.stalk.Write(cv[:])
   379  					s.chunk++
   380  					remainingBuf = remainingBuf[to:]
   381  				}
   382  			}
   383  
   384  			// Write length_encode( chunk )
   385  			binary.BigEndian.PutUint64(buf[:8], uint64(s.chunk))
   386  
   387  			// Find first non-zero digit in big endian encoding of number of chunks
   388  			i = 0
   389  			for buf[i] == 0 && i < 8 {
   390  				i++
   391  			}
   392  
   393  			buf[8] = byte(8 - i) // number of bytes to represent number of chunks.
   394  			_, _ = s.stalk.Write(buf[i:])
   395  			_, _ = s.stalk.Write([]byte{0xff, 0xff})
   396  		}
   397  	}
   398  
   399  	return s.stalk.Read(p)
   400  }