github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dfunctions/has_ancestor.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dfunctions
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/types"
    22  
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    25  )
    26  
    27  const HasAncestorFuncName = "has_ancestor"
    28  
    29  type HasAncestor struct {
    30  	reference sql.Expression
    31  	ancestor  sql.Expression
    32  }
    33  
    34  var _ sql.FunctionExpression = (*HasAncestor)(nil)
    35  
    36  // NewHasAncestor creates a new HasAncestor expression.
    37  func NewHasAncestor(head, anc sql.Expression) sql.Expression {
    38  	return &HasAncestor{reference: head, ancestor: anc}
    39  }
    40  
    41  // Eval implements the Expression interface.
    42  func (a *HasAncestor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    43  	if !types.IsText(a.reference.Type()) {
    44  		return nil, sql.ErrInvalidArgumentDetails.New(a, a.reference)
    45  	}
    46  	if !types.IsText(a.ancestor.Type()) {
    47  		return nil, sql.ErrInvalidArgumentDetails.New(a, a.ancestor)
    48  	}
    49  
    50  	// TODO analysis should embed a database the same way as table functions
    51  	sess := dsess.DSessFromSess(ctx.Session)
    52  	db := sess.GetCurrentDatabase()
    53  	dbd, ok := sess.GetDbData(ctx, db)
    54  	if !ok {
    55  		return nil, fmt.Errorf("error during has_ancestor check: database not found '%s'", db)
    56  	}
    57  	ddb := dbd.Ddb
    58  
    59  	// this errors for non-branch refs
    60  	// ddb.Resolve will error if combination of head and commit are invalid
    61  	headRef, _ := sess.CWBHeadRef(ctx, db)
    62  	var headCommit *doltdb.Commit
    63  	{
    64  		headIf, err := a.reference.Eval(ctx, row)
    65  		if err != nil {
    66  			return nil, err
    67  		}
    68  		headStr, _, err := types.Text.Convert(headIf)
    69  		if err != nil {
    70  			return nil, err
    71  		}
    72  
    73  		cs, err := doltdb.NewCommitSpec(headStr.(string))
    74  		if err != nil {
    75  			return nil, err
    76  		}
    77  		optCmt, err := ddb.Resolve(ctx, cs, headRef)
    78  		if err != nil {
    79  			return nil, fmt.Errorf("error during has_ancestor check: ref not found '%s'", headStr)
    80  		}
    81  		headCommit, ok = optCmt.ToCommit()
    82  		if !ok {
    83  			return nil, doltdb.ErrGhostCommitEncountered
    84  		}
    85  	}
    86  
    87  	var ancCommit *doltdb.Commit
    88  	{
    89  		ancIf, err := a.ancestor.Eval(ctx, row)
    90  		if err != nil {
    91  			return nil, err
    92  		}
    93  		ancStr, _, err := types.Text.Convert(ancIf)
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  		cs, err := doltdb.NewCommitSpec(ancStr.(string))
    98  		if err != nil {
    99  			return nil, err
   100  		}
   101  		optCmt, err := ddb.Resolve(ctx, cs, headRef)
   102  		if err != nil {
   103  			return nil, fmt.Errorf("error during has_ancestor check: ref not found '%s'", ancStr)
   104  		}
   105  		ancCommit, ok = optCmt.ToCommit()
   106  		if !ok {
   107  			return nil, doltdb.ErrGhostCommitEncountered
   108  		}
   109  
   110  	}
   111  
   112  	headHash, err := headCommit.HashOf()
   113  	if err != nil {
   114  		return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error())
   115  	}
   116  
   117  	ancHash, err := ancCommit.HashOf()
   118  	if err != nil {
   119  		return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error())
   120  	}
   121  	if headHash == ancHash {
   122  		return true, nil
   123  	}
   124  
   125  	cc, err := headCommit.GetCommitClosure(ctx)
   126  	if err != nil {
   127  		return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error())
   128  	}
   129  	ancHeight, err := ancCommit.Height()
   130  	if err != nil {
   131  		return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error())
   132  	}
   133  
   134  	isAncestor, err := cc.ContainsKey(ctx, ancHash, ancHeight)
   135  	if err != nil {
   136  		return nil, fmt.Errorf("error during has_ancestor check: %s", err.Error())
   137  	}
   138  
   139  	return isAncestor, nil
   140  }
   141  
   142  func (a *HasAncestor) Resolved() bool {
   143  	return a.reference.Resolved() && a.ancestor.Resolved()
   144  }
   145  
   146  func (a *HasAncestor) Children() []sql.Expression {
   147  	return []sql.Expression{a.reference, a.ancestor}
   148  }
   149  
   150  // String implements the Stringer interface.
   151  func (a *HasAncestor) String() string {
   152  	return fmt.Sprintf("HAS_ANCESTOR(%s, %s)", a.reference, a.ancestor)
   153  }
   154  
   155  // FunctionName implements the FunctionExpression interface
   156  func (a *HasAncestor) FunctionName() string {
   157  	return HasAncestorFuncName
   158  }
   159  
   160  // Description implements the FunctionExpression interface
   161  func (a *HasAncestor) Description() string {
   162  	return "returns whether a reference commit's ancestor graph contains a target commit"
   163  }
   164  
   165  // IsNullable implements the Expression interface.
   166  func (a *HasAncestor) IsNullable() bool {
   167  	return a.reference.IsNullable() || a.ancestor.IsNullable()
   168  }
   169  
   170  // WithChildren implements the Expression interface.
   171  func (a *HasAncestor) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   172  	if len(children) != 2 {
   173  		return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2)
   174  	}
   175  	return NewHasAncestor(children[0], children[1]), nil
   176  }
   177  
   178  // Type implements the Expression interface.
   179  func (a *HasAncestor) Type() sql.Type {
   180  	return types.Boolean
   181  }