github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/auto_increment.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"gopkg.in/src-d/go-errors.v1"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  )
    24  
    25  var (
    26  	// ErrAutoIncrementUnsupported is returned when table does not support AUTO_INCREMENT.
    27  	ErrAutoIncrementUnsupported = errors.NewKind("table %s does not support AUTO_INCREMENT columns")
    28  	// ErrNoAutoIncrementCols is returned when table has no AUTO_INCREMENT columns.
    29  	ErrNoAutoIncrementCols = errors.NewKind("table %s has no AUTO_INCREMENT columns")
    30  )
    31  
    32  // AutoIncrement implements AUTO_INCREMENT
    33  type AutoIncrement struct {
    34  	UnaryExpression
    35  	autoTbl sql.AutoIncrementTable
    36  	autoCol *sql.Column
    37  }
    38  
    39  var _ sql.Expression = (*AutoIncrement)(nil)
    40  var _ sql.CollationCoercible = (*AutoIncrement)(nil)
    41  
    42  // NewAutoIncrement creates a new AutoIncrement expression.
    43  func NewAutoIncrement(ctx *sql.Context, table sql.Table, given sql.Expression) (*AutoIncrement, error) {
    44  	autoTbl, ok := table.(sql.AutoIncrementTable)
    45  	if !ok {
    46  		return nil, ErrAutoIncrementUnsupported.New(table.Name())
    47  	}
    48  
    49  	var autoCol *sql.Column
    50  	for _, c := range autoTbl.Schema() {
    51  		if c.AutoIncrement {
    52  			autoCol = c
    53  			break
    54  		}
    55  	}
    56  	if autoCol == nil {
    57  		return nil, ErrNoAutoIncrementCols.New(table.Name())
    58  	}
    59  
    60  	return &AutoIncrement{
    61  		UnaryExpression{Child: given},
    62  		autoTbl,
    63  		autoCol,
    64  	}, nil
    65  }
    66  
    67  // NewAutoIncrementForColumn creates a new AutoIncrement expression for the column given.
    68  func NewAutoIncrementForColumn(ctx *sql.Context, table sql.Table, autoCol *sql.Column, given sql.Expression) (*AutoIncrement, error) {
    69  	autoTbl, ok := table.(sql.AutoIncrementTable)
    70  	if !ok {
    71  		return nil, ErrAutoIncrementUnsupported.New(table.Name())
    72  	}
    73  
    74  	return &AutoIncrement{
    75  		UnaryExpression{Child: given},
    76  		autoTbl,
    77  		autoCol,
    78  	}, nil
    79  }
    80  
    81  // IsNullable implements the Expression interface.
    82  func (i *AutoIncrement) IsNullable() bool {
    83  	return false
    84  }
    85  
    86  // Type implements the Expression interface.
    87  func (i *AutoIncrement) Type() sql.Type {
    88  	return i.autoCol.Type
    89  }
    90  
    91  // CollationCoercibility implements the interface sql.CollationCoercible.
    92  func (i *AutoIncrement) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    93  	return sql.GetCoercibility(ctx, i.Child)
    94  }
    95  
    96  // Eval implements the Expression interface.
    97  func (i *AutoIncrement) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    98  	// get value provided by INSERT
    99  	given, err := i.Child.Eval(ctx, row)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	// When a row passes in 0 as the auto_increment value it is equivalent to NULL.
   105  	cmp, err := i.Type().Compare(given, i.Type().Zero())
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	if cmp == 0 {
   110  		given = nil
   111  	} else if cmp < 0 {
   112  		// if given is negative, don't do any auto_increment logic
   113  		ret, _, err := i.Type().Convert(given)
   114  		return ret, err
   115  	}
   116  
   117  	// Update integrator AUTO_INCREMENT sequence with our value
   118  	seq, err := i.autoTbl.GetNextAutoIncrementValue(ctx, given)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	// Use sequence value if NULL or 0 were provided
   124  	if given == nil {
   125  		given = seq
   126  	}
   127  
   128  	ret, _, err := i.Type().Convert(given)
   129  	return ret, err
   130  }
   131  
   132  func (i *AutoIncrement) String() string {
   133  	return fmt.Sprintf("AutoIncrement(%s)", i.Child.String())
   134  }
   135  
   136  // WithChildren implements the Expression interface.
   137  func (i *AutoIncrement) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   138  	if len(children) != 1 {
   139  		return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1)
   140  	}
   141  	return &AutoIncrement{
   142  		UnaryExpression{Child: children[0]},
   143  		i.autoTbl,
   144  		i.autoCol,
   145  	}, nil
   146  }
   147  
   148  // Children implements the Expression interface.
   149  func (i *AutoIncrement) Children() []sql.Expression {
   150  	return []sql.Expression{i.Child}
   151  }