github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/statspro/configure.go (about) 1 // Copyright 2024 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package statspro 16 17 import ( 18 "context" 19 "fmt" 20 "strings" 21 "time" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 types2 "github.com/dolthub/go-mysql-server/sql/types" 25 26 "github.com/dolthub/dolt/go/libraries/doltcore/env" 27 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 28 "github.com/dolthub/dolt/go/libraries/utils/filesys" 29 ) 30 31 func (p *Provider) Configure(ctx context.Context, ctxFactory func(ctx context.Context) (*sql.Context, error), bThreads *sql.BackgroundThreads, dbs []dsess.SqlDatabase) error { 32 p.SetStarter(NewStatsInitDatabaseHook(p, ctxFactory, bThreads)) 33 34 if _, disabled, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsMemoryOnly); disabled == int8(1) { 35 return nil 36 } 37 38 loadCtx, err := ctxFactory(ctx) 39 if err != nil { 40 return err 41 } 42 43 branches := p.getStatsBranches(loadCtx) 44 45 var autoEnabled bool 46 var intervalSec time.Duration 47 var thresholdf64 float64 48 if _, enabled, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsAutoRefreshEnabled); enabled == int8(1) { 49 autoEnabled = true 50 _, threshold, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsAutoRefreshThreshold) 51 _, interval, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsAutoRefreshInterval) 52 interval64, _, _ := types2.Int64.Convert(interval) 53 intervalSec = time.Second * time.Duration(interval64.(int64)) 54 thresholdf64 = threshold.(float64) 55 56 p.pro.InitDatabaseHooks = append(p.pro.InitDatabaseHooks, NewStatsInitDatabaseHook(p, ctxFactory, bThreads)) 57 p.pro.DropDatabaseHooks = append(p.pro.DropDatabaseHooks, NewStatsDropDatabaseHook(p)) 58 } 59 60 eg, ctx := loadCtx.NewErrgroup() 61 for _, db := range dbs { 62 // copy closure variables 63 db := db 64 eg.Go(func() (err error) { 65 defer func() { 66 if r := recover(); r != nil { 67 if str, ok := r.(fmt.Stringer); ok { 68 err = fmt.Errorf("%w: %s", ErrFailedToLoad, str.String()) 69 } else { 70 err = fmt.Errorf("%w: %v", ErrFailedToLoad, r) 71 } 72 73 return 74 } 75 }() 76 77 fs, err := p.pro.FileSystemForDatabase(db.Name()) 78 if err != nil { 79 return err 80 } 81 82 if p.Load(loadCtx, fs, db, branches); err != nil { 83 return err 84 } 85 if autoEnabled { 86 return p.InitAutoRefreshWithParams(ctxFactory, db.Name(), bThreads, intervalSec, thresholdf64, branches) 87 } 88 return nil 89 }) 90 } 91 return eg.Wait() 92 } 93 94 // getStatsBranches returns the set of branches whose statistics are tracked. 95 // The order of precedence is (1) global variable, (2) session current branch, 96 // (3) engine default branch. 97 func (p *Provider) getStatsBranches(ctx *sql.Context) []string { 98 dSess := dsess.DSessFromSess(ctx.Session) 99 var branches []string 100 if _, bs, _ := sql.SystemVariables.GetGlobal(dsess.DoltStatsBranches); bs == "" { 101 defaultBranch, _ := dSess.GetBranch() 102 if defaultBranch != "" { 103 branches = append(branches, defaultBranch) 104 } 105 } else { 106 for _, branch := range strings.Split(bs.(string), ",") { 107 branches = append(branches, strings.TrimSpace(branch)) 108 } 109 } 110 111 if branches == nil { 112 branches = []string{p.pro.DefaultBranch()} 113 } 114 return branches 115 } 116 117 func (p *Provider) LoadStats(ctx *sql.Context, db, branch string) error { 118 if statDb, ok := p.getStatDb(db); ok { 119 return statDb.LoadBranchStats(ctx, branch) 120 } 121 return nil 122 } 123 124 // Load scans the statistics tables, populating the |stats| attribute. 125 // Statistics are not available for reading until we've finished loading. 126 func (p *Provider) Load(ctx *sql.Context, fs filesys.Filesys, db dsess.SqlDatabase, branches []string) { 127 // |statPath| is either file://./stat or mem://stat 128 statsDb, err := p.sf.Init(ctx, db, p.pro, fs, env.GetCurrentUserHomeDir) 129 if err != nil { 130 ctx.Warn(0, err.Error()) 131 return 132 } 133 134 for _, branch := range branches { 135 err = statsDb.LoadBranchStats(ctx, branch) 136 if err != nil { 137 // if branch name is invalid, continue loading rest 138 // TODO: differentiate bad branch name from other errors 139 ctx.Warn(0, err.Error()) 140 continue 141 } 142 } 143 144 p.mu.Lock() 145 defer p.mu.Unlock() 146 p.setStatDb(strings.ToLower(db.Name()), statsDb) 147 return 148 }