roughtime.googlesource.com/roughtime.git@v0.0.0-20201210012726-dd529367052d/go/client/client.go (about)

     1  // Copyright 2016 The Roughtime Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //   http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License. */
    14  
    15  // client is a somewhat featured Roughtime client.
    16  package main
    17  
    18  import (
    19  	"crypto/rand"
    20  	"encoding/binary"
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"io"
    25  	"net"
    26  	"time"
    27  
    28  	"math/big"
    29  	mathrand "math/rand"
    30  
    31  	"golang.org/x/crypto/ed25519"
    32  	"roughtime.googlesource.com/go/client/monotime"
    33  	"roughtime.googlesource.com/go/config"
    34  	"roughtime.googlesource.com/go/protocol"
    35  )
    36  
    37  const (
    38  	// defaultMaxRadius is the maximum radius that we'll accept from a
    39  	// server.
    40  	defaultMaxRadius = 10 * time.Second
    41  
    42  	// defaultMaxDifference is the maximum difference in time between any
    43  	// sample from a server and the quorum-agreed time before we believe
    44  	// that the server might be misbehaving.
    45  	defaultMaxDifference = 60 * time.Second
    46  
    47  	// defaultTimeout is the default maximum time that a server has to
    48  	// answer a query.
    49  	defaultTimeout = 2 * time.Second
    50  
    51  	// defaultNumQueries is the default number of times we will try to
    52  	// query a server.
    53  	defaultNumQueries = 3
    54  )
    55  
    56  // Client represents a Roughtime client and exposes a number of members that
    57  // can be set in order to configure it. The zero value of a Client is always
    58  // ready to use and will set sensible defaults.
    59  type Client struct {
    60  	// Permutation returns a random permutation of [0‥n) that is used to
    61  	// query servers in a random order. If nil, a sensible default is used.
    62  	Permutation func(n int) []int
    63  
    64  	// MaxRadiusUs is the maximum interval radius that will be accepted
    65  	// from a server. If zero, a sensible default is used.
    66  	MaxRadius time.Duration
    67  
    68  	// MaxDifference is the maximum difference in time between any sample
    69  	// from a server and the quorum-agreed time before that sample is
    70  	// considered suspect. If zero, a sensible default is used.
    71  	MaxDifference time.Duration
    72  
    73  	// QueryTimeout is the amount of time a server has to reply to a query.
    74  	// If zero, a sensible default will be used.
    75  	QueryTimeout time.Duration
    76  
    77  	// NumQueries is the maximum number of times a query will be sent to a
    78  	// specific server before giving up. If <= zero, a sensible default
    79  	// will be used.
    80  	NumQueries int
    81  
    82  	// now returns a monotonic duration from some unspecified epoch. If
    83  	// nil, the system monotonic time will be used.
    84  	nowFunc func() time.Duration
    85  }
    86  
    87  func (c *Client) now() time.Duration {
    88  	if c.nowFunc != nil {
    89  		return c.nowFunc()
    90  	}
    91  	return monotime.Now()
    92  }
    93  
    94  func (c *Client) permutation(n int) []int {
    95  	if c.Permutation != nil {
    96  		return c.Permutation(n)
    97  	}
    98  
    99  	var randBuf [8]byte
   100  	if _, err := io.ReadFull(rand.Reader, randBuf[:]); err != nil {
   101  		panic(err)
   102  	}
   103  
   104  	seed := binary.LittleEndian.Uint64(randBuf[:])
   105  	rand := mathrand.New(mathrand.NewSource(int64(seed)))
   106  
   107  	return rand.Perm(n)
   108  }
   109  
   110  func (c *Client) maxRadius() time.Duration {
   111  	if c.MaxRadius != 0 {
   112  		return c.MaxRadius
   113  	}
   114  
   115  	return defaultMaxRadius
   116  }
   117  
   118  func (c *Client) maxDifference() time.Duration {
   119  	if c.MaxDifference != 0 {
   120  		return c.MaxDifference
   121  	}
   122  
   123  	return defaultMaxDifference
   124  }
   125  
   126  func (c *Client) queryTimeout() time.Duration {
   127  	if c.QueryTimeout != 0 {
   128  		return c.QueryTimeout
   129  	}
   130  
   131  	return defaultTimeout
   132  }
   133  
   134  func (c *Client) numQueries() int {
   135  	if c.NumQueries > 0 {
   136  		return c.NumQueries
   137  	}
   138  
   139  	return defaultNumQueries
   140  }
   141  
   142  // LoadChain loads a JSON-format chain from the given JSON data.
   143  func LoadChain(jsonData []byte) (chain *config.Chain, err error) {
   144  	chain = new(config.Chain)
   145  	if err := json.Unmarshal(jsonData, chain); err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	for i, link := range chain.Links {
   150  		if link.PublicKeyType != "ed25519" {
   151  			return nil, fmt.Errorf("client: link #%d in chain file has unknown public key type %q", i, link.PublicKeyType)
   152  		}
   153  
   154  		if l := len(link.PublicKey); l != ed25519.PublicKeySize {
   155  			return nil, fmt.Errorf("client: link #%d in chain file has bad public key of length %d", i, l)
   156  		}
   157  
   158  		if l := len(link.NonceOrBlind); l != protocol.NonceSize {
   159  			return nil, fmt.Errorf("client: link #%d in chain file has bad nonce/blind of length %d", i, l)
   160  		}
   161  
   162  		var nonce [protocol.NonceSize]byte
   163  		if i == 0 {
   164  			copy(nonce[:], link.NonceOrBlind[:])
   165  		} else {
   166  			nonce = protocol.CalculateChainNonce(chain.Links[i-1].Reply, link.NonceOrBlind[:])
   167  		}
   168  
   169  		if _, _, err := protocol.VerifyReply(link.Reply, link.PublicKey, nonce); err != nil {
   170  			return nil, fmt.Errorf("client: failed to verify link #%d in chain file", i)
   171  		}
   172  	}
   173  
   174  	return chain, nil
   175  }
   176  
   177  // timeSample represents a time sample from the network.
   178  type timeSample struct {
   179  	// server references the server that was queried.
   180  	server *config.Server
   181  
   182  	// base is a monotonic clock sample that is taken at a time before the
   183  	// network could have answered the query.
   184  	base *big.Int
   185  
   186  	// min is the minimum real-time (in Roughtime UTC microseconds) that
   187  	// could correspond to |base| (i.e. midpoint - radius).
   188  	min *big.Int
   189  
   190  	// max is the maximum real-time (in Roughtime UTC microseconds) that
   191  	// could correspond to |base| (i.e. midpoint + radius + query time).
   192  	max *big.Int
   193  
   194  	// queryDuration contains the amount of time that the server took to
   195  	// answer the query.
   196  	queryDuration time.Duration
   197  }
   198  
   199  // midpoint returns the average of the min and max times.
   200  func (s *timeSample) midpoint() *big.Int {
   201  	ret := new(big.Int).Add(s.min, s.max)
   202  	return ret.Rsh(ret, 1)
   203  }
   204  
   205  // alignTo updates s so that its base value matches that from reference.
   206  func (s *timeSample) alignTo(reference *timeSample) {
   207  	delta := new(big.Int).Sub(s.base, reference.base)
   208  	delta.Div(delta, big.NewInt(int64(time.Microsecond)))
   209  	s.base.Sub(s.base, delta)
   210  	s.min.Sub(s.min, delta)
   211  	s.max.Sub(s.max, delta)
   212  }
   213  
   214  // contains returns true iff p belongs to s
   215  func (s *timeSample) contains(p *big.Int) bool {
   216  	return s.max.Cmp(p) >= 0 && s.min.Cmp(p) <= 0
   217  }
   218  
   219  // overlaps returns true iff s and other have any timespan in common.
   220  func (s *timeSample) overlaps(other *timeSample) bool {
   221  	return s.max.Cmp(other.min) >= 0 && other.max.Cmp(s.min) >= 0
   222  }
   223  
   224  // query sends a request to s, appends it to chain, and returns the resulting
   225  // timeSample.
   226  func (c *Client) query(server *config.Server, chain *config.Chain) (*timeSample, error) {
   227  	var prevReply []byte
   228  	if len(chain.Links) > 0 {
   229  		prevReply = chain.Links[len(chain.Links)-1].Reply
   230  	}
   231  
   232  	var baseTime, replyTime time.Duration
   233  	var reply []byte
   234  	var nonce, blind [protocol.NonceSize]byte
   235  
   236  	for attempts := 0; attempts < c.numQueries(); attempts++ {
   237  		var request []byte
   238  		var err error
   239  		if nonce, blind, request, err = protocol.CreateRequest(rand.Reader, prevReply); err != nil {
   240  			return nil, err
   241  		}
   242  		if len(request) < protocol.MinRequestSize {
   243  			panic("internal error: bad request length")
   244  		}
   245  
   246  		udpAddr, err := serverUDPAddr(server)
   247  		if err != nil {
   248  			panic(err)
   249  		}
   250  
   251  		conn, err := net.DialUDP("udp", nil, udpAddr)
   252  		if err != nil {
   253  			return nil, err
   254  		}
   255  
   256  		conn.SetReadDeadline(time.Now().Add(c.queryTimeout()))
   257  		baseTime = c.now()
   258  		conn.Write(request)
   259  
   260  		var replyBytes [1024]byte
   261  		n, err := conn.Read(replyBytes[:])
   262  		if err == nil {
   263  			replyTime = c.now()
   264  			reply = replyBytes[:n]
   265  			break
   266  		}
   267  
   268  		if netErr, ok := err.(net.Error); ok {
   269  			if !netErr.Timeout() {
   270  				return nil, errors.New("client: error reading from UDP socket: " + err.Error())
   271  			}
   272  		}
   273  	}
   274  
   275  	if reply == nil {
   276  		return nil, fmt.Errorf("client: no reply from server %q", server.Name)
   277  	}
   278  
   279  	if replyTime < baseTime {
   280  		panic("broken monotonic clock")
   281  	}
   282  	queryDuration := replyTime - baseTime
   283  
   284  	midpoint, radius, err := protocol.VerifyReply(reply, server.PublicKey, nonce)
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  
   289  	if time.Duration(radius)*time.Microsecond > c.maxRadius() {
   290  		return nil, fmt.Errorf("client: radius (%d) too large", radius)
   291  	}
   292  
   293  	nonceOrBlind := blind[:]
   294  	if len(prevReply) == 0 {
   295  		nonceOrBlind = nonce[:]
   296  	}
   297  
   298  	chain.Links = append(chain.Links, config.Link{
   299  		PublicKeyType: "ed25519",
   300  		PublicKey:     server.PublicKey,
   301  		NonceOrBlind:  nonceOrBlind,
   302  		Reply:         reply,
   303  	})
   304  
   305  	queryDurationBig := new(big.Int).SetInt64(int64(queryDuration / time.Microsecond))
   306  	bigRadius := new(big.Int).SetUint64(uint64(radius))
   307  	min := new(big.Int).SetUint64(midpoint)
   308  	min.Sub(min, bigRadius)
   309  	min.Sub(min, queryDurationBig)
   310  
   311  	max := new(big.Int).SetUint64(midpoint)
   312  	max.Add(max, bigRadius)
   313  
   314  	return &timeSample{
   315  		server:        server,
   316  		base:          new(big.Int).SetInt64(int64(baseTime)),
   317  		min:           min,
   318  		max:           max,
   319  		queryDuration: queryDuration,
   320  	}, nil
   321  }
   322  
   323  func serverUDPAddr(server *config.Server) (*net.UDPAddr, error) {
   324  	for _, addr := range server.Addresses {
   325  		if addr.Protocol != "udp" {
   326  			continue
   327  		}
   328  
   329  		return net.ResolveUDPAddr("udp", addr.Address)
   330  	}
   331  
   332  	return nil, nil
   333  }
   334  
   335  // LoadServers loads information about known servers from the given JSON data.
   336  // It only extracts information about servers with Ed25519 public keys and UDP
   337  // address. The number of servers skipped because of unsupported requirements
   338  // is returned in numSkipped.
   339  func LoadServers(jsonData []byte) (servers []config.Server, numSkipped int, err error) {
   340  	var serversJSON config.ServersJSON
   341  	if err := json.Unmarshal(jsonData, &serversJSON); err != nil {
   342  		return nil, 0, err
   343  	}
   344  
   345  	seenNames := make(map[string]struct{})
   346  
   347  	for _, candidate := range serversJSON.Servers {
   348  		if _, ok := seenNames[candidate.Name]; ok {
   349  			return nil, 0, fmt.Errorf("client: duplicate server name %q", candidate.Name)
   350  		}
   351  		seenNames[candidate.Name] = struct{}{}
   352  
   353  		if candidate.PublicKeyType != "ed25519" {
   354  			numSkipped++
   355  			continue
   356  		}
   357  
   358  		udpAddr, err := serverUDPAddr(&candidate)
   359  
   360  		if err != nil {
   361  			return nil, 0, fmt.Errorf("client: server %q lists invalid UDP address: %s", candidate.Name, err)
   362  		}
   363  
   364  		if udpAddr == nil {
   365  			numSkipped++
   366  			continue
   367  		}
   368  
   369  		servers = append(servers, candidate)
   370  	}
   371  
   372  	if len(servers) == 0 {
   373  		return nil, 0, errors.New("client: no usable servers found")
   374  	}
   375  
   376  	return servers, 0, nil
   377  }
   378  
   379  // trimChain drops elements from the beginning of chain, as needed, so that its
   380  // length is <= n.
   381  func trimChain(chain *config.Chain, n int) {
   382  	if n <= 0 {
   383  		chain.Links = nil
   384  		return
   385  	}
   386  
   387  	if len(chain.Links) <= n {
   388  		return
   389  	}
   390  
   391  	numToTrim := len(chain.Links) - n
   392  	for i := 0; i < numToTrim; i++ {
   393  		// The NonceOrBlind of the first element is special because
   394  		// it's an nonce. All the others are blinds and are combined
   395  		// with the previous reply to make the nonce. That's not
   396  		// possible for the first element because there is no previous
   397  		// reply. Therefore, when removing the first element the blind
   398  		// of the next element needs to be converted to an nonce.
   399  		nonce := protocol.CalculateChainNonce(chain.Links[0].Reply, chain.Links[1].NonceOrBlind[:])
   400  		chain.Links[1].NonceOrBlind = nonce[:]
   401  		chain.Links = chain.Links[1:]
   402  	}
   403  }
   404  
   405  // intersection returns the timespan common to all the elements in samples,
   406  // which must be aligned to the same base. The caller must ensure that such a
   407  // timespan exists.
   408  func intersection(samples []*timeSample) *timeSample {
   409  	ret := &timeSample{
   410  		base: samples[0].base,
   411  		min:  new(big.Int).Set(samples[0].min),
   412  		max:  new(big.Int).Set(samples[0].max),
   413  	}
   414  
   415  	for _, sample := range samples[1:] {
   416  		if ret.min.Cmp(sample.min) < 0 {
   417  			ret.min.Set(sample.min)
   418  		}
   419  		if ret.max.Cmp(sample.max) > 0 {
   420  			ret.max.Set(sample.max)
   421  		}
   422  	}
   423  
   424  	return ret
   425  }
   426  
   427  // findNOverlapping finds an n-element subset of samples where all the
   428  // members overlap. It returns the intersection if such a subset exists.
   429  func findNOverlapping(samples []*timeSample, n int) (sampleIntersection *timeSample, ok bool) {
   430  	switch {
   431  	case n <= 0:
   432  		return nil, false
   433  	case n == 1:
   434  		return samples[0], true
   435  	}
   436  
   437  	overlapping := make([]*timeSample, 0, n)
   438  
   439  	for _, initial := range samples {
   440  		// An intersection of any subset of intervals will be an interval that contains
   441  		// the starting point of one of the intervals (possibly as its own starting point).
   442  		point := initial.min
   443  
   444  		for _, candidate := range samples {
   445  			if candidate.contains(point) {
   446  				overlapping = append(overlapping, candidate)
   447  			}
   448  
   449  			if len(overlapping) == n {
   450  				return intersection(overlapping), true
   451  			}
   452  		}
   453  
   454  		overlapping = overlapping[:0]
   455  	}
   456  
   457  	return nil, false
   458  }
   459  
   460  // TimeResult is the result of trying to establish the current time by querying
   461  // a number of servers.
   462  type TimeResult struct {
   463  	// MonoUTCDelta may be nil, in which case a time could not be
   464  	// established. Otherwise it contains the difference between the
   465  	// Roughtime epoch and the monotonic clock.
   466  	MonoUTCDelta *time.Duration
   467  
   468  	// ServerErrors maps from server name to query error.
   469  	ServerErrors map[string]error
   470  
   471  	// ServerInfo contains information about each server that was queried.
   472  	ServerInfo map[string]ServerInfo
   473  
   474  	// OutOfRangeAnswer is true if one or more of the queries contained a
   475  	// significantly incorrect time, as defined by MaxDifference. In this
   476  	// case, the reply will have been recorded in the chain.
   477  	OutOfRangeAnswer bool
   478  }
   479  
   480  // ServerInfo contains information from a specific server.
   481  type ServerInfo struct {
   482  	// QueryDuration is the amount of time that the server took to answer.
   483  	QueryDuration time.Duration
   484  
   485  	// Min and Max specify the time window given by the server. These
   486  	// values have been adjusted so that they are comparible across
   487  	// servers, even though they are queried at different times.
   488  	Min, Max *big.Int
   489  }
   490  
   491  // EstablishTime queries a number of servers until it has a quorum of
   492  // overlapping results, or it runs out of servers. Results from the querying
   493  // the servers are appended to chain.
   494  func (c *Client) EstablishTime(chain *config.Chain, quorum int, servers []config.Server) (TimeResult, error) {
   495  	perm := c.permutation(len(servers))
   496  	var samples []*timeSample
   497  	var intersection *timeSample
   498  	var result TimeResult
   499  
   500  	for len(perm) > 0 {
   501  		server := &servers[perm[0]]
   502  		perm = perm[1:]
   503  
   504  		sample, err := c.query(server, chain)
   505  		if err != nil {
   506  			if result.ServerErrors == nil {
   507  				result.ServerErrors = make(map[string]error)
   508  			}
   509  			result.ServerErrors[server.Name] = err
   510  			continue
   511  		}
   512  
   513  		if len(samples) > 0 {
   514  			sample.alignTo(samples[0])
   515  		}
   516  		samples = append(samples, sample)
   517  
   518  		if result.ServerInfo == nil {
   519  			result.ServerInfo = make(map[string]ServerInfo)
   520  		}
   521  		result.ServerInfo[server.Name] = ServerInfo{
   522  			QueryDuration: sample.queryDuration,
   523  			Min:           sample.min,
   524  			Max:           sample.max,
   525  		}
   526  
   527  		var ok bool
   528  		if intersection, ok = findNOverlapping(samples, quorum); ok {
   529  			break
   530  		}
   531  		intersection = nil
   532  	}
   533  
   534  	if intersection == nil {
   535  		return result, nil
   536  	}
   537  	midpoint := intersection.midpoint()
   538  
   539  	maxDifference := new(big.Int).SetUint64(uint64(c.maxDifference() / time.Microsecond))
   540  	for _, sample := range samples {
   541  		delta := new(big.Int).Sub(midpoint, sample.midpoint())
   542  		delta.Abs(delta)
   543  
   544  		if delta.Cmp(maxDifference) > 0 {
   545  			result.OutOfRangeAnswer = true
   546  			break
   547  		}
   548  	}
   549  
   550  	midpoint.Mul(midpoint, big.NewInt(1000))
   551  	delta := new(big.Int).Sub(midpoint, intersection.base)
   552  	if delta.BitLen() > 63 {
   553  		return result, errors.New("client: cannot represent difference between monotonic and UTC time")
   554  	}
   555  	monoUTCDelta := time.Duration(delta.Int64())
   556  	result.MonoUTCDelta = &monoUTCDelta
   557  
   558  	return result, nil
   559  }