github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/get_field.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	errors "gopkg.in/src-d/go-errors.v1"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  )
    25  
    26  // GetField is an expression to get the field of a table.
    27  type GetField struct {
    28  	db         string
    29  	table      string
    30  	fieldIndex int
    31  	// exprId lets the lifecycle of getFields be idempotent. We can re-index
    32  	// or re-apply scope/caching optimizations without worrying about losing
    33  	// the reference to the unique id.
    34  	exprId     sql.ColumnId
    35  	tableId    sql.TableId
    36  	name       string
    37  	fieldType  sql.Type
    38  	fieldType2 sql.Type2
    39  	nullable   bool
    40  
    41  	backTickNames bool
    42  }
    43  
    44  var _ sql.Expression = (*GetField)(nil)
    45  var _ sql.Expression2 = (*GetField)(nil)
    46  var _ sql.CollationCoercible = (*GetField)(nil)
    47  var _ sql.IdExpression = (*GetField)(nil)
    48  
    49  // NewGetField creates a GetField expression.
    50  func NewGetField(index int, fieldType sql.Type, fieldName string, nullable bool) *GetField {
    51  	return NewGetFieldWithTable(index, 0, fieldType, "", "", fieldName, nullable)
    52  }
    53  
    54  // NewGetFieldWithTable creates a GetField expression with table name. The table name may be an alias.
    55  func NewGetFieldWithTable(index, tableId int, fieldType sql.Type, db, table, fieldName string, nullable bool) *GetField {
    56  	fieldType2, _ := fieldType.(sql.Type2)
    57  	return &GetField{
    58  		db:         db,
    59  		table:      table,
    60  		fieldIndex: index,
    61  		fieldType:  fieldType,
    62  		fieldType2: fieldType2,
    63  		name:       fieldName,
    64  		nullable:   nullable,
    65  		exprId:     sql.ColumnId(index),
    66  		tableId:    sql.TableId(tableId),
    67  	}
    68  }
    69  
    70  // Index returns the index where the GetField will look for the value from a sql.Row.
    71  func (p *GetField) Index() int { return p.fieldIndex }
    72  
    73  func (p *GetField) Id() sql.ColumnId { return p.exprId }
    74  
    75  func (p *GetField) WithId(id sql.ColumnId) sql.IdExpression {
    76  	ret := *p
    77  	ret.exprId = id
    78  	return &ret
    79  }
    80  
    81  func (p *GetField) TableId() sql.TableId { return p.tableId }
    82  
    83  func (p *GetField) Database() string { return p.db }
    84  
    85  // Children implements the Expression interface.
    86  func (*GetField) Children() []sql.Expression {
    87  	return nil
    88  }
    89  
    90  // Table returns the name of the field table.
    91  func (p *GetField) Table() string { return p.table }
    92  
    93  func (p *GetField) TableID() sql.TableId {
    94  	return p.tableId
    95  }
    96  
    97  // WithTable returns a copy of this expression with the table given
    98  func (p *GetField) WithTable(table string) *GetField {
    99  	p2 := *p
   100  	p2.table = table
   101  	return &p2
   102  }
   103  
   104  // WithName returns a copy of this expression with the field name given.
   105  func (p *GetField) WithName(name string) *GetField {
   106  	p2 := *p
   107  	p2.name = name
   108  	return &p2
   109  }
   110  
   111  // Resolved implements the Expression interface.
   112  func (p *GetField) Resolved() bool {
   113  	return true
   114  }
   115  
   116  // Name implements the Nameable interface.
   117  func (p *GetField) Name() string {
   118  	return p.name
   119  }
   120  
   121  // IsNullable returns whether the field is nullable or not.
   122  func (p *GetField) IsNullable() bool {
   123  	return p.nullable
   124  }
   125  
   126  // Type returns the type of the field.
   127  func (p *GetField) Type() sql.Type {
   128  	return p.fieldType
   129  }
   130  
   131  // Type2 returns the type of the field, if this field has a sql.Type2.
   132  func (p *GetField) Type2() sql.Type2 {
   133  	return p.fieldType2
   134  }
   135  
   136  // ErrIndexOutOfBounds is returned when the field index is out of the bounds.
   137  var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns")
   138  
   139  // Eval implements the Expression interface.
   140  func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   141  	if p.fieldIndex < 0 || p.fieldIndex >= len(row) {
   142  		return nil, ErrIndexOutOfBounds.New(p.fieldIndex, len(row))
   143  	}
   144  	return row[p.fieldIndex], nil
   145  }
   146  
   147  func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) {
   148  	if p.fieldIndex < 0 || p.fieldIndex >= row.Len() {
   149  		return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len())
   150  	}
   151  
   152  	return row.GetField(p.fieldIndex), nil
   153  }
   154  
   155  // WithChildren implements the Expression interface.
   156  func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   157  	if len(children) != 0 {
   158  		return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0)
   159  	}
   160  	return p, nil
   161  }
   162  
   163  func (p *GetField) String() string {
   164  	if p.table == "" {
   165  		if p.backTickNames {
   166  			return fmt.Sprintf("`%s`", p.name)
   167  		}
   168  		return p.name
   169  	}
   170  	return fmt.Sprintf("%s.%s", p.table, p.name)
   171  }
   172  
   173  func (p *GetField) DebugString() string {
   174  	var notNull string
   175  	if !p.nullable {
   176  		notNull = "!null"
   177  	}
   178  	if p.table == "" {
   179  		return fmt.Sprintf("%s:%d%s", p.name, p.fieldIndex, notNull)
   180  	}
   181  	return fmt.Sprintf("%s.%s:%d%s", p.table, p.name, p.fieldIndex, notNull)
   182  }
   183  
   184  // WithIndex returns this same GetField with a new index.
   185  func (p *GetField) WithIndex(n int) sql.Expression {
   186  	p2 := *p
   187  	p2.fieldIndex = n
   188  	return &p2
   189  }
   190  
   191  // WithBackTickNames returns a copy of this expression with the backtick names flag set to the given value.
   192  func (p *GetField) WithBackTickNames(backtick bool) *GetField {
   193  	p2 := *p
   194  	p2.backTickNames = backtick
   195  	return &p2
   196  }
   197  
   198  // IsBackTickNames returns whether the field name should be quoted with backticks.
   199  func (p *GetField) IsBackTickNames() bool {
   200  	return p.backTickNames
   201  }
   202  
   203  // CollationCoercibility implements the interface sql.CollationCoercible.
   204  func (p *GetField) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   205  	collation, _ = p.fieldType.CollationCoercibility(ctx)
   206  	return collation, 2
   207  }
   208  
   209  // SchemaToGetFields takes a schema and returns an expression array of
   210  // GetFields. If |columns| is provided, each get field will get the
   211  // appropriate expression id.
   212  func SchemaToGetFields(s sql.Schema, columns sql.ColSet) []sql.Expression {
   213  	ret := make([]sql.Expression, len(s))
   214  
   215  	var offset sql.ColumnId
   216  	if !columns.Empty() {
   217  		offset, _ = columns.Next(1)
   218  	}
   219  	for i, col := range s {
   220  		// 0 id represents the dual table column
   221  		id := i
   222  		if offset > 0 {
   223  			id += int(offset)
   224  		}
   225  		ret[i] = NewGetFieldWithTable(id, 0, col.Type, col.DatabaseSource, col.Source, col.Name, col.Nullable)
   226  	}
   227  
   228  	return ret
   229  }
   230  
   231  // ExtractGetField returns the inner GetField expression from another expression. If there are multiple GetField
   232  // expressions that are not the same, then none of the GetField expressions are returned.
   233  func ExtractGetField(e sql.Expression) *GetField {
   234  	var field *GetField
   235  	multipleFields := false
   236  	sql.Inspect(e, func(expr sql.Expression) bool {
   237  		if f, ok := expr.(*GetField); ok {
   238  			if field == nil {
   239  				field = f
   240  			} else if strings.ToLower(field.table) != strings.ToLower(f.table) ||
   241  				strings.ToLower(field.name) != strings.ToLower(f.name) {
   242  				multipleFields = true
   243  				return false
   244  			}
   245  			return true
   246  		}
   247  		return true
   248  	})
   249  
   250  	if multipleFields {
   251  		return nil
   252  	}
   253  	return field
   254  }