github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/values_derived_table.go (about) 1 package plan 2 3 import ( 4 "fmt" 5 "strings" 6 7 "github.com/dolthub/go-mysql-server/sql" 8 "github.com/dolthub/go-mysql-server/sql/types" 9 ) 10 11 type ValueDerivedTable struct { 12 *Values 13 name string 14 columns []string 15 sch sql.Schema 16 id sql.TableId 17 cols sql.ColSet 18 } 19 20 var _ sql.Node = (*ValueDerivedTable)(nil) 21 var _ sql.CollationCoercible = (*ValueDerivedTable)(nil) 22 var _ TableIdNode = (*ValueDerivedTable)(nil) 23 24 func NewValueDerivedTable(values *Values, name string) *ValueDerivedTable { 25 var s sql.Schema 26 if values.Resolved() && len(values.ExpressionTuples) != 0 { 27 s = getSchema(values.ExpressionTuples) 28 } 29 return &ValueDerivedTable{Values: values, name: name, sch: s} 30 } 31 32 // WithId implements sql.TableIdNode 33 func (v *ValueDerivedTable) WithId(id sql.TableId) TableIdNode { 34 ret := *v 35 ret.id = id 36 return &ret 37 } 38 39 // Id implements sql.TableIdNode 40 func (v *ValueDerivedTable) Id() sql.TableId { 41 return v.id 42 } 43 44 // WithColumns implements sql.TableIdNode 45 func (v *ValueDerivedTable) WithColumns(set sql.ColSet) TableIdNode { 46 ret := *v 47 ret.cols = set 48 return &ret 49 } 50 51 // Columns implements sql.TableIdNode 52 func (v *ValueDerivedTable) Columns() sql.ColSet { 53 return v.cols 54 } 55 56 // Name implements sql.Nameable 57 func (v *ValueDerivedTable) Name() string { 58 return v.name 59 } 60 61 // Schema implements the Node interface. 62 func (v *ValueDerivedTable) Schema() sql.Schema { 63 if len(v.ExpressionTuples) == 0 { 64 return nil 65 } 66 67 schema := make(sql.Schema, len(v.sch)) 68 for i, col := range v.sch { 69 c := *col 70 c.Source = v.name 71 if len(v.columns) > 0 { 72 c.Name = v.columns[i] 73 } else { 74 c.Name = fmt.Sprintf("column_%d", i) 75 } 76 schema[i] = &c 77 } 78 79 return schema 80 } 81 82 // WithChildren implements the Node interface. 83 func (v *ValueDerivedTable) WithChildren(children ...sql.Node) (sql.Node, error) { 84 if len(children) != 0 { 85 return nil, sql.ErrInvalidChildrenNumber.New(v, len(children), 0) 86 } 87 88 return v, nil 89 } 90 91 // WithExpressions implements the Expressioner interface. 92 func (v *ValueDerivedTable) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { 93 newValues, err := v.Values.WithExpressions(exprs...) 94 if err != nil { 95 return nil, err 96 } 97 98 nv := *v 99 nv.Values = newValues.(*Values) 100 if nv.Values.Resolved() && len(nv.Values.ExpressionTuples) != 0 { 101 nv.sch = getSchema(nv.Values.ExpressionTuples) 102 } 103 return &nv, nil 104 } 105 106 func (v *ValueDerivedTable) String() string { 107 children := make([]string, len(v.ExpressionTuples)) 108 for i, tuple := range v.ExpressionTuples { 109 var sb strings.Builder 110 sb.WriteString("Row(\n") 111 for j, e := range tuple { 112 if j > 0 { 113 sb.WriteString(",") 114 } 115 sb.WriteString(e.String()) 116 } 117 sb.WriteRune(')') 118 children[i] = sb.String() 119 } 120 121 tp := sql.NewTreePrinter() 122 _ = tp.WriteNode("Values() as %s", v.name) 123 _ = tp.WriteChildren(children...) 124 125 return tp.String() 126 } 127 128 func (v *ValueDerivedTable) DebugString() string { 129 children := make([]string, len(v.ExpressionTuples)) 130 for i, tuple := range v.ExpressionTuples { 131 var sb strings.Builder 132 sb.WriteString("Row(\n") 133 for j, e := range tuple { 134 if j > 0 { 135 sb.WriteString(",") 136 } 137 sb.WriteString(sql.DebugString(e)) 138 } 139 sb.WriteRune(')') 140 children[i] = sb.String() 141 } 142 143 tp := sql.NewTreePrinter() 144 _ = tp.WriteNode("Values() as %s", v.name) 145 _ = tp.WriteChildren(children...) 146 147 return tp.String() 148 } 149 150 func (v ValueDerivedTable) WithColumNames(columns []string) *ValueDerivedTable { 151 v.columns = columns 152 return &v 153 } 154 155 // getSchema returns schema created with most permissive types by examining all rows. 156 func getSchema(rows [][]sql.Expression) sql.Schema { 157 s := make(sql.Schema, len(rows[0])) 158 159 for _, exprs := range rows { 160 for i, val := range exprs { 161 if s[i] == nil { 162 var name string 163 if n, ok := val.(sql.Nameable); ok { 164 name = n.Name() 165 } else { 166 name = val.String() 167 } 168 169 s[i] = &sql.Column{Name: name, Type: val.Type(), Nullable: val.IsNullable()} 170 } else { 171 s[i].Type = getMostPermissiveType(s[i], val) 172 if !s[i].Nullable { 173 s[i].Nullable = val.IsNullable() 174 } 175 } 176 177 } 178 } 179 180 return s 181 } 182 183 // getMostPermissiveType returns the most permissive type given the current type and the expression type. 184 // The ordering is "other types < uint < int < decimal (float should be interpreted as decimal) < string" 185 func getMostPermissiveType(s *sql.Column, e sql.Expression) sql.Type { 186 if types.IsText(s.Type) { 187 return s.Type 188 } else if types.IsText(e.Type()) { 189 return e.Type() 190 } 191 192 if st, ok := s.Type.(sql.DecimalType); ok { 193 et, eok := e.Type().(sql.DecimalType) 194 if !eok { 195 return s.Type 196 } 197 // if both are decimal types, get the bigger decimaltype 198 frac := st.Scale() 199 whole := st.Precision() - frac 200 if ep := et.Precision() - et.Scale(); ep > whole { 201 whole = ep 202 } 203 if et.Scale() > frac { 204 frac = et.Scale() 205 } 206 return types.MustCreateDecimalType(whole+frac, frac) 207 } else if types.IsDecimal(e.Type()) { 208 return e.Type() 209 } 210 211 // TODO: float type should be interpreted as decimal type 212 if types.IsFloat(s.Type) { 213 return s.Type 214 } else if types.IsFloat(e.Type()) { 215 return types.Float64 216 } 217 218 if types.IsSigned(s.Type) { 219 return s.Type 220 } else if types.IsSigned(e.Type()) { 221 return types.Int64 222 } 223 224 if types.IsUnsigned(s.Type) { 225 return s.Type 226 } else if types.IsUnsigned(e.Type()) { 227 return types.Uint64 228 } 229 230 return s.Type 231 }