github.com/jlmucb/cloudproxy@v0.0.0-20170830161738-b5aa0b619bc4/go/apps/roughtime/agl_roughtime/protocol/protocol_test.go (about)

     1  // Copyright 2016 The Roughtime Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //   http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License. */
    14  
    15  package protocol
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/rand"
    20  	"encoding/binary"
    21  	"testing"
    22  	"testing/quick"
    23  
    24  	"golang.org/x/crypto/ed25519"
    25  )
    26  
    27  func testEncodeDecodeRoundtrip(msg map[uint32][]byte) bool {
    28  	encoded, err := Encode(msg)
    29  	if err != nil {
    30  		return true
    31  	}
    32  
    33  	decoded, err := Decode(encoded)
    34  	if err != nil {
    35  		return false
    36  	}
    37  
    38  	if len(msg) != len(decoded) {
    39  		return false
    40  	}
    41  
    42  	for tag, payload := range msg {
    43  		otherPayload, ok := decoded[tag]
    44  		if !ok {
    45  			return false
    46  		}
    47  		if !bytes.Equal(payload, otherPayload) {
    48  			return false
    49  		}
    50  	}
    51  
    52  	return true
    53  }
    54  
    55  func TestEncodeDecode(t *testing.T) {
    56  	quick.Check(testEncodeDecodeRoundtrip, &quick.Config{
    57  		MaxCountScale: 10,
    58  	})
    59  }
    60  
    61  func TestRequestSize(t *testing.T) {
    62  	_, _, request, err := CreateRequest(rand.Reader, nil)
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	if len(request) != MinRequestSize {
    67  		t.Errorf("got %d byte request, want %d bytes", len(request), MinRequestSize)
    68  	}
    69  }
    70  
    71  func createServerIdentity(t *testing.T) (cert, rootPublicKey, onlinePrivateKey []byte) {
    72  	rootPublicKey, rootPrivateKey, err := ed25519.GenerateKey(rand.Reader)
    73  	if err != nil {
    74  		t.Fatal(err)
    75  	}
    76  
    77  	onlinePublicKey, onlinePrivateKey, err := ed25519.GenerateKey(rand.Reader)
    78  	if err != nil {
    79  		t.Fatal(err)
    80  	}
    81  
    82  	if cert, err = CreateCertificate(0, 100, onlinePublicKey, rootPrivateKey); err != nil {
    83  		t.Fatal(err)
    84  	}
    85  
    86  	return cert, rootPublicKey, onlinePrivateKey
    87  }
    88  
    89  func TestRoundtrip(t *testing.T) {
    90  	cert, rootPublicKey, onlinePrivateKey := createServerIdentity(t)
    91  
    92  	for _, numRequests := range []int{1, 2, 3, 4, 5, 15, 16, 17} {
    93  		nonces := make([][NonceSize]byte, numRequests)
    94  		for i := range nonces {
    95  			binary.LittleEndian.PutUint32(nonces[i][:], uint32(i))
    96  		}
    97  
    98  		noncesSlice := make([][]byte, 0, numRequests)
    99  		for i := range nonces {
   100  			noncesSlice = append(noncesSlice, nonces[i][:])
   101  		}
   102  
   103  		const (
   104  			expectedMidpoint = 50
   105  			expectedRadius   = 5
   106  		)
   107  
   108  		replies, err := CreateReplies(noncesSlice, expectedMidpoint, expectedRadius, cert, onlinePrivateKey)
   109  		if err != nil {
   110  			t.Fatal(err)
   111  		}
   112  
   113  		for i, reply := range replies {
   114  			midpoint, radius, err := VerifyReply(reply, rootPublicKey, nonces[i])
   115  			if err != nil {
   116  				t.Errorf("error parsing reply #%d: %s", i, err)
   117  				continue
   118  			}
   119  
   120  			if midpoint != expectedMidpoint {
   121  				t.Errorf("reply #%d gave a midpoint of %d, want %d", i, midpoint, expectedMidpoint)
   122  			}
   123  			if radius != expectedRadius {
   124  				t.Errorf("reply #%d gave a radius of %d, want %d", i, radius, expectedRadius)
   125  			}
   126  		}
   127  	}
   128  }
   129  
   130  func TestChaining(t *testing.T) {
   131  	// This test demonstrates how a claim of misbehaviour from a client
   132  	// would be checked. The client creates a two element chain in this
   133  	// example where the first server says that the time is 10 and the
   134  	// second says that it's 5.
   135  	certA, rootPublicKeyA, onlinePrivateKeyA := createServerIdentity(t)
   136  	certB, rootPublicKeyB, onlinePrivateKeyB := createServerIdentity(t)
   137  
   138  	nonce1, _, _, err := CreateRequest(rand.Reader, nil)
   139  	if err != nil {
   140  		t.Fatal(err)
   141  	}
   142  
   143  	replies1, err := CreateReplies([][]byte{nonce1[:]}, 10, 0, certA, onlinePrivateKeyA)
   144  	if err != nil {
   145  		t.Fatal(err)
   146  	}
   147  
   148  	nonce2, blind2, _, err := CreateRequest(rand.Reader, replies1[0])
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  
   153  	replies2, err := CreateReplies([][]byte{nonce2[:]}, 5, 0, certB, onlinePrivateKeyB)
   154  	if err != nil {
   155  		t.Fatal(err)
   156  	}
   157  
   158  	// The client would present a series of tuples of (server identity,
   159  	// nonce/blind, reply) as its claim of misbehaviour. The first element
   160  	// contains a nonce where as all other elements contain just the
   161  	// blinding value, as the nonce used for that request is calculated
   162  	// from that and the previous reply.
   163  	type claimStep struct {
   164  		serverPublicKey []byte
   165  		nonceOrBlind    [NonceSize]byte
   166  		reply           []byte
   167  	}
   168  
   169  	claim := []claimStep{
   170  		claimStep{rootPublicKeyA, nonce1, replies1[0]},
   171  		claimStep{rootPublicKeyB, blind2, replies2[0]},
   172  	}
   173  
   174  	// In order to verify a claim, one would check each of the replies
   175  	// based on the calculated nonce.
   176  	var lastMidpoint uint64
   177  	var misbehaviourFound bool
   178  	for i, step := range claim {
   179  		var nonce [NonceSize]byte
   180  		if i == 0 {
   181  			copy(nonce[:], step.nonceOrBlind[:])
   182  		} else {
   183  			nonce = CalculateChainNonce(claim[i-1].reply, step.nonceOrBlind[:])
   184  		}
   185  		midpoint, _, err := VerifyReply(step.reply, step.serverPublicKey, nonce)
   186  		if err != nil {
   187  			t.Fatal(err)
   188  		}
   189  
   190  		// This example doesn't take the radius into account.
   191  		if i > 0 && midpoint < lastMidpoint {
   192  			misbehaviourFound = true
   193  		}
   194  		lastMidpoint = midpoint
   195  	}
   196  
   197  	if !misbehaviourFound {
   198  		t.Error("did not find expected misbehaviour")
   199  	}
   200  }