github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/values_derived_table.go (about)

     1  package plan
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/dolthub/go-mysql-server/sql"
     8  	"github.com/dolthub/go-mysql-server/sql/types"
     9  )
    10  
    11  type ValueDerivedTable struct {
    12  	*Values
    13  	name    string
    14  	columns []string
    15  	sch     sql.Schema
    16  	id      sql.TableId
    17  	cols    sql.ColSet
    18  }
    19  
    20  var _ sql.Node = (*ValueDerivedTable)(nil)
    21  var _ sql.CollationCoercible = (*ValueDerivedTable)(nil)
    22  var _ TableIdNode = (*ValueDerivedTable)(nil)
    23  
    24  func NewValueDerivedTable(values *Values, name string) *ValueDerivedTable {
    25  	var s sql.Schema
    26  	if values.Resolved() && len(values.ExpressionTuples) != 0 {
    27  		s = getSchema(values.ExpressionTuples)
    28  	}
    29  	return &ValueDerivedTable{Values: values, name: name, sch: s}
    30  }
    31  
    32  // WithId implements sql.TableIdNode
    33  func (v *ValueDerivedTable) WithId(id sql.TableId) TableIdNode {
    34  	ret := *v
    35  	ret.id = id
    36  	return &ret
    37  }
    38  
    39  // Id implements sql.TableIdNode
    40  func (v *ValueDerivedTable) Id() sql.TableId {
    41  	return v.id
    42  }
    43  
    44  // WithColumns implements sql.TableIdNode
    45  func (v *ValueDerivedTable) WithColumns(set sql.ColSet) TableIdNode {
    46  	ret := *v
    47  	ret.cols = set
    48  	return &ret
    49  }
    50  
    51  // Columns implements sql.TableIdNode
    52  func (v *ValueDerivedTable) Columns() sql.ColSet {
    53  	return v.cols
    54  }
    55  
    56  // Name implements sql.Nameable
    57  func (v *ValueDerivedTable) Name() string {
    58  	return v.name
    59  }
    60  
    61  // Schema implements the Node interface.
    62  func (v *ValueDerivedTable) Schema() sql.Schema {
    63  	if len(v.ExpressionTuples) == 0 {
    64  		return nil
    65  	}
    66  
    67  	schema := make(sql.Schema, len(v.sch))
    68  	for i, col := range v.sch {
    69  		c := *col
    70  		c.Source = v.name
    71  		if len(v.columns) > 0 {
    72  			c.Name = v.columns[i]
    73  		} else {
    74  			c.Name = fmt.Sprintf("column_%d", i)
    75  		}
    76  		schema[i] = &c
    77  	}
    78  
    79  	return schema
    80  }
    81  
    82  // WithChildren implements the Node interface.
    83  func (v *ValueDerivedTable) WithChildren(children ...sql.Node) (sql.Node, error) {
    84  	if len(children) != 0 {
    85  		return nil, sql.ErrInvalidChildrenNumber.New(v, len(children), 0)
    86  	}
    87  
    88  	return v, nil
    89  }
    90  
    91  // WithExpressions implements the Expressioner interface.
    92  func (v *ValueDerivedTable) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
    93  	newValues, err := v.Values.WithExpressions(exprs...)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	nv := *v
    99  	nv.Values = newValues.(*Values)
   100  	if nv.Values.Resolved() && len(nv.Values.ExpressionTuples) != 0 {
   101  		nv.sch = getSchema(nv.Values.ExpressionTuples)
   102  	}
   103  	return &nv, nil
   104  }
   105  
   106  func (v *ValueDerivedTable) String() string {
   107  	children := make([]string, len(v.ExpressionTuples))
   108  	for i, tuple := range v.ExpressionTuples {
   109  		var sb strings.Builder
   110  		sb.WriteString("Row(\n")
   111  		for j, e := range tuple {
   112  			if j > 0 {
   113  				sb.WriteString(",")
   114  			}
   115  			sb.WriteString(e.String())
   116  		}
   117  		sb.WriteRune(')')
   118  		children[i] = sb.String()
   119  	}
   120  
   121  	tp := sql.NewTreePrinter()
   122  	_ = tp.WriteNode("Values() as %s", v.name)
   123  	_ = tp.WriteChildren(children...)
   124  
   125  	return tp.String()
   126  }
   127  
   128  func (v *ValueDerivedTable) DebugString() string {
   129  	children := make([]string, len(v.ExpressionTuples))
   130  	for i, tuple := range v.ExpressionTuples {
   131  		var sb strings.Builder
   132  		sb.WriteString("Row(\n")
   133  		for j, e := range tuple {
   134  			if j > 0 {
   135  				sb.WriteString(",")
   136  			}
   137  			sb.WriteString(sql.DebugString(e))
   138  		}
   139  		sb.WriteRune(')')
   140  		children[i] = sb.String()
   141  	}
   142  
   143  	tp := sql.NewTreePrinter()
   144  	_ = tp.WriteNode("Values() as %s", v.name)
   145  	_ = tp.WriteChildren(children...)
   146  
   147  	return tp.String()
   148  }
   149  
   150  func (v ValueDerivedTable) WithColumNames(columns []string) *ValueDerivedTable {
   151  	v.columns = columns
   152  	return &v
   153  }
   154  
   155  // getSchema returns schema created with most permissive types by examining all rows.
   156  func getSchema(rows [][]sql.Expression) sql.Schema {
   157  	s := make(sql.Schema, len(rows[0]))
   158  
   159  	for _, exprs := range rows {
   160  		for i, val := range exprs {
   161  			if s[i] == nil {
   162  				var name string
   163  				if n, ok := val.(sql.Nameable); ok {
   164  					name = n.Name()
   165  				} else {
   166  					name = val.String()
   167  				}
   168  
   169  				s[i] = &sql.Column{Name: name, Type: val.Type(), Nullable: val.IsNullable()}
   170  			} else {
   171  				s[i].Type = getMostPermissiveType(s[i], val)
   172  				if !s[i].Nullable {
   173  					s[i].Nullable = val.IsNullable()
   174  				}
   175  			}
   176  
   177  		}
   178  	}
   179  
   180  	return s
   181  }
   182  
   183  // getMostPermissiveType returns the most permissive type given the current type and the expression type.
   184  // The ordering is "other types < uint < int < decimal (float should be interpreted as decimal) < string"
   185  func getMostPermissiveType(s *sql.Column, e sql.Expression) sql.Type {
   186  	if types.IsText(s.Type) {
   187  		return s.Type
   188  	} else if types.IsText(e.Type()) {
   189  		return e.Type()
   190  	}
   191  
   192  	if st, ok := s.Type.(sql.DecimalType); ok {
   193  		et, eok := e.Type().(sql.DecimalType)
   194  		if !eok {
   195  			return s.Type
   196  		}
   197  		// if both are decimal types, get the bigger decimaltype
   198  		frac := st.Scale()
   199  		whole := st.Precision() - frac
   200  		if ep := et.Precision() - et.Scale(); ep > whole {
   201  			whole = ep
   202  		}
   203  		if et.Scale() > frac {
   204  			frac = et.Scale()
   205  		}
   206  		return types.MustCreateDecimalType(whole+frac, frac)
   207  	} else if types.IsDecimal(e.Type()) {
   208  		return e.Type()
   209  	}
   210  
   211  	// TODO: float type should be interpreted as decimal type
   212  	if types.IsFloat(s.Type) {
   213  		return s.Type
   214  	} else if types.IsFloat(e.Type()) {
   215  		return types.Float64
   216  	}
   217  
   218  	if types.IsSigned(s.Type) {
   219  		return s.Type
   220  	} else if types.IsSigned(e.Type()) {
   221  		return types.Int64
   222  	}
   223  
   224  	if types.IsUnsigned(s.Type) {
   225  		return s.Type
   226  	} else if types.IsUnsigned(e.Type()) {
   227  		return types.Uint64
   228  	}
   229  
   230  	return s.Type
   231  }