roughtime.googlesource.com/roughtime.git@v0.0.0-20201210012726-dd529367052d/go/protocol/protocol_test.go (about)

     1  // Copyright 2016 The Roughtime Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //   http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License. */
    14  
    15  package protocol
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/rand"
    20  	"encoding/binary"
    21  	"testing"
    22  	"testing/quick"
    23  
    24  	"golang.org/x/crypto/ed25519"
    25  )
    26  
    27  func testEncodeDecodeRoundtrip(msg map[uint32][]byte) bool {
    28  	encoded, err := Encode(msg)
    29  	if err != nil {
    30  		return true
    31  	}
    32  
    33  	decoded, err := Decode(encoded)
    34  	if err != nil {
    35  		return false
    36  	}
    37  
    38  	if len(msg) != len(decoded) {
    39  		return false
    40  	}
    41  
    42  	for tag, payload := range msg {
    43  		otherPayload, ok := decoded[tag]
    44  		if !ok {
    45  			return false
    46  		}
    47  		if !bytes.Equal(payload, otherPayload) {
    48  			return false
    49  		}
    50  	}
    51  
    52  	return true
    53  }
    54  
    55  func TestEncodeDecode(t *testing.T) {
    56  	quick.Check(testEncodeDecodeRoundtrip, &quick.Config{
    57  		MaxCountScale: 10,
    58  	})
    59  }
    60  
    61  func TestRequestSize(t *testing.T) {
    62  	_, _, request, err := CreateRequest(rand.Reader, nil)
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	if len(request) != MinRequestSize {
    67  		t.Errorf("got %d byte request, want %d bytes", len(request), MinRequestSize)
    68  	}
    69  }
    70  
    71  func createServerIdentity(t *testing.T) (cert, rootPublicKey, onlinePrivateKey []byte) {
    72  	rootPublicKey, rootPrivateKey, err := ed25519.GenerateKey(rand.Reader)
    73  	if err != nil {
    74  		t.Fatal(err)
    75  	}
    76  
    77  	onlinePublicKey, onlinePrivateKey, err := ed25519.GenerateKey(rand.Reader)
    78  	if err != nil {
    79  		t.Fatal(err)
    80  	}
    81  
    82  	if cert, err = CreateCertificate(0, 100, onlinePublicKey, rootPrivateKey); err != nil {
    83  		t.Fatal(err)
    84  	}
    85  
    86  	return cert, rootPublicKey, onlinePrivateKey
    87  }
    88  
    89  func TestRoundtrip(t *testing.T) {
    90  	cert, rootPublicKey, onlinePrivateKey := createServerIdentity(t)
    91  
    92  	for _, numRequests := range []int{1, 2, 3, 4, 5, 15, 16, 17} {
    93  		nonces := make([][NonceSize]byte, numRequests)
    94  		for i := range nonces {
    95  			binary.LittleEndian.PutUint32(nonces[i][:], uint32(i))
    96  		}
    97  
    98  		noncesSlice := make([][]byte, 0, numRequests)
    99  		for i := range nonces {
   100  			noncesSlice = append(noncesSlice, nonces[i][:])
   101  		}
   102  
   103  		const (
   104  			expectedMidpoint = 50
   105  			expectedRadius   = 5
   106  		)
   107  
   108  		replies, err := CreateReplies(noncesSlice, expectedMidpoint, expectedRadius, cert, onlinePrivateKey)
   109  		if err != nil {
   110  			t.Fatal(err)
   111  		}
   112  
   113  		if len(replies) != len(nonces) {
   114  			t.Fatalf("received %d replies for %d nonces", len(replies), len(nonces))
   115  		}
   116  
   117  		for i, reply := range replies {
   118  			midpoint, radius, err := VerifyReply(reply, rootPublicKey, nonces[i])
   119  			if err != nil {
   120  				t.Errorf("error parsing reply #%d: %s", i, err)
   121  				continue
   122  			}
   123  
   124  			if midpoint != expectedMidpoint {
   125  				t.Errorf("reply #%d gave a midpoint of %d, want %d", i, midpoint, expectedMidpoint)
   126  			}
   127  			if radius != expectedRadius {
   128  				t.Errorf("reply #%d gave a radius of %d, want %d", i, radius, expectedRadius)
   129  			}
   130  		}
   131  	}
   132  }
   133  
   134  func TestChaining(t *testing.T) {
   135  	// This test demonstrates how a claim of misbehaviour from a client
   136  	// would be checked. The client creates a two element chain in this
   137  	// example where the first server says that the time is 10 and the
   138  	// second says that it's 5.
   139  	certA, rootPublicKeyA, onlinePrivateKeyA := createServerIdentity(t)
   140  	certB, rootPublicKeyB, onlinePrivateKeyB := createServerIdentity(t)
   141  
   142  	nonce1, _, _, err := CreateRequest(rand.Reader, nil)
   143  	if err != nil {
   144  		t.Fatal(err)
   145  	}
   146  
   147  	replies1, err := CreateReplies([][]byte{nonce1[:]}, 10, 0, certA, onlinePrivateKeyA)
   148  	if err != nil {
   149  		t.Fatal(err)
   150  	}
   151  
   152  	nonce2, blind2, _, err := CreateRequest(rand.Reader, replies1[0])
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  
   157  	replies2, err := CreateReplies([][]byte{nonce2[:]}, 5, 0, certB, onlinePrivateKeyB)
   158  	if err != nil {
   159  		t.Fatal(err)
   160  	}
   161  
   162  	// The client would present a series of tuples of (server identity,
   163  	// nonce/blind, reply) as its claim of misbehaviour. The first element
   164  	// contains a nonce where as all other elements contain just the
   165  	// blinding value, as the nonce used for that request is calculated
   166  	// from that and the previous reply.
   167  	type claimStep struct {
   168  		serverPublicKey []byte
   169  		nonceOrBlind    [NonceSize]byte
   170  		reply           []byte
   171  	}
   172  
   173  	claim := []claimStep{
   174  		claimStep{rootPublicKeyA, nonce1, replies1[0]},
   175  		claimStep{rootPublicKeyB, blind2, replies2[0]},
   176  	}
   177  
   178  	// In order to verify a claim, one would check each of the replies
   179  	// based on the calculated nonce.
   180  	var lastMidpoint uint64
   181  	var misbehaviourFound bool
   182  	for i, step := range claim {
   183  		var nonce [NonceSize]byte
   184  		if i == 0 {
   185  			copy(nonce[:], step.nonceOrBlind[:])
   186  		} else {
   187  			nonce = CalculateChainNonce(claim[i-1].reply, step.nonceOrBlind[:])
   188  		}
   189  		midpoint, _, err := VerifyReply(step.reply, step.serverPublicKey, nonce)
   190  		if err != nil {
   191  			t.Fatal(err)
   192  		}
   193  
   194  		// This example doesn't take the radius into account.
   195  		if i > 0 && midpoint < lastMidpoint {
   196  			misbehaviourFound = true
   197  		}
   198  		lastMidpoint = midpoint
   199  	}
   200  
   201  	if !misbehaviourFound {
   202  		t.Error("did not find expected misbehaviour")
   203  	}
   204  }