github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/statspro/configure.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package statspro
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strings"
    21  	"time"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	types2 "github.com/dolthub/go-mysql-server/sql/types"
    25  
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    28  	"github.com/dolthub/dolt/go/libraries/utils/filesys"
    29  )
    30  
    31  func (p *Provider) Configure(ctx context.Context, ctxFactory func(ctx context.Context) (*sql.Context, error), bThreads *sql.BackgroundThreads, dbs []dsess.SqlDatabase) error {
    32  	p.SetStarter(NewStatsInitDatabaseHook(p, ctxFactory, bThreads))
    33  
    34  	if _, disabled, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsMemoryOnly); disabled == int8(1) {
    35  		return nil
    36  	}
    37  
    38  	loadCtx, err := ctxFactory(ctx)
    39  	if err != nil {
    40  		return err
    41  	}
    42  
    43  	branches := p.getStatsBranches(loadCtx)
    44  
    45  	var autoEnabled bool
    46  	var intervalSec time.Duration
    47  	var thresholdf64 float64
    48  	if _, enabled, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsAutoRefreshEnabled); enabled == int8(1) {
    49  		autoEnabled = true
    50  		_, threshold, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsAutoRefreshThreshold)
    51  		_, interval, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsAutoRefreshInterval)
    52  		interval64, _, _ := types2.Int64.Convert(interval)
    53  		intervalSec = time.Second * time.Duration(interval64.(int64))
    54  		thresholdf64 = threshold.(float64)
    55  
    56  		p.pro.InitDatabaseHooks = append(p.pro.InitDatabaseHooks, NewStatsInitDatabaseHook(p, ctxFactory, bThreads))
    57  		p.pro.DropDatabaseHooks = append(p.pro.DropDatabaseHooks, NewStatsDropDatabaseHook(p))
    58  	}
    59  
    60  	eg, ctx := loadCtx.NewErrgroup()
    61  	for _, db := range dbs {
    62  		// copy closure variables
    63  		db := db
    64  		eg.Go(func() (err error) {
    65  			defer func() {
    66  				if r := recover(); r != nil {
    67  					if str, ok := r.(fmt.Stringer); ok {
    68  						err = fmt.Errorf("%w: %s", ErrFailedToLoad, str.String())
    69  					} else {
    70  						err = fmt.Errorf("%w: %v", ErrFailedToLoad, r)
    71  					}
    72  
    73  					return
    74  				}
    75  			}()
    76  
    77  			fs, err := p.pro.FileSystemForDatabase(db.Name())
    78  			if err != nil {
    79  				return err
    80  			}
    81  
    82  			if p.Load(loadCtx, fs, db, branches); err != nil {
    83  				return err
    84  			}
    85  			if autoEnabled {
    86  				return p.InitAutoRefreshWithParams(ctxFactory, db.Name(), bThreads, intervalSec, thresholdf64, branches)
    87  			}
    88  			return nil
    89  		})
    90  	}
    91  	return eg.Wait()
    92  }
    93  
    94  // getStatsBranches returns the set of branches whose statistics are tracked.
    95  // The order of precedence is (1) global variable, (2) session current branch,
    96  // (3) engine default branch.
    97  func (p *Provider) getStatsBranches(ctx *sql.Context) []string {
    98  	dSess := dsess.DSessFromSess(ctx.Session)
    99  	var branches []string
   100  	if _, bs, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsBranches); bs == "" {
   101  		defaultBranch, _ := dSess.GetBranch()
   102  		if defaultBranch != "" {
   103  			branches = append(branches, defaultBranch)
   104  		}
   105  	} else {
   106  		for _, branch := range strings.Split(bs.(string), ",") {
   107  			branches = append(branches, strings.TrimSpace(branch))
   108  		}
   109  	}
   110  
   111  	if branches == nil {
   112  		branches = []string{p.pro.DefaultBranch()}
   113  	}
   114  	return branches
   115  }
   116  
   117  func (p *Provider) LoadStats(ctx *sql.Context, db, branch string) error {
   118  	if statDb, ok := p.getStatDb(db); ok {
   119  		return statDb.LoadBranchStats(ctx, branch)
   120  	}
   121  	return nil
   122  }
   123  
   124  // Load scans the statistics tables, populating the |stats| attribute.
   125  // Statistics are not available for reading until we've finished loading.
   126  func (p *Provider) Load(ctx *sql.Context, fs filesys.Filesys, db dsess.SqlDatabase, branches []string) {
   127  	// |statPath| is either file://./stat or mem://stat
   128  	statsDb, err := p.sf.Init(ctx, db, p.pro, fs, env.GetCurrentUserHomeDir)
   129  	if err != nil {
   130  		ctx.Warn(0, err.Error())
   131  		return
   132  	}
   133  
   134  	for _, branch := range branches {
   135  		err = statsDb.LoadBranchStats(ctx, branch)
   136  		if err != nil {
   137  			// if branch name is invalid, continue loading rest
   138  			// TODO: differentiate bad branch name from other errors
   139  			ctx.Warn(0, err.Error())
   140  			continue
   141  		}
   142  	}
   143  
   144  	p.mu.Lock()
   145  	defer p.mu.Unlock()
   146  	p.setStatDb(strings.ToLower(db.Name()), statsDb)
   147  	return
   148  }