github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/process_truncate.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/vitess/go/vt/sqlparser"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/plan"
    25  	"github.com/dolthub/go-mysql-server/sql/transform"
    26  )
    27  
    28  // processTruncate is a combination of resolving fields in *plan.DeleteFrom and *plan.Truncate, validating the fields,
    29  // and in some cases converting *plan.DeleteFrom -> *plan.Truncate
    30  func processTruncate(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    31  	span, ctx := ctx.Span("processTruncate")
    32  	defer span.End()
    33  
    34  	switch n := node.(type) {
    35  	case *plan.DeleteFrom:
    36  		if !n.Resolved() {
    37  			return n, transform.SameTree, nil
    38  		}
    39  		return deleteToTruncate(ctx, a, n)
    40  	case *plan.Truncate:
    41  		if !n.Resolved() {
    42  			return nil, transform.SameTree, fmt.Errorf("cannot process TRUNCATE as node is expected to be resolved")
    43  		}
    44  		var db sql.Database
    45  		var err error
    46  		if n.DatabaseName() == "" {
    47  			db, err = a.Catalog.Database(ctx, ctx.GetCurrentDatabase())
    48  			if err != nil {
    49  				return nil, transform.SameTree, err
    50  			}
    51  		} else {
    52  			db, err = a.Catalog.Database(ctx, n.DatabaseName())
    53  			if err != nil {
    54  				return nil, transform.SameTree, err
    55  			}
    56  		}
    57  		_, err = validateTruncate(ctx, db, n.Child)
    58  		if err != nil {
    59  			return nil, transform.SameTree, err
    60  		}
    61  		return n, transform.SameTree, nil
    62  	default:
    63  		return n, transform.SameTree, nil
    64  	}
    65  }
    66  
    67  func deleteToTruncate(ctx *sql.Context, a *Analyzer, deletePlan *plan.DeleteFrom) (sql.Node, transform.TreeIdentity, error) {
    68  	tbl, ok := deletePlan.Child.(*plan.ResolvedTable)
    69  	if !ok {
    70  		return deletePlan, transform.SameTree, nil
    71  	}
    72  	tblName := strings.ToLower(tbl.Name())
    73  
    74  	// auto_increment behaves differently for TRUNCATE and DELETE
    75  	for _, col := range tbl.Schema() {
    76  		if col.AutoIncrement {
    77  			return deletePlan, transform.SameTree, nil
    78  		}
    79  	}
    80  
    81  	currentDb, err := a.Catalog.Database(ctx, ctx.GetCurrentDatabase())
    82  	if err != nil {
    83  		return nil, transform.SameTree, err
    84  	}
    85  	dbTblNames, err := currentDb.GetTableNames(ctx)
    86  	if err != nil {
    87  		return nil, transform.SameTree, err
    88  	}
    89  	tblFound := false
    90  	for _, dbTblName := range dbTblNames {
    91  		if strings.ToLower(dbTblName) == tblName {
    92  			if tblFound == false {
    93  				tblFound = true
    94  			} else {
    95  				return deletePlan, transform.SameTree, nil
    96  			}
    97  		}
    98  	}
    99  	if !tblFound {
   100  		return deletePlan, transform.SameTree, nil
   101  	}
   102  
   103  	triggers, err := loadTriggersFromDb(ctx, a, currentDb)
   104  	if err != nil {
   105  		return nil, transform.SameTree, err
   106  	}
   107  	for _, trigger := range triggers {
   108  		if trigger.TriggerEvent != sqlparser.DeleteStr {
   109  			continue
   110  		}
   111  		var triggerTblName string
   112  		switch trigger.Table.(type) {
   113  		case *plan.UnresolvedTable, *plan.ResolvedTable:
   114  			triggerTblName = trigger.Table.(sql.NameableNode).Name()
   115  		default:
   116  			// If we can't determine the name of the table that the trigger is on, we just abort to be safe
   117  			// TODO error?
   118  			return deletePlan, transform.SameTree, nil
   119  		}
   120  		if strings.ToLower(triggerTblName) == tblName {
   121  			// An ON DELETE trigger is present so we can't use TRUNCATE
   122  			return deletePlan, transform.SameTree, nil
   123  		}
   124  	}
   125  
   126  	if ok, err := validateTruncate(ctx, currentDb, tbl); ok {
   127  		// We only check err if ok is true, as some errors won't apply to us attempting to convert from a DELETE
   128  		if err != nil {
   129  			return nil, transform.SameTree, err
   130  		}
   131  		return plan.NewTruncate(ctx.GetCurrentDatabase(), tbl), transform.NewTree, nil
   132  	}
   133  	return deletePlan, transform.SameTree, nil
   134  }
   135  
   136  // validateTruncate returns whether the truncate operation adheres to the limitations as specified in
   137  // https://dev.mysql.com/doc/refman/8.0/en/truncate-table.html. In the case of checking if a DELETE may be converted
   138  // to a TRUNCATE operation, check the bool first. If false, then the error should be ignored (such as if the table does
   139  // not support TRUNCATE). If true is returned along with an error, then the error is not expected to happen under
   140  // normal circumstances and should be dealt with.
   141  func validateTruncate(ctx *sql.Context, db sql.Database, tbl sql.Node) (bool, error) {
   142  	truncatable, err := plan.GetTruncatable(tbl)
   143  	if err != nil {
   144  		return false, err // false as any caller besides Truncate would not care for this error
   145  	}
   146  	tableName := strings.ToLower(truncatable.Name())
   147  
   148  	tableNames, err := db.GetTableNames(ctx)
   149  	if err != nil {
   150  		return true, err // true as this should not error under normal circumstances
   151  	}
   152  	for _, tableNameToCheck := range tableNames {
   153  		if strings.ToLower(tableNameToCheck) == tableName {
   154  			continue
   155  		}
   156  		tableToCheck, ok, err := db.GetTableInsensitive(ctx, tableNameToCheck)
   157  		if err != nil {
   158  			return true, err // should not error under normal circumstances
   159  		}
   160  		if !ok {
   161  			return true, sql.ErrTableNotFound.New(tableNameToCheck)
   162  		}
   163  		fkTable, ok := tableToCheck.(sql.ForeignKeyTable)
   164  		if ok {
   165  			fks, err := fkTable.GetDeclaredForeignKeys(ctx)
   166  			if err != nil {
   167  				return true, err
   168  			}
   169  
   170  			fkChecks, err := ctx.GetSessionVariable(ctx, "foreign_key_checks")
   171  			if err != nil {
   172  				return true, err
   173  			}
   174  
   175  			if fkChecks.(int8) == 1 {
   176  				for _, fk := range fks {
   177  					if strings.ToLower(fk.ParentTable) == tableName {
   178  						return false, sql.ErrTruncateReferencedFromForeignKey.New(tableName, fk.Name, tableNameToCheck)
   179  					}
   180  				}
   181  			}
   182  		}
   183  	}
   184  	//TODO: check for an active table lock and error if one is found for the target table
   185  	return true, nil
   186  }