github.com/jlmucb/cloudproxy@v0.0.0-20170830161738-b5aa0b619bc4/go/apps/roughtime/client.go (about)

     1  // Copyright (c) 2016, Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // This contains adapted version of the roughtime client to use Tao.
    16  package roughtime
    17  
    18  import (
    19  	"crypto/rand"
    20  	"encoding/binary"
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"io"
    25  	"net"
    26  	"os"
    27  	"strings"
    28  	"time"
    29  
    30  	"github.com/jlmucb/cloudproxy/go/apps/roughtime/agl_roughtime/config"
    31  	"github.com/jlmucb/cloudproxy/go/apps/roughtime/agl_roughtime/monotime"
    32  	"github.com/jlmucb/cloudproxy/go/apps/roughtime/agl_roughtime/protocol"
    33  	"github.com/jlmucb/cloudproxy/go/tao"
    34  
    35  	"math/big"
    36  	mathrand "math/rand"
    37  
    38  	"golang.org/x/crypto/ed25519"
    39  )
    40  
    41  const (
    42  	// defaultMaxRadius is the maximum radius that we'll accept from a
    43  	// server.
    44  	defaultMaxRadius = 10 * time.Second
    45  
    46  	// defaultMaxDifference is the maximum difference in time between any
    47  	// sample from a server and the quorum-agreed time before we believe
    48  	// that the server might be misbehaving.
    49  	defaultMaxDifference = 60 * time.Second
    50  
    51  	// defaultTimeout is the default maximum time that a server has to
    52  	// answer a query.
    53  	defaultTimeout = 2 * time.Second
    54  
    55  	// defaultNumQueries is the default number of times we will try to
    56  	// query a server.
    57  	defaultNumQueries = 3
    58  )
    59  
    60  // Client represents a Roughtime client and exposes a number of members that
    61  // can be set in order to configure it. The zero value of a Client is always
    62  // ready to use and will set sensible defaults.
    63  type Client struct {
    64  	// Permutation returns a random permutation of [0‥n) that is used to
    65  	// query servers in a random order. If nil, a sensible default is used.
    66  	Permutation func(n int) []int
    67  
    68  	// MaxRadiusUs is the maximum interval radius that will be accepted
    69  	// from a server. If zero, a sensible default is used.
    70  	MaxRadius time.Duration
    71  
    72  	// MaxDifference is the maximum difference in time between any sample
    73  	// from a server and the quorum-agreed time before that sample is
    74  	// considered suspect. If zero, a sensible default is used.
    75  	MaxDifference time.Duration
    76  
    77  	// QueryTimeout is the amount of time a server has to reply to a query.
    78  	// If zero, a sensible default will be used.
    79  	QueryTimeout time.Duration
    80  
    81  	// NumQueries is the maximum number of times a query will be sent to a
    82  	// specific server before giving up. If <= zero, a sensible default
    83  	// will be used.
    84  	NumQueries int
    85  
    86  	// now returns a monotonic duration from some unspecified epoch. If
    87  	// nil, the system monotonic time will be used.
    88  	nowFunc func() time.Duration
    89  
    90  	network string
    91  	domain  *tao.Domain
    92  
    93  	//Number of servers to query
    94  	quorum  int
    95  	servers []config.Server
    96  }
    97  
    98  func NewClient(domainPath, network string, quorum int, servers []config.Server) (*Client, error) {
    99  	// Load domain from a local configuration.
   100  	domain, err := tao.LoadDomain(domainPath, nil)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	c := &Client{
   106  		network: network,
   107  		domain:  domain,
   108  		quorum:  quorum,
   109  		servers: servers,
   110  	}
   111  	return c, nil
   112  }
   113  
   114  func (c *Client) now() time.Duration {
   115  	if c.nowFunc != nil {
   116  		return c.nowFunc()
   117  	}
   118  	return monotime.Now()
   119  }
   120  
   121  func (c *Client) permutation(n int) []int {
   122  	if c.Permutation != nil {
   123  		return c.Permutation(n)
   124  	}
   125  
   126  	var randBuf [8]byte
   127  	if _, err := io.ReadFull(rand.Reader, randBuf[:]); err != nil {
   128  		panic(err)
   129  	}
   130  
   131  	seed := binary.LittleEndian.Uint64(randBuf[:])
   132  	rand := mathrand.New(mathrand.NewSource(int64(seed)))
   133  
   134  	return rand.Perm(n)
   135  }
   136  
   137  func (c *Client) maxRadius() time.Duration {
   138  	if c.MaxRadius != 0 {
   139  		return c.MaxRadius
   140  	}
   141  
   142  	return defaultMaxRadius
   143  }
   144  
   145  func (c *Client) maxDifference() time.Duration {
   146  	if c.MaxDifference != 0 {
   147  		return c.MaxDifference
   148  	}
   149  
   150  	return defaultMaxDifference
   151  }
   152  
   153  func (c *Client) queryTimeout() time.Duration {
   154  	if c.QueryTimeout != 0 {
   155  		return c.QueryTimeout
   156  	}
   157  
   158  	return defaultTimeout
   159  }
   160  
   161  func (c *Client) numQueries() int {
   162  	if c.NumQueries > 0 {
   163  		return c.NumQueries
   164  	}
   165  
   166  	return defaultNumQueries
   167  }
   168  
   169  // LoadChain loads a JSON-format chain from the given JSON data.
   170  func LoadChain(jsonData []byte) (chain *config.Chain, err error) {
   171  	chain = new(config.Chain)
   172  	if err := json.Unmarshal(jsonData, chain); err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	for i, link := range chain.Links {
   177  		if link.PublicKeyType != "ed25519" {
   178  			return nil, fmt.Errorf("client: link #%d in chain file has unknown public key type %q", i, link.PublicKeyType)
   179  		}
   180  
   181  		if l := len(link.PublicKey); l != ed25519.PublicKeySize {
   182  			return nil, fmt.Errorf("client: link #%d in chain file has bad public key of length %d", i, l)
   183  		}
   184  
   185  		if l := len(link.NonceOrBlind); l != protocol.NonceSize {
   186  			return nil, fmt.Errorf("client: link #%d in chain file has bad nonce/blind of length %d", i, l)
   187  		}
   188  
   189  		var nonce [protocol.NonceSize]byte
   190  		if i == 0 {
   191  			copy(nonce[:], link.NonceOrBlind[:])
   192  		} else {
   193  			nonce = protocol.CalculateChainNonce(chain.Links[i-1].Reply, link.NonceOrBlind[:])
   194  		}
   195  
   196  		if _, _, err := protocol.VerifyReply(link.Reply, link.PublicKey, nonce); err != nil {
   197  			return nil, fmt.Errorf("client: failed to verify link #%d in chain file", i)
   198  		}
   199  	}
   200  
   201  	return chain, nil
   202  }
   203  
   204  // timeSample represents a time sample from the network.
   205  type timeSample struct {
   206  	// server references the server that was queried.
   207  	server *config.Server
   208  
   209  	// base is a monotonic clock sample that is taken at a time before the
   210  	// network could have answered the query.
   211  	base *big.Int
   212  
   213  	// min is the minimum real-time (in Roughtime UTC microseconds) that
   214  	// could correspond to |base| (i.e. midpoint - radius).
   215  	min *big.Int
   216  
   217  	// max is the maximum real-time (in Roughtime UTC microseconds) that
   218  	// could correspond to |base| (i.e. midpoint + radius + query time).
   219  	max *big.Int
   220  
   221  	// queryDuration contains the amount of time that the server took to
   222  	// answer the query.
   223  	queryDuration time.Duration
   224  }
   225  
   226  // midpoint returns the average of the min and max times.
   227  func (s *timeSample) midpoint() *big.Int {
   228  	ret := new(big.Int).Add(s.min, s.max)
   229  	return ret.Rsh(ret, 1)
   230  }
   231  
   232  // alignTo updates s so that its base value matches that from reference.
   233  func (s *timeSample) alignTo(reference *timeSample) {
   234  	delta := new(big.Int).Sub(s.base, reference.base)
   235  	delta.Div(delta, big.NewInt(int64(time.Microsecond)))
   236  	s.base.Sub(s.base, delta)
   237  	s.min.Sub(s.min, delta)
   238  	s.max.Sub(s.max, delta)
   239  }
   240  
   241  // contains returns true iff p belongs to s
   242  func (s *timeSample) contains(p *big.Int) bool {
   243  	return s.max.Cmp(p) >= 0 && s.min.Cmp(p) <= 0
   244  }
   245  
   246  // overlaps returns true iff s and other have any timespan in common.
   247  func (s *timeSample) overlaps(other *timeSample) bool {
   248  	return s.max.Cmp(other.min) >= 0 && other.max.Cmp(s.min) >= 0
   249  }
   250  
   251  func serverNetworkAddr(network string, server *config.Server) string {
   252  	for _, addr := range server.Addresses {
   253  		if addr.Protocol == network {
   254  			return addr.Address
   255  		}
   256  	}
   257  	return ""
   258  }
   259  
   260  // query sends a request to s, appends it to chain, and returns the resulting
   261  // timeSample.
   262  func (c *Client) query(server *config.Server, chain *config.Chain) (*timeSample, error) {
   263  	var prevReply []byte
   264  	if len(chain.Links) > 0 {
   265  		prevReply = chain.Links[len(chain.Links)-1].Reply
   266  	}
   267  
   268  	var baseTime, replyTime time.Duration
   269  	var reply []byte
   270  	var nonce, blind [protocol.NonceSize]byte
   271  
   272  	for attempts := 0; attempts < c.numQueries(); attempts++ {
   273  		var request []byte
   274  		var err error
   275  		if nonce, blind, request, err = protocol.CreateRequest(rand.Reader, prevReply); err != nil {
   276  			return nil, err
   277  		}
   278  		if len(request) < protocol.MinRequestSize {
   279  			panic("internal error: bad request length")
   280  		}
   281  
   282  		serverAddr := serverNetworkAddr(c.network, server)
   283  		conn, err := tao.Dial(c.network, serverAddr, c.domain.Guard,
   284  			c.domain.Keys.VerifyingKey, nil)
   285  		if err != nil {
   286  			return nil, err
   287  		}
   288  
   289  		conn.SetReadDeadline(time.Now().Add(c.queryTimeout()))
   290  		baseTime = c.now()
   291  		conn.Write(request)
   292  
   293  		var replyBytes [1024]byte
   294  		n, err := conn.Read(replyBytes[:])
   295  		if err == nil {
   296  			replyTime = c.now()
   297  			reply = replyBytes[:n]
   298  			break
   299  		}
   300  
   301  		if netErr, ok := err.(net.Error); ok {
   302  			if !netErr.Timeout() {
   303  				return nil, errors.New("client: error reading from UDP socket: " + err.Error())
   304  			}
   305  		}
   306  	}
   307  
   308  	if reply == nil {
   309  		return nil, fmt.Errorf("client: no reply from server %q", server.Name)
   310  	}
   311  
   312  	if replyTime < baseTime {
   313  		panic("broken monotonic clock")
   314  	}
   315  	queryDuration := replyTime - baseTime
   316  
   317  	midpoint, radius, err := protocol.VerifyReply(reply, server.PublicKey, nonce)
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	if time.Duration(radius)*time.Microsecond > c.maxRadius() {
   323  		return nil, fmt.Errorf("client: radius (%d) too large", radius)
   324  	}
   325  
   326  	nonceOrBlind := blind[:]
   327  	if len(prevReply) == 0 {
   328  		nonceOrBlind = nonce[:]
   329  	}
   330  
   331  	chain.Links = append(chain.Links, config.Link{
   332  		PublicKeyType: "ed25519",
   333  		PublicKey:     server.PublicKey,
   334  		NonceOrBlind:  nonceOrBlind,
   335  		Reply:         reply,
   336  	})
   337  
   338  	queryDurationBig := new(big.Int).SetInt64(int64(queryDuration / time.Microsecond))
   339  	bigRadius := new(big.Int).SetUint64(uint64(radius))
   340  	min := new(big.Int).SetUint64(midpoint)
   341  	min.Sub(min, bigRadius)
   342  	min.Sub(min, queryDurationBig)
   343  
   344  	max := new(big.Int).SetUint64(midpoint)
   345  	max.Add(max, bigRadius)
   346  
   347  	return &timeSample{
   348  		server:        server,
   349  		base:          new(big.Int).SetInt64(int64(baseTime)),
   350  		min:           min,
   351  		max:           max,
   352  		queryDuration: queryDuration,
   353  	}, nil
   354  }
   355  
   356  // LoadServers loads information about known servers from the given JSON data.
   357  // It only extracts information about servers with Ed25519 public keys and UDP
   358  // address. The number of servers skipped because of unsupported requirements
   359  // is returned in numSkipped.
   360  func LoadServers(jsonData []byte) (servers []config.Server, numSkipped int, err error) {
   361  	var serversJSON config.ServersJSON
   362  	if err := json.Unmarshal(jsonData, &serversJSON); err != nil {
   363  		return nil, 0, err
   364  	}
   365  
   366  	seenNames := make(map[string]struct{})
   367  
   368  	for _, candidate := range serversJSON.Servers {
   369  		if _, ok := seenNames[candidate.Name]; ok {
   370  			return nil, 0, fmt.Errorf("client: duplicate server name %q", candidate.Name)
   371  		}
   372  		seenNames[candidate.Name] = struct{}{}
   373  
   374  		if candidate.PublicKeyType != "ed25519" {
   375  			numSkipped++
   376  			continue
   377  		}
   378  
   379  		servers = append(servers, candidate)
   380  	}
   381  
   382  	if len(servers) == 0 {
   383  		return nil, 0, errors.New("client: no usable servers found")
   384  	}
   385  
   386  	return servers, 0, nil
   387  }
   388  
   389  // trimChain drops elements from the beginning of chain, as needed, so that its
   390  // length is <= n.
   391  func trimChain(chain *config.Chain, n int) {
   392  	if n <= 0 {
   393  		chain.Links = nil
   394  		return
   395  	}
   396  
   397  	if len(chain.Links) <= n {
   398  		return
   399  	}
   400  
   401  	numToTrim := len(chain.Links) - n
   402  	for i := 0; i < numToTrim; i++ {
   403  		// The NonceOrBlind of the first element is special because
   404  		// it's an nonce. All the others are blinds and are combined
   405  		// with the previous reply to make the nonce. That's not
   406  		// possible for the first element because there is no previous
   407  		// reply. Therefore, when removing the first element the blind
   408  		// of the next element needs to be converted to an nonce.
   409  		nonce := protocol.CalculateChainNonce(chain.Links[0].Reply, chain.Links[1].NonceOrBlind[:])
   410  		chain.Links[1].NonceOrBlind = nonce[:]
   411  		chain.Links = chain.Links[1:]
   412  	}
   413  }
   414  
   415  // intersection returns the timespan common to all the elements in samples,
   416  // which must be aligned to the same base. The caller must ensure that such a
   417  // timespan exists.
   418  func intersection(samples []*timeSample) *timeSample {
   419  	ret := &timeSample{
   420  		base: samples[0].base,
   421  		min:  new(big.Int).Set(samples[0].min),
   422  		max:  new(big.Int).Set(samples[0].max),
   423  	}
   424  
   425  	for _, sample := range samples[1:] {
   426  		if ret.min.Cmp(sample.min) < 0 {
   427  			ret.min.Set(sample.min)
   428  		}
   429  		if ret.max.Cmp(sample.max) > 0 {
   430  			ret.max.Set(sample.max)
   431  		}
   432  	}
   433  
   434  	return ret
   435  }
   436  
   437  // findNOverlapping finds an n-element subset of samples where all the
   438  // members overlap. It returns the intersection if such a subset exists.
   439  func findNOverlapping(samples []*timeSample, n int) (sampleIntersection *timeSample, ok bool) {
   440  	switch {
   441  	case n <= 0:
   442  		return nil, false
   443  	case n == 1:
   444  		return samples[0], true
   445  	}
   446  
   447  	overlapping := make([]*timeSample, 0, n)
   448  
   449  	for _, initial := range samples {
   450  		// An intersection of any subset of intervals will be an interval that contains
   451  		// the starting point of one of the intervals (possibly as its own starting point).
   452  		point := initial.min
   453  
   454  		for _, candidate := range samples {
   455  			if candidate.contains(point) {
   456  				overlapping = append(overlapping, candidate)
   457  			}
   458  
   459  			if len(overlapping) == n {
   460  				return intersection(overlapping), true
   461  			}
   462  		}
   463  
   464  		overlapping = overlapping[:0]
   465  	}
   466  
   467  	return nil, false
   468  }
   469  
   470  // TimeResult is the result of trying to establish the current time by querying
   471  // a number of servers.
   472  type TimeResult struct {
   473  	// MonoUTCDelta may be nil, in which case a time could not be
   474  	// established. Otherwise it contains the difference between the
   475  	// Roughtime epoch and the monotonic clock.
   476  	MonoUTCDelta *time.Duration
   477  
   478  	// ServerErrors maps from server name to query error.
   479  	ServerErrors map[string]error
   480  
   481  	// ServerInfo contains information about each server that was queried.
   482  	ServerInfo map[string]ServerInfo
   483  
   484  	// OutOfRangeAnswer is true if one or more of the queries contained a
   485  	// significantly incorrect time, as defined by MaxDifference. In this
   486  	// case, the reply will have been recorded in the chain.
   487  	OutOfRangeAnswer bool
   488  }
   489  
   490  // ServerInfo contains information from a specific server.
   491  type ServerInfo struct {
   492  	// QueryDuration is the amount of time that the server took to answer.
   493  	QueryDuration time.Duration
   494  
   495  	// Min and Max specify the time window given by the server. These
   496  	// values have been adjusted so that they are comparible across
   497  	// servers, even though they are queried at different times.
   498  	Min, Max *big.Int
   499  }
   500  
   501  // EstablishTime queries a number of servers until it has a quorum of
   502  // overlapping results, or it runs out of servers. Results from the querying
   503  // the servers are appended to chain.
   504  func (c *Client) EstablishTime(chain *config.Chain, quorum int, servers []config.Server) (TimeResult, error) {
   505  	perm := c.permutation(len(servers))
   506  	var samples []*timeSample
   507  	var intersection *timeSample
   508  	var result TimeResult
   509  
   510  	for len(perm) > 0 {
   511  		server := &servers[perm[0]]
   512  		perm = perm[1:]
   513  
   514  		sample, err := c.query(server, chain)
   515  		if err != nil {
   516  			if result.ServerErrors == nil {
   517  				result.ServerErrors = make(map[string]error)
   518  			}
   519  			result.ServerErrors[server.Name] = err
   520  			continue
   521  		}
   522  
   523  		if len(samples) > 0 {
   524  			sample.alignTo(samples[0])
   525  		}
   526  		samples = append(samples, sample)
   527  
   528  		if result.ServerInfo == nil {
   529  			result.ServerInfo = make(map[string]ServerInfo)
   530  		}
   531  		result.ServerInfo[server.Name] = ServerInfo{
   532  			QueryDuration: sample.queryDuration,
   533  			Min:           sample.min,
   534  			Max:           sample.max,
   535  		}
   536  
   537  		var ok bool
   538  		if intersection, ok = findNOverlapping(samples, quorum); ok {
   539  			break
   540  		}
   541  		intersection = nil
   542  	}
   543  
   544  	if intersection == nil {
   545  		return result, nil
   546  	}
   547  	midpoint := intersection.midpoint()
   548  
   549  	maxDifference := new(big.Int).SetUint64(uint64(c.maxDifference() / time.Microsecond))
   550  	for _, sample := range samples {
   551  		delta := new(big.Int).Sub(midpoint, sample.midpoint())
   552  		delta.Abs(delta)
   553  
   554  		if delta.Cmp(maxDifference) > 0 {
   555  			result.OutOfRangeAnswer = true
   556  			break
   557  		}
   558  	}
   559  
   560  	midpoint.Mul(midpoint, big.NewInt(1000))
   561  	delta := new(big.Int).Sub(midpoint, intersection.base)
   562  	if delta.BitLen() > 63 {
   563  		return result, errors.New("client: cannot represent difference between monotonic and UTC time")
   564  	}
   565  	monoUTCDelta := time.Duration(delta.Int64())
   566  	result.MonoUTCDelta = &monoUTCDelta
   567  
   568  	return result, nil
   569  }
   570  
   571  func (c *Client) Do(chain *config.Chain) (*config.Chain, error) {
   572  	result, err := c.EstablishTime(chain, c.quorum, c.servers)
   573  	if err != nil {
   574  		return nil, err
   575  	}
   576  
   577  	for serverName, err := range result.ServerErrors {
   578  		fmt.Fprintf(os.Stderr, "Failed to query %q: %s\n", serverName, err)
   579  	}
   580  
   581  	maxLenServerName := 0
   582  	for name := range result.ServerInfo {
   583  		if len(name) > maxLenServerName {
   584  			maxLenServerName = len(name)
   585  		}
   586  	}
   587  
   588  	for name, info := range result.ServerInfo {
   589  		fmt.Printf("%s:%s %d–%d (answered in %s)\n", name, strings.Repeat(" ", maxLenServerName-len(name)), info.Min, info.Max, info.QueryDuration)
   590  	}
   591  
   592  	if result.MonoUTCDelta == nil {
   593  		fmt.Fprintf(os.Stderr, "Failed to get %d servers to agree on the time.\n", c.quorum)
   594  	} else {
   595  		nowUTC := time.Unix(0, int64(monotime.Now()+*result.MonoUTCDelta))
   596  		nowRealTime := time.Now()
   597  
   598  		fmt.Printf("real-time delta: %s\n", nowRealTime.Sub(nowUTC))
   599  	}
   600  
   601  	if result.OutOfRangeAnswer {
   602  		fmt.Fprintf(os.Stderr, "One or more of the answers was significantly out of range.\n")
   603  	}
   604  
   605  	trimChain(chain, maxChainSize)
   606  	return chain, nil
   607  }