github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/teams/proofs.go (about)

     1  package teams
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  
     7  	"golang.org/x/net/context"
     8  	"golang.org/x/sync/errgroup"
     9  
    10  	"github.com/davecgh/go-spew/spew"
    11  	"github.com/keybase/client/go/libkb"
    12  	"github.com/keybase/client/go/protocol/keybase1"
    13  )
    14  
    15  // newProofTerm creates a new proof term.
    16  // `lm` can be nil (it is for teams since SetTeamLinkMap is used)
    17  func newProofTerm(i keybase1.UserOrTeamID, s keybase1.SignatureMetadata, lm linkMapT) proofTerm {
    18  	return proofTerm{leafID: i, sigMeta: s, linkMap: lm}
    19  }
    20  
    21  type linkMapT map[keybase1.Seqno]keybase1.LinkID
    22  
    23  type proofTerm struct {
    24  	leafID  keybase1.UserOrTeamID
    25  	sigMeta keybase1.SignatureMetadata
    26  	linkMap linkMapT
    27  }
    28  
    29  func (t *proofTerm) shortForm() string {
    30  	return fmt.Sprintf("%v@%v", t.sigMeta.SigChainLocation.Seqno, t.leafID)
    31  
    32  }
    33  
    34  type proofTermBookends struct {
    35  	left  proofTerm
    36  	right *proofTerm
    37  }
    38  
    39  type proof struct {
    40  	a      proofTerm
    41  	b      proofTerm
    42  	reason string
    43  }
    44  
    45  func (p *proof) shortForm() string {
    46  	return fmt.Sprintf("%v --> %v '%v'", p.a.shortForm(), p.b.shortForm(), p.reason)
    47  }
    48  
    49  type proofIndex struct {
    50  	a keybase1.UserOrTeamID
    51  	b keybase1.UserOrTeamID
    52  }
    53  
    54  func (t proofTerm) seqno() keybase1.Seqno { return t.sigMeta.SigChainLocation.Seqno }
    55  func (t proofTerm) isPublic() bool {
    56  	return t.sigMeta.SigChainLocation.SeqType == keybase1.SeqType_PUBLIC
    57  }
    58  
    59  // comparison method only valid if `t` and `u` are known to be on the same chain
    60  func (t proofTerm) lessThanOrEqual(u proofTerm) bool {
    61  	return t.seqno() <= u.seqno()
    62  }
    63  
    64  // comparison method only valid if `t` and `u` are known to be on the same chain
    65  func (t proofTerm) equal(u proofTerm) bool {
    66  	return t.seqno() == u.seqno()
    67  }
    68  
    69  // comparison method only valid if `t` and `u` are known to be on the same chain
    70  func (t proofTerm) max(u proofTerm) proofTerm {
    71  	if t.lessThanOrEqual(u) {
    72  		return u
    73  	}
    74  	return t
    75  }
    76  
    77  // comparison method only valid if `t` and `u` are known to be on the same chain
    78  func (t proofTerm) min(u proofTerm) proofTerm {
    79  	if t.lessThanOrEqual(u) {
    80  		return t
    81  	}
    82  	return u
    83  }
    84  
    85  func newProofIndex(a keybase1.UserOrTeamID, b keybase1.UserOrTeamID) proofIndex {
    86  	return proofIndex{b, a}
    87  }
    88  
    89  type proofSetT struct {
    90  	libkb.Contextified
    91  	proofs       map[proofIndex][]proof
    92  	teamLinkMaps map[keybase1.TeamID]linkMapT
    93  }
    94  
    95  func newProofSet(g *libkb.GlobalContext) *proofSetT {
    96  	return &proofSetT{
    97  		Contextified: libkb.NewContextified(g),
    98  		proofs:       make(map[proofIndex][]proof),
    99  		teamLinkMaps: make(map[keybase1.TeamID]linkMapT),
   100  	}
   101  }
   102  
   103  // AddNeededHappensBeforeProof adds a new needed proof to the proof set. The
   104  // proof is that `a` happened before `b`.  If there are other proofs in the proof set
   105  // that prove the same thing, then we can tighten those proofs with a and b if
   106  // it makes sense.  For instance, if there is an existing proof that c<d,
   107  // but we know that c<a and b<d, then it suffices to replace c<d with a<b as
   108  // the needed proof. Each proof in the proof set in the end will correspond
   109  // to a merkle tree lookup, so it makes sense to be stingy. Return the modified
   110  // proof set with the new proofs needed, but the original argument p will
   111  // be mutated.
   112  func (p *proofSetT) AddNeededHappensBeforeProof(ctx context.Context, a proofTerm, b proofTerm, reason string) {
   113  
   114  	var action string
   115  	defer func() {
   116  		if action != "discard-easy" && !ShouldSuppressLogging(ctx) {
   117  			p.G().Log.CDebugf(ctx, "proofSet add(%v --> %v) [%v] '%v'", a.shortForm(), b.shortForm(), action, reason)
   118  		}
   119  	}()
   120  
   121  	idx := newProofIndex(a.leafID, b.leafID)
   122  
   123  	if idx.a.Equal(idx.b) {
   124  		// If both terms are on the same chain
   125  		if a.lessThanOrEqual(b) {
   126  			// The proof is self-evident.
   127  			// Discard it.
   128  			action = "discard-easy"
   129  			return
   130  		}
   131  		// The proof is self-evident FALSE.
   132  		// Add it and return immediately so the rest of this function doesn't have to trip over it.
   133  		// It should be failed later by the checker.
   134  		action = "added-easy-false"
   135  		p.proofs[idx] = append(p.proofs[idx], proof{a, b, reason})
   136  		return
   137  	}
   138  
   139  	set := p.proofs[idx]
   140  	for i := len(set) - 1; i >= 0; i-- {
   141  		existing := set[i]
   142  		if existing.a.lessThanOrEqual(a) && b.lessThanOrEqual(existing.b) {
   143  			// If the new proof is surrounded by the old proof.
   144  			existing.a = existing.a.max(a)
   145  			existing.b = existing.b.min(b)
   146  			set[i] = existing
   147  			action = "collapsed"
   148  			return
   149  		}
   150  		if existing.a.equal(a) && existing.b.lessThanOrEqual(b) {
   151  			// If the new proof is the same on the left and weaker on the right.
   152  			// Discard the new proof, as it is implied by the existing one.
   153  			action = "discard-weak"
   154  			return
   155  		}
   156  	}
   157  	action = "added"
   158  	p.proofs[idx] = append(p.proofs[idx], proof{a, b, reason})
   159  }
   160  
   161  // Set the latest link map for the team
   162  func (p *proofSetT) SetTeamLinkMap(ctx context.Context, teamID keybase1.TeamID, linkMap linkMapT) {
   163  	p.teamLinkMaps[teamID] = linkMap
   164  }
   165  
   166  func (p *proofSetT) AllProofs() []proof {
   167  	var ret []proof
   168  	for _, v := range p.proofs {
   169  		ret = append(ret, v...)
   170  	}
   171  	sort.Slice(ret, func(i, j int) bool {
   172  		cmp := ret[i].a.leafID.Compare(ret[j].a.leafID)
   173  		if cmp < 0 {
   174  			return true
   175  		}
   176  		if cmp > 0 {
   177  			return false
   178  		}
   179  		cmp = ret[i].b.leafID.Compare(ret[j].b.leafID)
   180  		if cmp < 0 {
   181  			return true
   182  		}
   183  		if cmp > 0 {
   184  			return false
   185  		}
   186  		cs := ret[i].a.sigMeta.SigChainLocation.Seqno - ret[j].a.sigMeta.SigChainLocation.Seqno
   187  		if cs < 0 {
   188  			return true
   189  		}
   190  		if cs > 0 {
   191  			return false
   192  		}
   193  		cs = ret[i].b.sigMeta.SigChainLocation.Seqno - ret[j].b.sigMeta.SigChainLocation.Seqno
   194  		return cs < 0
   195  	})
   196  	return ret
   197  }
   198  
   199  // lookupMerkleTreeChain loads the path up to the merkle tree and back down that corresponds
   200  // to this proof. It will contact the API server.  Returns the sigchain tail on success.
   201  func (p proof) lookupMerkleTreeChain(ctx context.Context, world LoaderContext) (ret *libkb.MerkleTriple, err error) {
   202  	return world.merkleLookupTripleInPast(ctx, p.a.isPublic(), p.a.leafID, p.b.sigMeta.PrevMerkleRootSigned)
   203  }
   204  
   205  // check a single proof. Call to the merkle API endpoint, and then ensure that the
   206  // data that comes back fits the proof and previously checked sigchain links.
   207  func (p proof) check(ctx context.Context, g *libkb.GlobalContext, world LoaderContext, proofSet *proofSetT) (err error) {
   208  	defer func() {
   209  		g.Log.CDebugf(ctx, "TeamLoader proofSet check1(%v) -> %v", p.shortForm(), err)
   210  	}()
   211  
   212  	triple, err := p.lookupMerkleTreeChain(ctx, world)
   213  	if err != nil {
   214  		return err
   215  	}
   216  
   217  	// laterSeqno is the tail of chain A at the time when B was signed
   218  	// earlierSeqno is the tail of chain A at the time when A was signed
   219  	laterSeqno := triple.Seqno
   220  	earlierSeqno := p.a.sigMeta.SigChainLocation.Seqno
   221  	if earlierSeqno > laterSeqno {
   222  		return NewProofError(p, fmt.Sprintf("seqno %d > %d", earlierSeqno, laterSeqno))
   223  	}
   224  
   225  	linkID, err := p.findLink(ctx, g, world, p.a.leafID, laterSeqno, p.a.linkMap, proofSet)
   226  	if err != nil {
   227  		return err
   228  	}
   229  
   230  	if !triple.LinkID.Export().Eq(linkID) {
   231  		g.Log.CDebugf(ctx, "proof error: %s", spew.Sdump(p))
   232  		return NewProofError(p, fmt.Sprintf("hash mismatch: %s != %s", triple.LinkID, linkID))
   233  	}
   234  	return nil
   235  }
   236  
   237  // Find the LinkID for the leaf at the seqno.
   238  func (p proof) findLink(ctx context.Context, g *libkb.GlobalContext, world LoaderContext, leafID keybase1.UserOrTeamID, seqno keybase1.Seqno, firstLinkMap linkMapT, proofSet *proofSetT) (linkID keybase1.LinkID, err error) {
   239  	lm := firstLinkMap
   240  
   241  	if leafID.IsTeamOrSubteam() {
   242  		// Pull in the latest link map, instead of the one from the proof object.
   243  		tid := leafID.AsTeamOrBust()
   244  		lm2, ok := proofSet.teamLinkMaps[tid]
   245  		if ok {
   246  			lm = lm2
   247  		}
   248  	}
   249  	if lm == nil {
   250  		return linkID, NewProofError(p, "nil link map")
   251  	}
   252  
   253  	linkID, ok := lm[seqno]
   254  	if ok {
   255  		return linkID, nil
   256  	}
   257  
   258  	// We loaded this user originally to get a sigchain as fresh as a certain key provisioning.
   259  	// In this scenario, we might need a fresher version, so force a poll all the way through
   260  	// the server, and then try again. If we fail the second time, we a force repoll, then
   261  	// we're toast.
   262  	if leafID.IsUser() {
   263  		g.Log.CDebugf(ctx, "proof#findLink: missed load for %s at %d; trying a force repoll", leafID.String(), seqno)
   264  		lm, err := world.forceLinkMapRefreshForUser(ctx, leafID.AsUserOrBust())
   265  		if err != nil {
   266  			return linkID, err
   267  		}
   268  		linkID, ok = lm[seqno]
   269  	}
   270  
   271  	if !ok {
   272  		return linkID, NewProofError(p, fmt.Sprintf("no linkID for seqno %d", seqno))
   273  	}
   274  	return linkID, nil
   275  }
   276  
   277  func (p *proofSetT) checkRequired() bool {
   278  	return len(p.proofs) > 0
   279  }
   280  
   281  // check the entire proof set, failing if any one proof fails.
   282  func (p *proofSetT) check(ctx context.Context, world LoaderContext, parallel bool) (err error) {
   283  	defer p.G().CTrace(ctx, "TeamLoader proofSet check", &err)()
   284  
   285  	if parallel {
   286  		return p.checkParallel(ctx, world)
   287  	}
   288  
   289  	var total int
   290  	for _, v := range p.proofs {
   291  		total += len(v)
   292  	}
   293  
   294  	var i int
   295  	for _, v := range p.proofs {
   296  		for _, proof := range v {
   297  			p.G().Log.CDebugf(ctx, "TeamLoader proofSet check [%v / %v]", i, total)
   298  			err = proof.check(ctx, p.G(), world, p)
   299  			if err != nil {
   300  				return err
   301  			}
   302  			i++
   303  		}
   304  	}
   305  	return nil
   306  }
   307  
   308  // check the entire proof set, failing if any one proof fails. (parallel version)
   309  func (p *proofSetT) checkParallel(ctx context.Context, world LoaderContext) (err error) {
   310  
   311  	var total int
   312  	for _, v := range p.proofs {
   313  		total += len(v)
   314  	}
   315  	p.G().Log.CDebugf(ctx, "TeamLoader proofSet check parallel [%v]", total)
   316  
   317  	queue := make(chan proof)
   318  	go func() {
   319  		for _, v := range p.proofs {
   320  			for _, proof := range v {
   321  				queue <- proof
   322  			}
   323  		}
   324  		close(queue)
   325  	}()
   326  
   327  	group, ctx := errgroup.WithContext(libkb.CopyTagsToBackground(ctx))
   328  	const pipeline = 20
   329  	for i := 0; i < pipeline; i++ {
   330  		group.Go(func() error {
   331  			for {
   332  				select {
   333  				case <-ctx.Done():
   334  					return ctx.Err()
   335  				case proof, ok := <-queue:
   336  					if !ok {
   337  						return nil
   338  					}
   339  					err = proof.check(ctx, p.G(), world, p)
   340  					if err != nil {
   341  						return err
   342  					}
   343  				}
   344  			}
   345  		})
   346  	}
   347  
   348  	return group.Wait()
   349  }