github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_count_commits.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dprocedures 16 17 import ( 18 "context" 19 "fmt" 20 "io" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 24 "github.com/dolthub/dolt/go/cmd/dolt/cli" 25 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 26 "github.com/dolthub/dolt/go/libraries/doltcore/env/actions/commitwalk" 27 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 28 "github.com/dolthub/dolt/go/store/hash" 29 ) 30 31 func doltCountCommits(ctx *sql.Context, args ...string) (sql.RowIter, error) { 32 ahead, behind, err := countCommits(ctx, args...) 33 if err != nil { 34 return nil, err 35 } 36 return sql.RowsToRowIter(sql.Row{ahead, behind}), nil 37 } 38 39 func countCommits(ctx *sql.Context, args ...string) (ahead uint64, behind uint64, err error) { 40 dbName := ctx.GetCurrentDatabase() 41 if len(dbName) == 0 { 42 return 0, 0, fmt.Errorf("empty database name") 43 } 44 45 sess := dsess.DSessFromSess(ctx.Session) 46 apr, err := cli.CreateCountCommitsArgParser().Parse(args) 47 if err != nil { 48 return 0, 0, err 49 } 50 fromRef, ok := apr.GetValue("from") 51 if !ok { 52 return 0, 0, fmt.Errorf("missing from ref") 53 } 54 if len(fromRef) == 0 { 55 return 0, 0, fmt.Errorf("empty from ref") 56 } 57 toRef, ok := apr.GetValue("to") 58 if !ok { 59 return 0, 0, fmt.Errorf("missing to ref") 60 } 61 if len(toRef) == 0 { 62 return 0, 0, fmt.Errorf("empty to ref") 63 } 64 65 dbData, ok := sess.GetDbData(ctx, dbName) 66 if !ok { 67 return 0, 0, fmt.Errorf("could not load database %s", dbName) 68 } 69 ddb := dbData.Ddb 70 rsr := dbData.Rsr 71 72 fromSpec, err := doltdb.NewCommitSpec(fromRef) 73 if err != nil { 74 return 0, 0, err 75 } 76 headRef, err := rsr.CWBHeadRef() 77 if err != nil { 78 return 0, 0, err 79 } 80 optCmt, err := ddb.Resolve(ctx, fromSpec, headRef) 81 if err != nil { 82 return 0, 0, err 83 } 84 fromCommit, ok := optCmt.ToCommit() 85 if !ok { 86 return 0, 0, doltdb.ErrGhostCommitEncountered 87 } 88 89 fromHash, err := fromCommit.HashOf() 90 if err != nil { 91 return 0, 0, err 92 } 93 94 toSpec, err := doltdb.NewCommitSpec(toRef) 95 if err != nil { 96 return 0, 0, err 97 } 98 optCmt, err = ddb.Resolve(ctx, toSpec, headRef) 99 if err != nil { 100 return 0, 0, err 101 } 102 toCommit, ok := optCmt.ToCommit() 103 if !ok { 104 return 0, 0, doltdb.ErrGhostCommitEncountered 105 } 106 107 toHash, err := toCommit.HashOf() 108 if err != nil { 109 return 0, 0, err 110 } 111 112 optCmt, err = doltdb.GetCommitAncestor(ctx, fromCommit, toCommit) 113 if err != nil { 114 return 0, 0, err 115 } 116 ancestor, ok := optCmt.ToCommit() 117 if !ok { 118 return 0, 0, doltdb.ErrGhostCommitEncountered 119 } 120 121 ancestorHash, err := ancestor.HashOf() 122 if err != nil { 123 return 0, 0, err 124 } 125 126 if fromHash != toHash { 127 behind, err = countCommitsInRange(ctx, ddb, toHash, ancestorHash) 128 if err != nil { 129 return 0, 0, err 130 } 131 ahead, err = countCommitsInRange(ctx, ddb, fromHash, ancestorHash) 132 if err != nil { 133 return 0, 0, err 134 } 135 } 136 137 return ahead, behind, nil 138 } 139 140 // countCommitsInRange returns the number of commits between the given starting point to trace back to the given target point. 141 // The starting commit must be a descendant of the target commit. Target commit must be a common ancestor commit. 142 func countCommitsInRange(ctx context.Context, ddb *doltdb.DoltDB, startCommitHash, targetCommitHash hash.Hash) (uint64, error) { 143 itr, iErr := commitwalk.GetTopologicalOrderIterator(ctx, ddb, []hash.Hash{startCommitHash}, nil) 144 if iErr != nil { 145 return 0, iErr 146 } 147 count := 0 148 for { 149 nextHash, _, err := itr.Next(ctx) 150 if err == io.EOF { 151 return 0, fmt.Errorf("no match found to ancestor commit") 152 } else if err != nil { 153 return 0, err 154 } 155 156 if nextHash == targetCommitHash { 157 break 158 } 159 count += 1 160 } 161 162 return uint64(count), nil 163 }