github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/constraints.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"strings"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql/transform"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/plan"
    24  )
    25  
    26  // resolveDropConstraint replaces DropConstraint nodes with a concrete type of alter table node as appropriate, or
    27  // throws a constraint not found error if the named constraint isn't found on the table given.
    28  func resolveDropConstraint(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    29  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    30  		dropConstraint, ok := n.(*plan.DropConstraint)
    31  		if !ok {
    32  			return n, transform.SameTree, nil
    33  		}
    34  
    35  		rt, ok := dropConstraint.Child.(*plan.ResolvedTable)
    36  		if !ok {
    37  			return nil, transform.SameTree, ErrInAnalysis.New("Expected a TableNode for ALTER TABLE DROP CONSTRAINT statement")
    38  		}
    39  
    40  		//TODO: handle if a foreign key and check constraint have the same name, it should error saying to use the specific drop
    41  		table := rt.Table
    42  		fkt, ok := table.(sql.ForeignKeyTable)
    43  		if ok {
    44  			decFks, err := fkt.GetDeclaredForeignKeys(ctx)
    45  			if err != nil {
    46  				return nil, transform.SameTree, err
    47  			}
    48  			refFks, err := fkt.GetReferencedForeignKeys(ctx)
    49  			if err != nil {
    50  				return nil, transform.SameTree, err
    51  			}
    52  			for _, fk := range append(decFks, refFks...) {
    53  				if strings.ToLower(fk.Name) == strings.ToLower(dropConstraint.Name) {
    54  					n, err = plan.NewAlterDropForeignKey(rt.SqlDatabase.Name(), rt.Table.Name(), dropConstraint.Name).
    55  						WithDatabaseProvider(a.Catalog.DbProvider)
    56  					return n, transform.NewTree, err
    57  				}
    58  			}
    59  		}
    60  
    61  		ct, ok := table.(sql.CheckTable)
    62  		if !ok {
    63  			return nil, transform.SameTree, plan.ErrNoCheckConstraintSupport.New(table.Name())
    64  		}
    65  
    66  		checks, err := ct.GetChecks(ctx)
    67  		if err != nil {
    68  			return nil, transform.SameTree, err
    69  		}
    70  
    71  		for _, check := range checks {
    72  			if strings.ToLower(check.Name) == strings.ToLower(dropConstraint.Name) {
    73  				return plan.NewAlterDropCheck(rt, check.Name), transform.NewTree, nil
    74  			}
    75  		}
    76  
    77  		return nil, transform.SameTree, sql.ErrUnknownConstraint.New(dropConstraint.Name)
    78  	})
    79  }
    80  
    81  // validateDropConstraint returns an error if the constraint named to be dropped doesn't exist
    82  func validateDropConstraint(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    83  	switch n := n.(type) {
    84  	case *plan.DropCheck:
    85  		rt := n.Table
    86  
    87  		ct, ok := rt.Table.(sql.CheckTable)
    88  		if ok {
    89  			checks, err := ct.GetChecks(ctx)
    90  			if err != nil {
    91  				return nil, transform.SameTree, err
    92  			}
    93  
    94  			for _, check := range checks {
    95  				if strings.ToLower(check.Name) == strings.ToLower(n.Name) {
    96  					return n, transform.SameTree, nil
    97  				}
    98  			}
    99  
   100  			return nil, transform.SameTree, sql.ErrUnknownConstraint.New(n.Name)
   101  		}
   102  
   103  		return nil, transform.SameTree, plan.ErrNoCheckConstraintSupport.New(rt.Table.Name())
   104  	default:
   105  		return n, transform.SameTree, nil
   106  	}
   107  }