github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/nacl/secretbox/secretbox_reader.go (about)

     1  /*
     2   * Copyright (c) 2017, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  // Copyright 2012 The Go Authors. All rights reserved.
    21  // Use of this source code is governed by a BSD-style
    22  // license that can be found in the LICENSE file.
    23  
    24  package secretbox // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/nacl/secretbox"
    25  
    26  import (
    27  	"crypto/subtle"
    28  	"encoding/binary"
    29  	"fmt"
    30  	"io"
    31  
    32  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/internal/poly1305"
    33  	"golang.org/x/crypto/salsa20/salsa"
    34  )
    35  
    36  // NewOpenReadSeeker is a streaming variant of Open.
    37  //
    38  // NewOpenReadSeeker is intended only for use in Psiphon with a payload that is
    39  // independently authenticated; and consideration has been given only for client-side
    40  // operation. Non-optimized reference implementation poly1305 and salsa20 code is used.
    41  //
    42  // The box is accessed through an io.ReadSeeker, which allows for an initial
    43  // poly1305 verification pass followed by a payload decryption pass, both
    44  // without loading the entire box into memory. As such, this implementation
    45  // should not be subject to the use-before-authentication or truncation attacks
    46  // discussed here:
    47  // https://github.com/golang/crypto/commit/9ba3862cf6a5452ae579de98f9364dd2e544844c#diff-9a969aca62172940631ad143523794ee
    48  // https://github.com/golang/go/issues/17673#issuecomment-275732868
    49  func NewOpenReadSeeker(box io.ReadSeeker, nonce *[24]byte, key *[32]byte) (io.ReadSeeker, error) {
    50  
    51  	r := &salsa20ReadSeeker{
    52  		box:   box,
    53  		nonce: *nonce,
    54  		key:   *key,
    55  	}
    56  
    57  	err := r.reset()
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	return r, nil
    63  }
    64  
    65  type salsa20ReadSeeker struct {
    66  	box         io.ReadSeeker
    67  	nonce       [24]byte
    68  	key         [32]byte
    69  	subKey      [32]byte
    70  	counter     [16]byte
    71  	block       [64]byte
    72  	blockOffset int
    73  }
    74  
    75  // Open x/crypto/nacl/secretbox/secretbox.go, adapted to streaming and rewinding.
    76  func (r *salsa20ReadSeeker) reset() error {
    77  
    78  	// See comments in Open
    79  
    80  	_, err := r.box.Seek(0, io.SeekStart)
    81  	if err != nil {
    82  		return fmt.Errorf("initial seek failed: %s", err)
    83  	}
    84  
    85  	var tag [poly1305.TagSize]byte
    86  
    87  	_, err = io.ReadFull(r.box, tag[:])
    88  	if err != nil {
    89  		return fmt.Errorf("read tag failed: %s", err)
    90  	}
    91  
    92  	var subKey [32]byte
    93  	var counter [16]byte
    94  	setup(&subKey, &counter, &r.nonce, &r.key)
    95  
    96  	// The Poly1305 key is generated by encrypting 32 bytes of zeros. Since
    97  	// Salsa20 works with 64-byte blocks, we also generate 32 bytes of
    98  	// keystream as a side effect.
    99  	var firstBlock [64]byte
   100  	salsa.XORKeyStream(firstBlock[:], firstBlock[:], &counter, &subKey)
   101  
   102  	var poly1305Key [32]byte
   103  	copy(poly1305Key[:], firstBlock[:])
   104  
   105  	err = poly1305VerifyReader(&tag, r.box, &poly1305Key)
   106  	if err != nil {
   107  		return err
   108  	}
   109  
   110  	_, err = r.box.Seek(int64(len(tag)), io.SeekStart)
   111  	if err != nil {
   112  		return fmt.Errorf("rewind seek failed: %s", err)
   113  	}
   114  
   115  	counter[8] = 1
   116  
   117  	r.subKey = subKey
   118  	r.counter = counter
   119  
   120  	// We XOR up to 32 bytes of box with the keystream generated from
   121  	// the first block.
   122  
   123  	r.block = firstBlock
   124  	r.blockOffset = 32
   125  
   126  	return nil
   127  }
   128  
   129  func (r *salsa20ReadSeeker) Read(p []byte) (int, error) {
   130  
   131  	n, err := r.box.Read(p)
   132  
   133  	for i := 0; i < n; i++ {
   134  		if r.blockOffset == 64 {
   135  			salsa20Core(&r.block, &r.counter, &r.subKey, &salsa.Sigma)
   136  
   137  			u := uint32(1)
   138  			for i := 8; i < 16; i++ {
   139  				u += uint32(r.counter[i])
   140  				r.counter[i] = byte(u)
   141  				u >>= 8
   142  			}
   143  			r.blockOffset = 0
   144  		}
   145  		p[i] = p[i] ^ r.block[r.blockOffset]
   146  		r.blockOffset++
   147  	}
   148  
   149  	return n, err
   150  }
   151  
   152  func (r *salsa20ReadSeeker) Seek(offset int64, whence int) (int64, error) {
   153  
   154  	// Currently only supports Seek(0, io.SeekStart) as required for Psiphon.
   155  
   156  	if offset != 0 || whence != io.SeekStart {
   157  		return -1, fmt.Errorf("unsupported")
   158  	}
   159  
   160  	// TODO: could skip poly1305 verify after 1st reset.
   161  
   162  	err := r.reset()
   163  	if err != nil {
   164  		return -1, err
   165  	}
   166  
   167  	return 0, nil
   168  }
   169  
   170  // Verify from crypto/poly1305/poly1305.go, modifed to use an io.Reader.
   171  func poly1305VerifyReader(mac *[16]byte, m io.Reader, key *[32]byte) error {
   172  	var tmp [16]byte
   173  	err := poly1305SumReader(&tmp, m, key)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	if subtle.ConstantTimeCompare(tmp[:], mac[:]) != 1 {
   178  		return fmt.Errorf("verify failed")
   179  	}
   180  	return nil
   181  }
   182  
   183  // Sum from crypto/poly1305/sum_ref.go, modifed to use an io.Reader.
   184  func poly1305SumReader(out *[poly1305.TagSize]byte, msg io.Reader, key *[32]byte) error {
   185  	var (
   186  		h0, h1, h2, h3, h4 uint32 // the hash accumulators
   187  		r0, r1, r2, r3, r4 uint64 // the r part of the key
   188  	)
   189  
   190  	r0 = uint64(binary.LittleEndian.Uint32(key[0:]) & 0x3ffffff)
   191  	r1 = uint64((binary.LittleEndian.Uint32(key[3:]) >> 2) & 0x3ffff03)
   192  	r2 = uint64((binary.LittleEndian.Uint32(key[6:]) >> 4) & 0x3ffc0ff)
   193  	r3 = uint64((binary.LittleEndian.Uint32(key[9:]) >> 6) & 0x3f03fff)
   194  	r4 = uint64((binary.LittleEndian.Uint32(key[12:]) >> 8) & 0x00fffff)
   195  
   196  	R1, R2, R3, R4 := r1*5, r2*5, r3*5, r4*5
   197  
   198  	var in [poly1305.TagSize]byte
   199  
   200  	for {
   201  		n, err := msg.Read(in[:])
   202  
   203  		if n == poly1305.TagSize {
   204  
   205  			// h += msg
   206  			h0 += binary.LittleEndian.Uint32(in[0:]) & 0x3ffffff
   207  			h1 += (binary.LittleEndian.Uint32(in[3:]) >> 2) & 0x3ffffff
   208  			h2 += (binary.LittleEndian.Uint32(in[6:]) >> 4) & 0x3ffffff
   209  			h3 += (binary.LittleEndian.Uint32(in[9:]) >> 6) & 0x3ffffff
   210  			h4 += (binary.LittleEndian.Uint32(in[12:]) >> 8) | (1 << 24)
   211  
   212  		} else if n > 0 {
   213  
   214  			in[n] = 0x01
   215  			for i := n + 1; i < poly1305.TagSize; i++ {
   216  				in[i] = 0
   217  			}
   218  
   219  			// h += msg
   220  			h0 += binary.LittleEndian.Uint32(in[0:]) & 0x3ffffff
   221  			h1 += (binary.LittleEndian.Uint32(in[3:]) >> 2) & 0x3ffffff
   222  			h2 += (binary.LittleEndian.Uint32(in[6:]) >> 4) & 0x3ffffff
   223  			h3 += (binary.LittleEndian.Uint32(in[9:]) >> 6) & 0x3ffffff
   224  			h4 += (binary.LittleEndian.Uint32(in[12:]) >> 8)
   225  		}
   226  
   227  		if n > 0 {
   228  
   229  			// h *= r
   230  			d0 := (uint64(h0) * r0) + (uint64(h1) * R4) + (uint64(h2) * R3) + (uint64(h3) * R2) + (uint64(h4) * R1)
   231  			d1 := (d0 >> 26) + (uint64(h0) * r1) + (uint64(h1) * r0) + (uint64(h2) * R4) + (uint64(h3) * R3) + (uint64(h4) * R2)
   232  			d2 := (d1 >> 26) + (uint64(h0) * r2) + (uint64(h1) * r1) + (uint64(h2) * r0) + (uint64(h3) * R4) + (uint64(h4) * R3)
   233  			d3 := (d2 >> 26) + (uint64(h0) * r3) + (uint64(h1) * r2) + (uint64(h2) * r1) + (uint64(h3) * r0) + (uint64(h4) * R4)
   234  			d4 := (d3 >> 26) + (uint64(h0) * r4) + (uint64(h1) * r3) + (uint64(h2) * r2) + (uint64(h3) * r1) + (uint64(h4) * r0)
   235  
   236  			// h %= p
   237  			h0 = uint32(d0) & 0x3ffffff
   238  			h1 = uint32(d1) & 0x3ffffff
   239  			h2 = uint32(d2) & 0x3ffffff
   240  			h3 = uint32(d3) & 0x3ffffff
   241  			h4 = uint32(d4) & 0x3ffffff
   242  
   243  			h0 += uint32(d4>>26) * 5
   244  			h1 += h0 >> 26
   245  			h0 = h0 & 0x3ffffff
   246  		}
   247  
   248  		if err == io.EOF {
   249  			break
   250  		}
   251  
   252  		if err != nil {
   253  			return err
   254  		}
   255  	}
   256  
   257  	// h %= p reduction
   258  	h2 += h1 >> 26
   259  	h1 &= 0x3ffffff
   260  	h3 += h2 >> 26
   261  	h2 &= 0x3ffffff
   262  	h4 += h3 >> 26
   263  	h3 &= 0x3ffffff
   264  	h0 += 5 * (h4 >> 26)
   265  	h4 &= 0x3ffffff
   266  	h1 += h0 >> 26
   267  	h0 &= 0x3ffffff
   268  
   269  	// h - p
   270  	t0 := h0 + 5
   271  	t1 := h1 + (t0 >> 26)
   272  	t2 := h2 + (t1 >> 26)
   273  	t3 := h3 + (t2 >> 26)
   274  	t4 := h4 + (t3 >> 26) - (1 << 26)
   275  	t0 &= 0x3ffffff
   276  	t1 &= 0x3ffffff
   277  	t2 &= 0x3ffffff
   278  	t3 &= 0x3ffffff
   279  
   280  	// select h if h < p else h - p
   281  	t_mask := (t4 >> 31) - 1
   282  	h_mask := ^t_mask
   283  	h0 = (h0 & h_mask) | (t0 & t_mask)
   284  	h1 = (h1 & h_mask) | (t1 & t_mask)
   285  	h2 = (h2 & h_mask) | (t2 & t_mask)
   286  	h3 = (h3 & h_mask) | (t3 & t_mask)
   287  	h4 = (h4 & h_mask) | (t4 & t_mask)
   288  
   289  	// h %= 2^128
   290  	h0 |= h1 << 26
   291  	h1 = ((h1 >> 6) | (h2 << 20))
   292  	h2 = ((h2 >> 12) | (h3 << 14))
   293  	h3 = ((h3 >> 18) | (h4 << 8))
   294  
   295  	// s: the s part of the key
   296  	// tag = (h + s) % (2^128)
   297  	t := uint64(h0) + uint64(binary.LittleEndian.Uint32(key[16:]))
   298  	h0 = uint32(t)
   299  	t = uint64(h1) + uint64(binary.LittleEndian.Uint32(key[20:])) + (t >> 32)
   300  	h1 = uint32(t)
   301  	t = uint64(h2) + uint64(binary.LittleEndian.Uint32(key[24:])) + (t >> 32)
   302  	h2 = uint32(t)
   303  	t = uint64(h3) + uint64(binary.LittleEndian.Uint32(key[28:])) + (t >> 32)
   304  	h3 = uint32(t)
   305  
   306  	binary.LittleEndian.PutUint32(out[0:], h0)
   307  	binary.LittleEndian.PutUint32(out[4:], h1)
   308  	binary.LittleEndian.PutUint32(out[8:], h2)
   309  	binary.LittleEndian.PutUint32(out[12:], h3)
   310  
   311  	return nil
   312  }
   313  
   314  // core from x/crypto/salsa20/salsa/salsa20_ref.go.
   315  func salsa20Core(out *[64]byte, in *[16]byte, k *[32]byte, c *[16]byte) {
   316  	j0 := uint32(c[0]) | uint32(c[1])<<8 | uint32(c[2])<<16 | uint32(c[3])<<24
   317  	j1 := uint32(k[0]) | uint32(k[1])<<8 | uint32(k[2])<<16 | uint32(k[3])<<24
   318  	j2 := uint32(k[4]) | uint32(k[5])<<8 | uint32(k[6])<<16 | uint32(k[7])<<24
   319  	j3 := uint32(k[8]) | uint32(k[9])<<8 | uint32(k[10])<<16 | uint32(k[11])<<24
   320  	j4 := uint32(k[12]) | uint32(k[13])<<8 | uint32(k[14])<<16 | uint32(k[15])<<24
   321  	j5 := uint32(c[4]) | uint32(c[5])<<8 | uint32(c[6])<<16 | uint32(c[7])<<24
   322  	j6 := uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24
   323  	j7 := uint32(in[4]) | uint32(in[5])<<8 | uint32(in[6])<<16 | uint32(in[7])<<24
   324  	j8 := uint32(in[8]) | uint32(in[9])<<8 | uint32(in[10])<<16 | uint32(in[11])<<24
   325  	j9 := uint32(in[12]) | uint32(in[13])<<8 | uint32(in[14])<<16 | uint32(in[15])<<24
   326  	j10 := uint32(c[8]) | uint32(c[9])<<8 | uint32(c[10])<<16 | uint32(c[11])<<24
   327  	j11 := uint32(k[16]) | uint32(k[17])<<8 | uint32(k[18])<<16 | uint32(k[19])<<24
   328  	j12 := uint32(k[20]) | uint32(k[21])<<8 | uint32(k[22])<<16 | uint32(k[23])<<24
   329  	j13 := uint32(k[24]) | uint32(k[25])<<8 | uint32(k[26])<<16 | uint32(k[27])<<24
   330  	j14 := uint32(k[28]) | uint32(k[29])<<8 | uint32(k[30])<<16 | uint32(k[31])<<24
   331  	j15 := uint32(c[12]) | uint32(c[13])<<8 | uint32(c[14])<<16 | uint32(c[15])<<24
   332  
   333  	x0, x1, x2, x3, x4, x5, x6, x7, x8 := j0, j1, j2, j3, j4, j5, j6, j7, j8
   334  	x9, x10, x11, x12, x13, x14, x15 := j9, j10, j11, j12, j13, j14, j15
   335  
   336  	const rounds = 20
   337  
   338  	for i := 0; i < rounds; i += 2 {
   339  		u := x0 + x12
   340  		x4 ^= u<<7 | u>>(32-7)
   341  		u = x4 + x0
   342  		x8 ^= u<<9 | u>>(32-9)
   343  		u = x8 + x4
   344  		x12 ^= u<<13 | u>>(32-13)
   345  		u = x12 + x8
   346  		x0 ^= u<<18 | u>>(32-18)
   347  
   348  		u = x5 + x1
   349  		x9 ^= u<<7 | u>>(32-7)
   350  		u = x9 + x5
   351  		x13 ^= u<<9 | u>>(32-9)
   352  		u = x13 + x9
   353  		x1 ^= u<<13 | u>>(32-13)
   354  		u = x1 + x13
   355  		x5 ^= u<<18 | u>>(32-18)
   356  
   357  		u = x10 + x6
   358  		x14 ^= u<<7 | u>>(32-7)
   359  		u = x14 + x10
   360  		x2 ^= u<<9 | u>>(32-9)
   361  		u = x2 + x14
   362  		x6 ^= u<<13 | u>>(32-13)
   363  		u = x6 + x2
   364  		x10 ^= u<<18 | u>>(32-18)
   365  
   366  		u = x15 + x11
   367  		x3 ^= u<<7 | u>>(32-7)
   368  		u = x3 + x15
   369  		x7 ^= u<<9 | u>>(32-9)
   370  		u = x7 + x3
   371  		x11 ^= u<<13 | u>>(32-13)
   372  		u = x11 + x7
   373  		x15 ^= u<<18 | u>>(32-18)
   374  
   375  		u = x0 + x3
   376  		x1 ^= u<<7 | u>>(32-7)
   377  		u = x1 + x0
   378  		x2 ^= u<<9 | u>>(32-9)
   379  		u = x2 + x1
   380  		x3 ^= u<<13 | u>>(32-13)
   381  		u = x3 + x2
   382  		x0 ^= u<<18 | u>>(32-18)
   383  
   384  		u = x5 + x4
   385  		x6 ^= u<<7 | u>>(32-7)
   386  		u = x6 + x5
   387  		x7 ^= u<<9 | u>>(32-9)
   388  		u = x7 + x6
   389  		x4 ^= u<<13 | u>>(32-13)
   390  		u = x4 + x7
   391  		x5 ^= u<<18 | u>>(32-18)
   392  
   393  		u = x10 + x9
   394  		x11 ^= u<<7 | u>>(32-7)
   395  		u = x11 + x10
   396  		x8 ^= u<<9 | u>>(32-9)
   397  		u = x8 + x11
   398  		x9 ^= u<<13 | u>>(32-13)
   399  		u = x9 + x8
   400  		x10 ^= u<<18 | u>>(32-18)
   401  
   402  		u = x15 + x14
   403  		x12 ^= u<<7 | u>>(32-7)
   404  		u = x12 + x15
   405  		x13 ^= u<<9 | u>>(32-9)
   406  		u = x13 + x12
   407  		x14 ^= u<<13 | u>>(32-13)
   408  		u = x14 + x13
   409  		x15 ^= u<<18 | u>>(32-18)
   410  	}
   411  	x0 += j0
   412  	x1 += j1
   413  	x2 += j2
   414  	x3 += j3
   415  	x4 += j4
   416  	x5 += j5
   417  	x6 += j6
   418  	x7 += j7
   419  	x8 += j8
   420  	x9 += j9
   421  	x10 += j10
   422  	x11 += j11
   423  	x12 += j12
   424  	x13 += j13
   425  	x14 += j14
   426  	x15 += j15
   427  
   428  	out[0] = byte(x0)
   429  	out[1] = byte(x0 >> 8)
   430  	out[2] = byte(x0 >> 16)
   431  	out[3] = byte(x0 >> 24)
   432  
   433  	out[4] = byte(x1)
   434  	out[5] = byte(x1 >> 8)
   435  	out[6] = byte(x1 >> 16)
   436  	out[7] = byte(x1 >> 24)
   437  
   438  	out[8] = byte(x2)
   439  	out[9] = byte(x2 >> 8)
   440  	out[10] = byte(x2 >> 16)
   441  	out[11] = byte(x2 >> 24)
   442  
   443  	out[12] = byte(x3)
   444  	out[13] = byte(x3 >> 8)
   445  	out[14] = byte(x3 >> 16)
   446  	out[15] = byte(x3 >> 24)
   447  
   448  	out[16] = byte(x4)
   449  	out[17] = byte(x4 >> 8)
   450  	out[18] = byte(x4 >> 16)
   451  	out[19] = byte(x4 >> 24)
   452  
   453  	out[20] = byte(x5)
   454  	out[21] = byte(x5 >> 8)
   455  	out[22] = byte(x5 >> 16)
   456  	out[23] = byte(x5 >> 24)
   457  
   458  	out[24] = byte(x6)
   459  	out[25] = byte(x6 >> 8)
   460  	out[26] = byte(x6 >> 16)
   461  	out[27] = byte(x6 >> 24)
   462  
   463  	out[28] = byte(x7)
   464  	out[29] = byte(x7 >> 8)
   465  	out[30] = byte(x7 >> 16)
   466  	out[31] = byte(x7 >> 24)
   467  
   468  	out[32] = byte(x8)
   469  	out[33] = byte(x8 >> 8)
   470  	out[34] = byte(x8 >> 16)
   471  	out[35] = byte(x8 >> 24)
   472  
   473  	out[36] = byte(x9)
   474  	out[37] = byte(x9 >> 8)
   475  	out[38] = byte(x9 >> 16)
   476  	out[39] = byte(x9 >> 24)
   477  
   478  	out[40] = byte(x10)
   479  	out[41] = byte(x10 >> 8)
   480  	out[42] = byte(x10 >> 16)
   481  	out[43] = byte(x10 >> 24)
   482  
   483  	out[44] = byte(x11)
   484  	out[45] = byte(x11 >> 8)
   485  	out[46] = byte(x11 >> 16)
   486  	out[47] = byte(x11 >> 24)
   487  
   488  	out[48] = byte(x12)
   489  	out[49] = byte(x12 >> 8)
   490  	out[50] = byte(x12 >> 16)
   491  	out[51] = byte(x12 >> 24)
   492  
   493  	out[52] = byte(x13)
   494  	out[53] = byte(x13 >> 8)
   495  	out[54] = byte(x13 >> 16)
   496  	out[55] = byte(x13 >> 24)
   497  
   498  	out[56] = byte(x14)
   499  	out[57] = byte(x14 >> 8)
   500  	out[58] = byte(x14 >> 16)
   501  	out[59] = byte(x14 >> 24)
   502  
   503  	out[60] = byte(x15)
   504  	out[61] = byte(x15 >> 8)
   505  	out[62] = byte(x15 >> 16)
   506  	out[63] = byte(x15 >> 24)
   507  }