github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/teams/treeloader.go (about)

     1  package teams
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"github.com/keybase/client/go/libkb"
    10  	"github.com/keybase/client/go/protocol/keybase1"
    11  	"golang.org/x/sync/errgroup"
    12  )
    13  
    14  // Treeloader is an ephemeral struct for loading the memberships of `targetUsername` in the partial
    15  // team tree surrounding `targetTeamID`.
    16  // Its behavior is described at protocol/avdl/keybase1/teams.avdl:loadTeamTreeMemberships.
    17  type Treeloader struct {
    18  	// arguments
    19  	targetUsername string
    20  	targetTeamID   keybase1.TeamID
    21  	guid           int
    22  
    23  	includeAncestors bool
    24  
    25  	// initial computed values
    26  	targetUV       keybase1.UserVersion
    27  	targetTeamName keybase1.TeamName
    28  
    29  	// mock for error testing
    30  	Converter TreeloaderStateConverter
    31  }
    32  
    33  var _ TreeloaderStateConverter = &Treeloader{}
    34  
    35  func NewTreeloader(mctx libkb.MetaContext, targetUsername string,
    36  	targetTeamID keybase1.TeamID, guid int, includeAncestors bool) (*Treeloader, error) {
    37  
    38  	larg := libkb.NewLoadUserArgWithMetaContext(mctx).WithName(targetUsername).WithForcePoll(true)
    39  	upak, _, err := mctx.G().GetUPAKLoader().LoadV2(larg)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  
    44  	l := &Treeloader{
    45  		targetUsername:   targetUsername,
    46  		targetTeamID:     targetTeamID,
    47  		guid:             guid,
    48  		targetUV:         upak.Current.ToUserVersion(),
    49  		includeAncestors: includeAncestors,
    50  	}
    51  	l.Converter = l
    52  	return l, nil
    53  }
    54  
    55  // LoadSync requires all loads to succeed, or errors out.
    56  func (l *Treeloader) LoadSync(mctx libkb.MetaContext) (res []keybase1.TeamTreeMembership,
    57  	err error) {
    58  	defer mctx.Trace(fmt.Sprintf("Treeloader.LoadSync(%s, %s)",
    59  		l.targetTeamID, l.targetUV), &err)()
    60  
    61  	ch, cancel, err := l.loadAsync(mctx)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	// Stop load if we error out early
    66  	defer cancel()
    67  	for notification := range ch {
    68  		switch notification.typ {
    69  		case treeloaderNotificationTypePartial:
    70  			s, err := notification.partialNotification.S()
    71  			if err != nil {
    72  				return nil, fmt.Errorf("Treeloader.LoadSync: failed to load subtree; bailing: %s",
    73  					err)
    74  			}
    75  			if s == keybase1.TeamTreeMembershipStatus_ERROR {
    76  				return nil, fmt.Errorf("Treeloader.LoadSync: failed to load subtree; bailing: %s",
    77  					notification.partialNotification.Error().Message)
    78  			}
    79  
    80  			res = append(res, keybase1.TeamTreeMembership{
    81  				Result:         *notification.partialNotification,
    82  				TeamName:       notification.teamName.String(),
    83  				TargetTeamID:   l.targetTeamID,
    84  				TargetUsername: l.targetUsername,
    85  				Guid:           l.guid,
    86  			})
    87  		default:
    88  		}
    89  	}
    90  
    91  	return res, nil
    92  }
    93  
    94  func (l *Treeloader) LoadAsync(mctx libkb.MetaContext) (err error) {
    95  	defer mctx.Trace(fmt.Sprintf("Treeloader.LoadAsync(%s, %s)",
    96  		l.targetTeamID, l.targetUV), &err)()
    97  
    98  	ch, _, err := l.loadAsync(mctx)
    99  	if err != nil {
   100  		return err
   101  	}
   102  	go func() {
   103  		// Because Go channels can be closed, we don't need to check expectedCount against the
   104  		// number of received partials like RPC clients do.
   105  		for notification := range ch {
   106  			switch notification.typ {
   107  			case treeloaderNotificationTypeDone:
   108  				l.notifyDone(mctx, *notification.doneNotification)
   109  			case treeloaderNotificationTypePartial:
   110  				l.notifyPartial(mctx, notification.teamName, *notification.partialNotification)
   111  			}
   112  		}
   113  	}()
   114  
   115  	return nil
   116  }
   117  
   118  type treeloaderNotificationType int
   119  
   120  const (
   121  	treeloaderNotificationTypeDone    = 0
   122  	treeloaderNotificationTypePartial = 1
   123  )
   124  
   125  type notification struct {
   126  	typ                 treeloaderNotificationType
   127  	teamName            keybase1.TeamName
   128  	partialNotification *keybase1.TeamTreeMembershipResult
   129  	doneNotification    *int
   130  }
   131  
   132  func newPartialNotification(teamName keybase1.TeamName,
   133  	result keybase1.TeamTreeMembershipResult) notification {
   134  	return notification{
   135  		typ:                 treeloaderNotificationTypePartial,
   136  		teamName:            teamName,
   137  		partialNotification: &result,
   138  	}
   139  }
   140  func newDoneNotification(teamName keybase1.TeamName, expectedCount int) notification {
   141  	return notification{
   142  		typ:              treeloaderNotificationTypeDone,
   143  		teamName:         teamName,
   144  		doneNotification: &expectedCount,
   145  	}
   146  }
   147  
   148  func (l *Treeloader) loadAsync(mctx libkb.MetaContext) (ch chan notification,
   149  	cancel context.CancelFunc, err error) {
   150  	defer mctx.Trace(fmt.Sprintf("Treeloader.loadAsync(%s, %s)",
   151  		l.targetTeamID, l.targetUV), nil)()
   152  
   153  	start := time.Now()
   154  
   155  	target, err := GetForTeamManagementByTeamID(mctx.Ctx(), mctx.G(),
   156  		l.targetTeamID, true /* needAdmin */)
   157  	if err != nil {
   158  		return nil, nil, err
   159  	}
   160  
   161  	l.targetTeamName = target.Name()
   162  
   163  	ch = make(chan notification)
   164  	imctx, cancel := mctx.BackgroundWithLogTags().WithContextCancel()
   165  
   166  	// Load rest of team tree asynchronously.
   167  	go func(imctx libkb.MetaContext, start time.Time, targetChainState *TeamSigChainState,
   168  		ch chan notification) {
   169  		expectedCount := l.loadRecursive(imctx, l.targetTeamID, l.targetTeamName,
   170  			targetChainState, ch)
   171  		ch <- newDoneNotification(l.targetTeamName, int(expectedCount))
   172  		close(ch)
   173  		imctx.G().RuntimeStats.PushPerfEvent(keybase1.PerfEvent{
   174  			EventType: keybase1.PerfEventType_TEAMTREELOAD,
   175  			Message: fmt.Sprintf("Loaded %d teams in tree for %s",
   176  				expectedCount, l.targetTeamName),
   177  			Ctime: keybase1.ToTime(start),
   178  		})
   179  	}(imctx, start, target.chain(), ch)
   180  
   181  	return ch, cancel, nil
   182  }
   183  
   184  func (l *Treeloader) loadRecursive(mctx libkb.MetaContext, teamID keybase1.TeamID,
   185  	teamName keybase1.TeamName, targetChainState *TeamSigChainState,
   186  	ch chan notification) (expectedCount int32) {
   187  	defer mctx.Trace(fmt.Sprintf("Treeloader.loadRecursive(%s, %s, %s)",
   188  		l.targetTeamName, l.targetUV, l.targetUsername), nil)()
   189  
   190  	var result keybase1.TeamTreeMembershipResult
   191  	var subteams []keybase1.TeamIDAndName
   192  	select {
   193  	case <-mctx.Ctx().Done():
   194  		result = l.NewErrorResult(mctx.Ctx().Err(), teamName)
   195  	default:
   196  		// If it is the initial call, the caller passes in the sigchain state so we don't need to reload
   197  		// it. Otherwise, do a team load.
   198  		if targetChainState != nil {
   199  			result = l.Converter.ProcessSigchainState(mctx, teamName, &targetChainState.inner)
   200  			subteams = targetChainState.ListSubteams()
   201  		} else {
   202  			team, err := GetForTeamManagementByTeamID(mctx.Ctx(), mctx.G(), teamID, true /* needAdmin */)
   203  			if err != nil {
   204  				result = l.NewErrorResult(err, teamName)
   205  			} else {
   206  				result = l.Converter.ProcessSigchainState(mctx, teamName, &team.chain().inner)
   207  				subteams = team.chain().ListSubteams()
   208  			}
   209  		}
   210  	}
   211  	expectedCount = 1
   212  	ch <- newPartialNotification(teamName, result)
   213  	s, _ := result.S()
   214  	if s == keybase1.TeamTreeMembershipStatus_ERROR {
   215  		mctx.Debug("Treeloader.loadRecursive: short-circuiting load due to failure: %+v", result)
   216  		return expectedCount
   217  	}
   218  
   219  	np := l.getPosition(teamName)
   220  
   221  	eg, ctx := errgroup.WithContext(mctx.Ctx())
   222  	mctx = mctx.WithContext(ctx)
   223  	// Load ancestors
   224  	if l.includeAncestors && np == nodePositionTarget && !teamName.IsRootTeam() {
   225  		eg.Go(func() error {
   226  			incr := l.loadAncestors(mctx, teamID, teamName, ch)
   227  			mctx.Debug("Treeloader.loadRecursive: loaded %d teams from ancestors", incr)
   228  			atomic.AddInt32(&expectedCount, incr)
   229  			return nil
   230  		})
   231  	}
   232  	// Load subtree
   233  	// Because we load parents before children, the child's load can use the cached parent's team
   234  	// so we only make one team/get per team.
   235  	for _, idAndName := range subteams {
   236  		idAndName := idAndName
   237  		// This is unbounded but assuming subteam spread isn't too high, should be ok.
   238  		eg.Go(func() error {
   239  			incr := l.loadRecursive(mctx, idAndName.Id, idAndName.Name, nil, ch)
   240  			mctx.Debug("Treeloader.loadRecursive: loaded %d teams from subtree", incr)
   241  			atomic.AddInt32(&expectedCount, incr)
   242  			return nil
   243  		})
   244  	}
   245  	// Should not return any errors since we did error handling ourselves
   246  	_ = eg.Wait()
   247  
   248  	return expectedCount
   249  }
   250  
   251  func (l *Treeloader) loadAncestors(mctx libkb.MetaContext, teamID keybase1.TeamID,
   252  	teamName keybase1.TeamName, ch chan notification) (expectedCount int32) {
   253  	defer mctx.Trace("Treeloader.loadAncestors", nil)()
   254  
   255  	handleAncestor := func(t keybase1.TeamSigChainState, ancestorTeamName keybase1.TeamName) error {
   256  		result := l.Converter.ProcessSigchainState(mctx, ancestorTeamName, &t)
   257  		s, err := result.S()
   258  		if err != nil {
   259  			return fmt.Errorf("failed to get result status: %w", err)
   260  		}
   261  		// Short-circuit ancestor load if this resulted in an error, to keep symmetry with behavior
   262  		// if the ancestor team load failed. We can continue if the result was HIDDEN. The switch
   263  		// statement below will catch the error and send a notification.
   264  		if s == keybase1.TeamTreeMembershipStatus_ERROR {
   265  			return fmt.Errorf("failed to load ancestor: %s", result.Error().Message)
   266  		}
   267  		ch <- newPartialNotification(ancestorTeamName, result)
   268  		return nil
   269  	}
   270  	err := mctx.G().GetTeamLoader().MapTeamAncestors(
   271  		mctx.Ctx(), handleAncestor, teamID, "Treeloader.Load", nil)
   272  
   273  	switch e := err.(type) {
   274  	case nil:
   275  		expectedCount = int32(teamName.Depth()) - 1
   276  	case *MapAncestorsError:
   277  		mctx.Debug("loadTeamAncestorsMemberships: map failed: %s at idx %d", e,
   278  			e.failedLoadingAtAncestorIdx)
   279  
   280  		// Calculate the team name of the team it failed at
   281  		// e.g. if failedLoadingAtAncestorIdx was 1 and the target team was A.B.C,
   282  		// maxInt(0, 3 - 1 - 1) = 1, and A.B.C[:1+1] = A.B
   283  		idx := maxInt(0, teamName.Depth()-1-e.failedLoadingAtAncestorIdx)
   284  		nameFailedAt := keybase1.TeamName{Parts: teamName.Parts[:idx+1]}
   285  
   286  		expectedCount = maxInt32(0, int32(e.failedLoadingAtAncestorIdx))
   287  		ch <- newPartialNotification(nameFailedAt, l.NewErrorResult(err, nameFailedAt))
   288  	default:
   289  		// Should never happen, since MapTeamAncestors should wrap every error as a
   290  		// MapAncestorsError.
   291  		// Not sure where the error failed: prompt to reload the entire thing
   292  		// Also not sure if it failed at a root for now, so say we didn't to err on the side of
   293  		// caution.
   294  		mctx.Debug("loadTeamAncestorsMemberships: map failed for unknown reason: %s", e)
   295  		ch <- newPartialNotification(teamName, l.NewErrorResult(e, teamName))
   296  	}
   297  	return expectedCount
   298  }
   299  
   300  func (l *Treeloader) ProcessSigchainState(mctx libkb.MetaContext, teamName keybase1.TeamName,
   301  	s *keybase1.TeamSigChainState) keybase1.TeamTreeMembershipResult {
   302  	np := l.getPosition(teamName)
   303  
   304  	if np == nodePositionAncestor {
   305  		meUV, err := mctx.G().GetMeUV(mctx.Ctx())
   306  		// Should never get an error here since we're logged in.
   307  		if err != nil || s.UserRole(meUV) == keybase1.TeamRole_NONE {
   308  			return keybase1.NewTeamTreeMembershipResultWithHidden()
   309  		}
   310  	}
   311  
   312  	role := s.UserRole(l.targetUV)
   313  	var joinTime *keybase1.Time
   314  	if role != keybase1.TeamRole_NONE {
   315  		t, err := s.GetUserLastJoinTime(l.targetUV)
   316  		if err != nil {
   317  			mctx.Debug("Treeloader.ProcessSigchainState: failed to compute join time for %s: %s",
   318  				l.targetUV, err)
   319  		} else {
   320  			joinTime = &t
   321  		}
   322  	}
   323  	return keybase1.NewTeamTreeMembershipResultWithOk(keybase1.TeamTreeMembershipValue{
   324  		Role:     role,
   325  		JoinTime: joinTime,
   326  		TeamID:   s.Id,
   327  	})
   328  }
   329  
   330  func (l *Treeloader) NewErrorResult(err error,
   331  	teamName keybase1.TeamName) keybase1.TeamTreeMembershipResult {
   332  	np := l.getPosition(teamName)
   333  	return keybase1.NewTeamTreeMembershipResultWithError(keybase1.TeamTreeError{
   334  		Message:           fmt.Sprintf("%s", err),
   335  		WillSkipSubtree:   np != nodePositionAncestor,
   336  		WillSkipAncestors: !teamName.IsRootTeam() && np != nodePositionChild,
   337  	})
   338  }
   339  
   340  func (l *Treeloader) getPosition(teamName keybase1.TeamName) nodePosition {
   341  	if l.targetTeamName.Eq(teamName) {
   342  		return nodePositionTarget
   343  	}
   344  	if l.targetTeamName.IsAncestorOf(teamName) {
   345  		return nodePositionChild
   346  	}
   347  	return nodePositionAncestor
   348  }
   349  
   350  func (l *Treeloader) notifyDone(mctx libkb.MetaContext, expectedCount int) {
   351  	doneResult := keybase1.TeamTreeMembershipsDoneResult{
   352  		Guid:           l.guid,
   353  		TargetTeamID:   l.targetTeamID,
   354  		TargetUsername: l.targetUsername,
   355  		ExpectedCount:  expectedCount,
   356  	}
   357  	mctx.G().NotifyRouter.HandleTeamTreeMembershipsDone(mctx.Ctx(), doneResult)
   358  }
   359  
   360  func (l *Treeloader) notifyPartial(mctx libkb.MetaContext, teamName keybase1.TeamName,
   361  	result keybase1.TeamTreeMembershipResult) {
   362  	partial := keybase1.TeamTreeMembership{
   363  		Guid:           l.guid,
   364  		TargetTeamID:   l.targetTeamID,
   365  		TargetUsername: l.targetUsername,
   366  		TeamName:       teamName.String(),
   367  		Result:         result,
   368  	}
   369  	mctx.G().NotifyRouter.HandleTeamTreeMembershipsPartial(mctx.Ctx(), partial)
   370  }
   371  
   372  func maxInt32(a, b int32) int32 {
   373  	if a > b {
   374  		return a
   375  	}
   376  	return b
   377  }
   378  
   379  func maxInt(a, b int) int {
   380  	if a > b {
   381  		return a
   382  	}
   383  	return b
   384  }
   385  
   386  type nodePosition int
   387  
   388  const (
   389  	nodePositionTarget   nodePosition = 0
   390  	nodePositionChild    nodePosition = 1
   391  	nodePositionAncestor nodePosition = 2
   392  )
   393  
   394  type TreeloaderStateConverter interface {
   395  	ProcessSigchainState(libkb.MetaContext, keybase1.TeamName,
   396  		*keybase1.TeamSigChainState) keybase1.TeamTreeMembershipResult
   397  }