github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/read_replica_database.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"strings"
    22  	"sync"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/ref"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    31  	"github.com/dolthub/dolt/go/store/datas"
    32  	"github.com/dolthub/dolt/go/store/hash"
    33  	"github.com/dolthub/dolt/go/store/types"
    34  )
    35  
    36  type ReadReplicaDatabase struct {
    37  	Database
    38  	remote  env.Remote
    39  	srcDB   *doltdb.DoltDB
    40  	tmpDir  string
    41  	limiter *limiter
    42  }
    43  
    44  var _ dsess.SqlDatabase = ReadReplicaDatabase{}
    45  var _ sql.VersionedDatabase = ReadReplicaDatabase{}
    46  var _ sql.TableDropper = ReadReplicaDatabase{}
    47  var _ sql.TableCreator = ReadReplicaDatabase{}
    48  var _ sql.TemporaryTableCreator = ReadReplicaDatabase{}
    49  var _ sql.TableRenamer = ReadReplicaDatabase{}
    50  var _ sql.TriggerDatabase = &ReadReplicaDatabase{}
    51  var _ sql.StoredProcedureDatabase = ReadReplicaDatabase{}
    52  var _ dsess.RemoteReadReplicaDatabase = ReadReplicaDatabase{}
    53  
    54  var ErrFailedToLoadReplicaDB = errors.New("failed to load replica database")
    55  
    56  var EmptyReadReplica = ReadReplicaDatabase{}
    57  
    58  func NewReadReplicaDatabase(ctx context.Context, db Database, remoteName string, dEnv *env.DoltEnv) (ReadReplicaDatabase, error) {
    59  	remotes, err := dEnv.GetRemotes()
    60  	if err != nil {
    61  		return EmptyReadReplica, err
    62  	}
    63  
    64  	remote, ok := remotes.Get(remoteName)
    65  	if !ok {
    66  		return EmptyReadReplica, fmt.Errorf("%w: '%s'", env.ErrRemoteNotFound, remoteName)
    67  	}
    68  
    69  	srcDB, err := remote.GetRemoteDB(ctx, types.Format_Default, dEnv)
    70  	if err != nil {
    71  		return EmptyReadReplica, err
    72  	}
    73  
    74  	tmpDir, err := dEnv.TempTableFilesDir()
    75  	if err != nil {
    76  		return EmptyReadReplica, err
    77  	}
    78  
    79  	return ReadReplicaDatabase{
    80  		Database: db,
    81  		remote:   remote,
    82  		tmpDir:   tmpDir,
    83  		srcDB:    srcDB,
    84  		limiter:  newLimiter(),
    85  	}, nil
    86  }
    87  
    88  func (rrd ReadReplicaDatabase) WithBranchRevision(requestedName string, branchSpec dsess.SessionDatabaseBranchSpec) (dsess.SqlDatabase, error) {
    89  	rrd.rsr, rrd.rsw = branchSpec.RepoState, branchSpec.RepoState
    90  	rrd.revision = branchSpec.Branch
    91  	rrd.revType = dsess.RevisionTypeBranch
    92  	rrd.requestedName = requestedName
    93  
    94  	return rrd, nil
    95  }
    96  
    97  func (rrd ReadReplicaDatabase) ValidReplicaState(ctx *sql.Context) bool {
    98  	// srcDB will be nil in the case the remote was specified incorrectly and startup errors are suppressed
    99  	return rrd.srcDB != nil
   100  }
   101  
   102  // InitialDBState implements dsess.SessionDatabase
   103  // This seems like a pointless override from the embedded Database implementation, but it's necessary to pass the
   104  // correct pointer type to the session initializer.
   105  func (rrd ReadReplicaDatabase) InitialDBState(ctx *sql.Context) (dsess.InitialDbState, error) {
   106  	return initialDBState(ctx, rrd, rrd.revision)
   107  }
   108  
   109  func (rrd ReadReplicaDatabase) DoltDatabases() []*doltdb.DoltDB {
   110  	return []*doltdb.DoltDB{rrd.ddb, rrd.srcDB}
   111  }
   112  
   113  func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error {
   114  	ctx.GetLogger().Tracef("pulling from remote %s for database %s", rrd.remote.Name, rrd.Name())
   115  
   116  	_, headsArg, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateHeads)
   117  	if !ok {
   118  		return sql.ErrUnknownSystemVariable.New(dsess.ReplicateHeads)
   119  	}
   120  
   121  	_, allHeads, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateAllHeads)
   122  	if !ok {
   123  		return sql.ErrUnknownSystemVariable.New(dsess.ReplicateAllHeads)
   124  	}
   125  
   126  	behavior := pullBehaviorFastForward
   127  	if ReadReplicaForcePull() {
   128  		behavior = pullBehaviorForcePull
   129  	}
   130  
   131  	err := rrd.srcDB.Rebase(ctx)
   132  	if err != nil {
   133  		return err
   134  	}
   135  
   136  	remoteRefs, localRefs, toDelete, err := getReplicationRefs(ctx, rrd)
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	switch {
   142  	case headsArg != "" && allHeads == dsess.SysVarTrue:
   143  		ctx.GetLogger().Warnf("cannot set both @@dolt_replicate_heads and @@dolt_replicate_all_heads, replication disabled")
   144  		return nil
   145  	case headsArg != "":
   146  		heads, ok := headsArg.(string)
   147  		if !ok {
   148  			return sql.ErrInvalidSystemVariableValue.New(dsess.ReplicateHeads)
   149  		}
   150  
   151  		branchesToPull := make(map[string]bool)
   152  		for _, branch := range strings.Split(heads, ",") {
   153  			if !containsWildcards(branch) {
   154  				branchesToPull[branch] = true
   155  			} else {
   156  				expandedBranches, err := rrd.expandWildcardBranchPattern(ctx, branch)
   157  				if err != nil {
   158  					return err
   159  				}
   160  
   161  				for _, expandedBranch := range expandedBranches {
   162  					branchesToPull[expandedBranch] = true
   163  				}
   164  
   165  				if len(expandedBranches) == 0 {
   166  					ctx.GetLogger().Warnf("branch pattern '%s' did not match any branches", branch)
   167  				} else {
   168  					ctx.GetLogger().Debugf("expanded '%s' to: %s", branch, strings.Join(expandedBranches, ","))
   169  				}
   170  			}
   171  		}
   172  
   173  		// Reduce the remote branch list to only the ones configured to replicate
   174  		prunedRefs := make([]doltdb.RefWithHash, len(branchesToPull))
   175  		pruneI := 0
   176  		for _, remoteBranch := range remoteRefs {
   177  			if remoteBranch.Ref.GetType() == ref.BranchRefType && branchesToPull[remoteBranch.Ref.GetPath()] {
   178  				prunedRefs[pruneI] = remoteBranch
   179  				pruneI++
   180  			}
   181  			delete(branchesToPull, remoteBranch.Ref.GetPath())
   182  		}
   183  
   184  		if len(branchesToPull) > 0 {
   185  			// just use the first not-found branch as the error string
   186  			var branch string
   187  			for b := range branchesToPull {
   188  				branch = b
   189  				break
   190  			}
   191  
   192  			err := fmt.Errorf("unable to find %q on %q; branch not found", branch, rrd.remote.Name)
   193  			if err != nil {
   194  				return err
   195  			}
   196  		}
   197  
   198  		remoteRefs = prunedRefs
   199  		_, err = pullBranches(ctx, rrd, remoteRefs, localRefs, behavior)
   200  		if err != nil {
   201  			return err
   202  		}
   203  
   204  	case allHeads == int8(1):
   205  		_, err = pullBranches(ctx, rrd, remoteRefs, localRefs, behavior)
   206  		if err != nil {
   207  			return err
   208  		}
   209  
   210  		err = deleteBranches(ctx, rrd, toDelete)
   211  		if err != nil {
   212  			return err
   213  		}
   214  	default:
   215  		ctx.GetLogger().Warnf("must set either @@dolt_replicate_heads or @@dolt_replicate_all_heads, replication disabled")
   216  		return nil
   217  	}
   218  
   219  	return nil
   220  }
   221  
   222  // CreateLocalBranchFromRemote pulls the given branch from the remote database and creates a local tracking branch for
   223  // it. This is only used for initializing a new local branch being pulled from a remote during connection
   224  // initialization, and doesn't do the full work of remote synchronization that happens on transaction start.
   225  func (rrd ReadReplicaDatabase) CreateLocalBranchFromRemote(ctx *sql.Context, branchRef ref.BranchRef) error {
   226  	_, err := rrd.limiter.Run(ctx, "pullNewBranch", func() (any, error) {
   227  		// because several clients can queue up waiting to create the same local branch, double check to see if this
   228  		// work was already done and bail early if so
   229  		_, branchExists, err := rrd.ddb.HasBranch(ctx, branchRef.GetPath())
   230  		if err != nil {
   231  			return nil, err
   232  		}
   233  
   234  		if branchExists {
   235  			return nil, nil
   236  		}
   237  
   238  		cm, err := actions.FetchRemoteBranch(ctx, rrd.tmpDir, rrd.remote, rrd.srcDB, rrd.ddb, branchRef, actions.NoopRunProgFuncs, actions.NoopStopProgFuncs)
   239  		if err != nil {
   240  			return nil, err
   241  		}
   242  
   243  		cmHash, err := cm.HashOf()
   244  		if err != nil {
   245  			return nil, err
   246  		}
   247  
   248  		// create refs/heads/branch dataset
   249  		err = rrd.ddb.NewBranchAtCommit(ctx, branchRef, cm, nil)
   250  		if err != nil {
   251  			return nil, err
   252  		}
   253  
   254  		err = rrd.srcDB.Rebase(ctx)
   255  		if err != nil {
   256  			return nil, err
   257  		}
   258  
   259  		_, err = pullBranches(ctx, rrd, []doltdb.RefWithHash{{
   260  			Ref:  branchRef,
   261  			Hash: cmHash,
   262  		}}, nil, pullBehaviorFastForward)
   263  		if err != nil {
   264  			return nil, err
   265  		}
   266  
   267  		return nil, err
   268  	})
   269  
   270  	return err
   271  }
   272  
   273  type pullBehavior bool
   274  
   275  const pullBehaviorFastForward pullBehavior = false
   276  const pullBehaviorForcePull pullBehavior = true
   277  
   278  // pullBranches pulls the remote branches named and returns the map of their hashes keyed by branch path.
   279  func pullBranches(
   280  	ctx *sql.Context,
   281  	rrd ReadReplicaDatabase,
   282  	remoteRefs []doltdb.RefWithHash,
   283  	localRefs []doltdb.RefWithHash,
   284  	behavior pullBehavior,
   285  ) (map[string]doltdb.RefWithHash, error) {
   286  	localRefsByPath := make(map[string]doltdb.RefWithHash)
   287  	remoteRefsByPath := make(map[string]doltdb.RefWithHash)
   288  	remoteHashes := make([]hash.Hash, len(remoteRefs))
   289  
   290  	for i, b := range remoteRefs {
   291  		remoteRefsByPath[b.Ref.GetPath()] = b
   292  		remoteHashes[i] = b.Hash
   293  	}
   294  
   295  	for _, b := range localRefs {
   296  		localRefsByPath[b.Ref.GetPath()] = b
   297  	}
   298  
   299  	// XXX: Our view of which remote branches to pull and what to set the
   300  	// local branches to was computed outside of the limiter, concurrently
   301  	// with other possible attempts to pull from the remote. Now we are
   302  	// applying changes based on that view. This seems capable of rolling
   303  	// back changes which were applied from another thread.
   304  
   305  	_, err := rrd.limiter.Run(ctx, "-all", func() (any, error) {
   306  		pullErr := rrd.ddb.PullChunks(ctx, rrd.tmpDir, rrd.srcDB, remoteHashes, nil, nil)
   307  		if pullErr != nil {
   308  			return nil, pullErr
   309  		}
   310  
   311  	REFS: // every successful pass through the loop below must end with `continue REFS` to get out of the retry loop
   312  		for _, remoteRef := range remoteRefs {
   313  			trackingRef := ref.NewRemoteRef(rrd.remote.Name, remoteRef.Ref.GetPath())
   314  			localRef, localRefExists := localRefsByPath[remoteRef.Ref.GetPath()]
   315  
   316  			// loop on optimistic lock failures
   317  		OPTIMISTIC_RETRY:
   318  			for {
   319  				if pullErr != nil || localRefExists {
   320  					pullErr = nil
   321  
   322  					if localRef.Ref.GetType() == ref.BranchRefType {
   323  						pulled, err := rrd.pullLocalBranch(ctx, localRef, remoteRef, trackingRef, behavior)
   324  						if errors.Is(err, datas.ErrOptimisticLockFailed) {
   325  							continue OPTIMISTIC_RETRY
   326  						} else if err != nil {
   327  							return nil, err
   328  						}
   329  
   330  						// If we pulled this branch, we need to also update its corresponding working set
   331  						// TODO: the ErrOptimisticLockFailed below will cause working set to not be updated the next time through
   332  						//  the loop, since pullLocalBranch will return false (branch already up to date)
   333  						//  A better solution would be to update both the working set and the branch in the same noms transaction,
   334  						//  but that's difficult with the current structure
   335  						if pulled {
   336  							err = rrd.updateWorkingSet(ctx, localRef, behavior)
   337  							if errors.Is(err, datas.ErrOptimisticLockFailed) {
   338  								continue OPTIMISTIC_RETRY
   339  							} else if err != nil {
   340  								return nil, err
   341  							}
   342  						}
   343  					}
   344  
   345  					continue REFS
   346  				} else {
   347  					switch remoteRef.Ref.GetType() {
   348  					case ref.BranchRefType:
   349  						// CreateNewBranch also creates its corresponding working set
   350  						err := rrd.createNewBranchFromRemote(ctx, remoteRef, trackingRef)
   351  						if errors.Is(err, datas.ErrOptimisticLockFailed) {
   352  							continue OPTIMISTIC_RETRY
   353  						} else if err != nil {
   354  							return nil, err
   355  						}
   356  
   357  						// TODO: Establish upstream tracking for this new branch
   358  						continue REFS
   359  					case ref.TagRefType:
   360  						err := rrd.ddb.SetHead(ctx, remoteRef.Ref, remoteRef.Hash)
   361  						if errors.Is(err, datas.ErrOptimisticLockFailed) {
   362  							continue OPTIMISTIC_RETRY
   363  						} else if err != nil {
   364  							return nil, err
   365  						}
   366  
   367  						continue REFS
   368  					default:
   369  						ctx.GetLogger().Debugf("skipping replication for unhandled remote ref %s", remoteRef.Ref.String())
   370  						continue REFS
   371  					}
   372  				}
   373  			}
   374  		}
   375  		return nil, nil
   376  	})
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  
   381  	return remoteRefsByPath, nil
   382  }
   383  
   384  // expandWildcardBranchPattern evaluates |pattern| and returns a list of branch names from the source database that
   385  // match the branch name pattern. The '*' wildcard may be used in the pattern to match zero or more characters.
   386  func (rrd ReadReplicaDatabase) expandWildcardBranchPattern(ctx context.Context, pattern string) ([]string, error) {
   387  	sourceBranches, err := rrd.srcDB.GetBranches(ctx)
   388  	if err != nil {
   389  		return nil, err
   390  	}
   391  	expandedBranches := make([]string, 0)
   392  	for _, sourceBranch := range sourceBranches {
   393  		if matchWildcardPattern(pattern, sourceBranch.GetPath()) {
   394  			expandedBranches = append(expandedBranches, sourceBranch.GetPath())
   395  		}
   396  	}
   397  	return expandedBranches, nil
   398  }
   399  
   400  func (rrd ReadReplicaDatabase) createNewBranchFromRemote(ctx *sql.Context, remoteRef doltdb.RefWithHash, trackingRef ref.RemoteRef) error {
   401  	ctx.GetLogger().Tracef("creating local branch %s", remoteRef.Ref.GetPath())
   402  
   403  	// If a local branch isn't present for the remote branch, create a new branch for it. We need to use
   404  	// NewBranchAtCommit so that the branch has its associated working set created at the same time. Creating
   405  	// branch refs without associate working sets causes errors in other places.
   406  	spec, err := doltdb.NewCommitSpec(remoteRef.Hash.String())
   407  	if err != nil {
   408  		return err
   409  	}
   410  
   411  	optCmt, err := rrd.ddb.Resolve(ctx, spec, nil)
   412  	if err != nil {
   413  		return err
   414  	}
   415  	cm, ok := optCmt.ToCommit()
   416  	if !ok {
   417  		return doltdb.ErrGhostCommitEncountered // NM4 - TEST.
   418  	}
   419  
   420  	err = rrd.ddb.NewBranchAtCommit(ctx, remoteRef.Ref, cm, nil)
   421  	return rrd.ddb.SetHead(ctx, trackingRef, remoteRef.Hash)
   422  }
   423  
   424  // pullLocalBranch pulls the remote branch into the local branch if they differ and returns if any work was done.
   425  // Sets the head directly if pullBehaviorForcePull is provided, otherwise attempts a fast-forward.
   426  func (rrd ReadReplicaDatabase) pullLocalBranch(ctx *sql.Context, localRef doltdb.RefWithHash, remoteRef doltdb.RefWithHash, trackingRef ref.RemoteRef, behavior pullBehavior) (bool, error) {
   427  	if localRef.Hash != remoteRef.Hash {
   428  		if behavior == pullBehaviorForcePull {
   429  			err := rrd.ddb.SetHead(ctx, remoteRef.Ref, remoteRef.Hash)
   430  			if err != nil {
   431  				return false, err
   432  			}
   433  		} else {
   434  			err := rrd.ddb.FastForwardToHash(ctx, remoteRef.Ref, remoteRef.Hash)
   435  			if err != nil {
   436  				return false, err
   437  			}
   438  		}
   439  
   440  		err := rrd.ddb.SetHead(ctx, trackingRef, remoteRef.Hash)
   441  		if err != nil {
   442  			return false, err
   443  		}
   444  
   445  		return true, nil
   446  	}
   447  
   448  	return false, nil
   449  }
   450  
   451  // updateWorkingSet updates the working set for the branch ref given to the root value in that commit
   452  func (rrd ReadReplicaDatabase) updateWorkingSet(ctx *sql.Context, localRef doltdb.RefWithHash, behavior pullBehavior) error {
   453  	wsRef, err := ref.WorkingSetRefForHead(localRef.Ref)
   454  	if err != nil {
   455  		return err
   456  	}
   457  
   458  	var wsHash hash.Hash
   459  	ws, err := rrd.ddb.ResolveWorkingSet(ctx, wsRef)
   460  	if err == doltdb.ErrWorkingSetNotFound {
   461  		// ignore, we'll create from scratch
   462  	} else if err != nil {
   463  		return err
   464  	} else {
   465  		wsHash, err = ws.HashOf()
   466  		if err != nil {
   467  			return err
   468  		}
   469  	}
   470  
   471  	cm, err := rrd.ddb.ResolveCommitRef(ctx, localRef.Ref)
   472  	if err != nil {
   473  		return err
   474  	}
   475  	rv, err := cm.GetRootValue(ctx)
   476  	if err != nil {
   477  		return err
   478  	}
   479  
   480  	wsMeta := doltdb.TodoWorkingSetMeta()
   481  	if dtx, ok := ctx.GetTransaction().(*dsess.DoltTransaction); ok {
   482  		wsMeta = dtx.WorkingSetMeta(ctx)
   483  	}
   484  
   485  	newWs := doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(rv).WithStagedRoot(rv)
   486  	return rrd.ddb.UpdateWorkingSet(ctx, wsRef, newWs, wsHash, wsMeta, nil)
   487  }
   488  
   489  func getReplicationRefs(ctx *sql.Context, rrd ReadReplicaDatabase) (
   490  	remoteRefs []doltdb.RefWithHash,
   491  	localRefs []doltdb.RefWithHash,
   492  	deletedRefs []doltdb.RefWithHash,
   493  	err error,
   494  ) {
   495  	remoteRefs, err = rrd.srcDB.GetRefsWithHashes(ctx)
   496  	if err != nil {
   497  		return nil, nil, nil, err
   498  	}
   499  
   500  	localRefs, err = rrd.Database.ddb.GetRefsWithHashes(ctx)
   501  	if err != nil {
   502  		return nil, nil, nil, err
   503  	}
   504  
   505  	deletedRefs = refsToDelete(remoteRefs, localRefs)
   506  	return remoteRefs, localRefs, deletedRefs, nil
   507  }
   508  
   509  func refsToDelete(remRefs, localRefs []doltdb.RefWithHash) []doltdb.RefWithHash {
   510  	toDelete := make([]doltdb.RefWithHash, 0, len(localRefs))
   511  	var i, j int
   512  	for i < len(remRefs) && j < len(localRefs) {
   513  		rem := remRefs[i].Ref.GetPath()
   514  		local := localRefs[j].Ref.GetPath()
   515  		if rem == local {
   516  			i++
   517  			j++
   518  		} else if rem < local {
   519  			i++
   520  		} else {
   521  			toDelete = append(toDelete, localRefs[j])
   522  			j++
   523  		}
   524  	}
   525  	for j < len(localRefs) {
   526  		toDelete = append(toDelete, localRefs[j])
   527  		j++
   528  	}
   529  	return toDelete
   530  }
   531  
   532  func deleteBranches(ctx *sql.Context, rrd ReadReplicaDatabase, branches []doltdb.RefWithHash) error {
   533  	for _, b := range branches {
   534  		err := rrd.ddb.DeleteBranch(ctx, b.Ref, nil)
   535  		if errors.Is(err, doltdb.ErrBranchNotFound) {
   536  			continue
   537  		} else if err != nil {
   538  			return err
   539  		}
   540  	}
   541  	return nil
   542  }
   543  
   544  type res struct {
   545  	v   any
   546  	err error
   547  }
   548  
   549  type blocked struct {
   550  	f       func() (any, error)
   551  	waiters []chan res
   552  }
   553  
   554  func newLimiter() *limiter {
   555  	return &limiter{
   556  		running: make(map[string]*blocked),
   557  	}
   558  }
   559  
   560  // *limiter allows a caller to limit performing concurrent work for a given string key.
   561  type limiter struct {
   562  	mu      sync.Mutex
   563  	running map[string]*blocked
   564  }
   565  
   566  // |Run| invokes |f|, returning its result. It does not allow two |f|s
   567  // submitted with the same |s| to be running in concurrently.
   568  // Only one of the |f|s that arrives with the same |s| while another |f| with
   569  // that key is running will ultimately be run. The result of invoking that |f|
   570  // will be returned from the |Run| call to all blockers on that key.
   571  //
   572  // 1) A caller provides a string key, |s|, and an |f func() error| which will
   573  // perform the work when invoked.
   574  //
   575  // 2) If there is no outstanding call for the key, |f| is invoked and the
   576  // result is returned.
   577  //
   578  // 3) Otherwise, the caller blocks until the outstanding call is completed.
   579  // When the outstanding call completes, one of the blocked |f|s that was
   580  // provided for that key is run. The result of that invocation is returned to
   581  // all blocked callers.
   582  //
   583  // A caller's |Run| invocation can return early if the context is cancelled.
   584  // If the |f| captures a context, and that context is canceled, and the |f|
   585  // allows the error from that context cancelation to escape, then multiple
   586  // callers will see ContextCanceled / DeadlineExceeded, even if their contexts
   587  // are not canceled.
   588  //
   589  // This implementation is very naive and is not not optimized for high
   590  // contention on |l.running|/|l.mu|.
   591  func (l *limiter) Run(ctx context.Context, s string, f func() (any, error)) (any, error) {
   592  	l.mu.Lock()
   593  	if b, ok := l.running[s]; ok {
   594  		// Something is already running; add ourselves to waiters.
   595  		ch := make(chan res)
   596  		if b.f == nil {
   597  			// We are the first waiter; we set what |f| will be invoked.
   598  			b.f = f
   599  		}
   600  		b.waiters = append(b.waiters, ch)
   601  		l.mu.Unlock()
   602  		select {
   603  		case r := <-ch:
   604  			return r.v, r.err
   605  		case <-ctx.Done():
   606  			go func() { <-ch }()
   607  			return nil, ctx.Err()
   608  		}
   609  	} else {
   610  		// We can run immediately and return the result of |f|.
   611  		// Register ourselves as running.
   612  		l.running[s] = new(blocked)
   613  		l.mu.Unlock()
   614  	}
   615  
   616  	res, err := f()
   617  	l.finish(s)
   618  	return res, err
   619  }
   620  
   621  // Called anytime work is finished on a given key. Responsible for
   622  // starting any blocked work on |s| and delivering the results to waiters.
   623  func (l *limiter) finish(s string) {
   624  	l.mu.Lock()
   625  	defer l.mu.Unlock()
   626  	b := l.running[s]
   627  	if len(b.waiters) != 0 {
   628  		go func() {
   629  			r, err := b.f()
   630  			for _, ch := range b.waiters {
   631  				ch <- res{r, err}
   632  				close(ch)
   633  			}
   634  			l.finish(s)
   635  		}()
   636  		// Just started work for the existing |*blocked|, make a new
   637  		// |*blocked| for work that arrives from this point forward.
   638  		l.running[s] = new(blocked)
   639  	} else {
   640  		// No work is pending. Delete l.running[s] since nothing is
   641  		// running anymore.
   642  		delete(l.running, s)
   643  	}
   644  }