github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dfunctions/has_ancestor.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dfunctions 16 17 import ( 18 "fmt" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 "github.com/dolthub/go-mysql-server/sql/types" 22 23 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 24 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 25 ) 26 27 const HasAncestorFuncName = "has_ancestor" 28 29 type HasAncestor struct { 30 reference sql.Expression 31 ancestor sql.Expression 32 } 33 34 var _ sql.FunctionExpression = (*HasAncestor)(nil) 35 36 // NewHasAncestor creates a new HasAncestor expression. 37 func NewHasAncestor(head, anc sql.Expression) sql.Expression { 38 return &HasAncestor{reference: head, ancestor: anc} 39 } 40 41 // Eval implements the Expression interface. 42 func (a *HasAncestor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 43 if !types.IsText(a.reference.Type()) { 44 return nil, sql.ErrInvalidArgumentDetails.New(a, a.reference) 45 } 46 if !types.IsText(a.ancestor.Type()) { 47 return nil, sql.ErrInvalidArgumentDetails.New(a, a.ancestor) 48 } 49 50 // TODO analysis should embed a database the same way as table functions 51 sess := dsess.DSessFromSess(ctx.Session) 52 db := sess.GetCurrentDatabase() 53 dbd, ok := sess.GetDbData(ctx, db) 54 if !ok { 55 return nil, fmt.Errorf("error during has_ancestor check: database not found '%s'", db) 56 } 57 ddb := dbd.Ddb 58 59 // this errors for non-branch refs 60 // ddb.Resolve will error if combination of head and commit are invalid 61 headRef, _ := sess.CWBHeadRef(ctx, db) 62 var headCommit *doltdb.Commit 63 { 64 headIf, err := a.reference.Eval(ctx, row) 65 if err != nil { 66 return nil, err 67 } 68 headStr, _, err := types.Text.Convert(headIf) 69 if err != nil { 70 return nil, err 71 } 72 73 cs, err := doltdb.NewCommitSpec(headStr.(string)) 74 if err != nil { 75 return nil, err 76 } 77 optCmt, err := ddb.Resolve(ctx, cs, headRef) 78 if err != nil { 79 return nil, fmt.Errorf("error during has_ancestor check: ref not found '%s'", headStr) 80 } 81 headCommit, ok = optCmt.ToCommit() 82 if !ok { 83 return nil, doltdb.ErrGhostCommitEncountered 84 } 85 } 86 87 var ancCommit *doltdb.Commit 88 { 89 ancIf, err := a.ancestor.Eval(ctx, row) 90 if err != nil { 91 return nil, err 92 } 93 ancStr, _, err := types.Text.Convert(ancIf) 94 if err != nil { 95 return nil, err 96 } 97 cs, err := doltdb.NewCommitSpec(ancStr.(string)) 98 if err != nil { 99 return nil, err 100 } 101 optCmt, err := ddb.Resolve(ctx, cs, headRef) 102 if err != nil { 103 return nil, fmt.Errorf("error during has_ancestor check: ref not found '%s'", ancStr) 104 } 105 ancCommit, ok = optCmt.ToCommit() 106 if !ok { 107 return nil, doltdb.ErrGhostCommitEncountered 108 } 109 110 } 111 112 headHash, err := headCommit.HashOf() 113 if err != nil { 114 return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error()) 115 } 116 117 ancHash, err := ancCommit.HashOf() 118 if err != nil { 119 return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error()) 120 } 121 if headHash == ancHash { 122 return true, nil 123 } 124 125 cc, err := headCommit.GetCommitClosure(ctx) 126 if err != nil { 127 return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error()) 128 } 129 ancHeight, err := ancCommit.Height() 130 if err != nil { 131 return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error()) 132 } 133 134 isAncestor, err := cc.ContainsKey(ctx, ancHash, ancHeight) 135 if err != nil { 136 return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error()) 137 } 138 139 return isAncestor, nil 140 } 141 142 func (a *HasAncestor) Resolved() bool { 143 return a.reference.Resolved() && a.ancestor.Resolved() 144 } 145 146 func (a *HasAncestor) Children() []sql.Expression { 147 return []sql.Expression{a.reference, a.ancestor} 148 } 149 150 // String implements the Stringer interface. 151 func (a *HasAncestor) String() string { 152 return fmt.Sprintf("HAS_ANCESTOR(%s, %s)", a.reference, a.ancestor) 153 } 154 155 // FunctionName implements the FunctionExpression interface 156 func (a *HasAncestor) FunctionName() string { 157 return HasAncestorFuncName 158 } 159 160 // Description implements the FunctionExpression interface 161 func (a *HasAncestor) Description() string { 162 return "returns whether a reference commit's ancestor graph contains a target commit" 163 } 164 165 // IsNullable implements the Expression interface. 166 func (a *HasAncestor) IsNullable() bool { 167 return a.reference.IsNullable() || a.ancestor.IsNullable() 168 } 169 170 // WithChildren implements the Expression interface. 171 func (a *HasAncestor) WithChildren(children ...sql.Expression) (sql.Expression, error) { 172 if len(children) != 2 { 173 return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2) 174 } 175 return NewHasAncestor(children[0], children[1]), nil 176 } 177 178 // Type implements the Expression interface. 179 func (a *HasAncestor) Type() sql.Type { 180 return types.Boolean 181 }