github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/inet_convert.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"encoding/binary"
    19  	"fmt"
    20  	"net"
    21  	"reflect"
    22  	"strings"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/types"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  )
    29  
    30  type InetAton struct {
    31  	expression.UnaryExpression
    32  }
    33  
    34  var _ sql.FunctionExpression = (*InetAton)(nil)
    35  var _ sql.CollationCoercible = (*InetAton)(nil)
    36  
    37  func NewInetAton(val sql.Expression) sql.Expression {
    38  	return &InetAton{expression.UnaryExpression{Child: val}}
    39  }
    40  
    41  // FunctionName implements sql.FunctionExpression
    42  func (i *InetAton) FunctionName() string {
    43  	return "inet_aton"
    44  }
    45  
    46  // Description implements sql.FunctionExpression
    47  func (i *InetAton) Description() string {
    48  	return "returns the numeric value of an IP address."
    49  }
    50  
    51  func (i *InetAton) String() string {
    52  	return fmt.Sprintf("%s(%s)", i.FunctionName(), i.Child.String())
    53  }
    54  
    55  func (i *InetAton) Type() sql.Type {
    56  	return types.Uint32
    57  }
    58  
    59  // CollationCoercibility implements the interface sql.CollationCoercible.
    60  func (*InetAton) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    61  	return sql.Collation_binary, 5
    62  }
    63  
    64  func (i *InetAton) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    65  	if len(children) != 1 {
    66  		return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1)
    67  	}
    68  	return NewInetAton(children[0]), nil
    69  }
    70  
    71  func (i *InetAton) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    72  	// Evaluate value
    73  	val, err := i.Child.Eval(ctx, row)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	// Return null if given null
    79  	if val == nil {
    80  		return nil, nil
    81  	}
    82  
    83  	// Expect to receive an IP address, so convert val into string
    84  	ipstr, err := types.ConvertToString(val, types.LongText)
    85  	if err != nil {
    86  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(val).String())
    87  	}
    88  
    89  	// Parse IP address
    90  	ip := net.ParseIP(ipstr)
    91  	if ip == nil {
    92  		// Failed to Parse IP correctly
    93  		ctx.Warn(1411, fmt.Sprintf("Incorrect string value: ''%s'' for function %s", ipstr, i.FunctionName()))
    94  		return nil, nil
    95  	}
    96  
    97  	// Expect an IPv4 address
    98  	ipv4 := ip.To4()
    99  	if ipv4 == nil {
   100  		// Received invalid IPv4 address (IPv6 address are invalid)
   101  		ctx.Warn(1411, fmt.Sprintf("Incorrect string value: ''%s'' for function %s", ipstr, i.FunctionName()))
   102  		return nil, nil
   103  	}
   104  
   105  	// Return IPv4 address as uint32
   106  	ipv4int := binary.BigEndian.Uint32(ipv4)
   107  	return ipv4int, nil
   108  }
   109  
   110  type Inet6Aton struct {
   111  	expression.UnaryExpression
   112  }
   113  
   114  var _ sql.FunctionExpression = (*Inet6Aton)(nil)
   115  var _ sql.CollationCoercible = (*Inet6Aton)(nil)
   116  
   117  func NewInet6Aton(val sql.Expression) sql.Expression {
   118  	return &Inet6Aton{expression.UnaryExpression{Child: val}}
   119  }
   120  
   121  // FunctionName implements sql.FunctionExpression
   122  func (i *Inet6Aton) FunctionName() string {
   123  	return "inet6_aton"
   124  }
   125  
   126  // Description implements sql.FunctionExpression
   127  func (i *Inet6Aton) Description() string {
   128  	return "returns the numeric value of an IPv6 address."
   129  }
   130  
   131  func (i *Inet6Aton) String() string {
   132  	return fmt.Sprintf("%s(%s)", i.FunctionName(), i.Child.String())
   133  }
   134  
   135  func (i *Inet6Aton) Type() sql.Type {
   136  	return types.LongBlob
   137  }
   138  
   139  // CollationCoercibility implements the interface sql.CollationCoercible.
   140  func (*Inet6Aton) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   141  	return sql.Collation_binary, 4
   142  }
   143  
   144  func (i *Inet6Aton) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   145  	if len(children) != 1 {
   146  		return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1)
   147  	}
   148  	return NewInet6Aton(children[0]), nil
   149  }
   150  
   151  func (i *Inet6Aton) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   152  	// Evaluate value
   153  	val, err := i.Child.Eval(ctx, row)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	// Return null if given null
   159  	if val == nil {
   160  		return nil, nil
   161  	}
   162  
   163  	// Parse IP address
   164  	ipstr := val.(string)
   165  	ip := net.ParseIP(ipstr)
   166  	if ip == nil {
   167  		// Failed to Parse IP correctly
   168  		ctx.Warn(1411, fmt.Sprintf("Incorrect string value: ''%s'' for function %s", ipstr, i.FunctionName()))
   169  		return nil, nil
   170  	}
   171  
   172  	// if it doesn't contain colons, treat it as ipv4
   173  	if strings.Count(val.(string), ":") < 2 {
   174  		ipv4 := ip.To4()
   175  		return []byte(ipv4), nil
   176  	}
   177  
   178  	// Received IPv6 address
   179  	ipv6 := ip.To16()
   180  	if ipv6 == nil {
   181  		// Invalid IPv6 address
   182  		ctx.Warn(1411, fmt.Sprintf("Incorrect string value: ''%s'' for function %s", ipstr, i.FunctionName()))
   183  		return nil, nil
   184  	}
   185  
   186  	// Return as []byte
   187  	return []byte(ipv6), nil
   188  }
   189  
   190  type InetNtoa struct {
   191  	expression.UnaryExpression
   192  }
   193  
   194  var _ sql.FunctionExpression = (*InetNtoa)(nil)
   195  var _ sql.CollationCoercible = (*InetNtoa)(nil)
   196  
   197  func NewInetNtoa(val sql.Expression) sql.Expression {
   198  	return &InetNtoa{expression.UnaryExpression{Child: val}}
   199  }
   200  
   201  // FunctionName implements sql.FunctionExpression
   202  func (i *InetNtoa) FunctionName() string {
   203  	return "inet_ntoa"
   204  }
   205  
   206  // Description implements sql.FunctionExpression
   207  func (i *InetNtoa) Description() string {
   208  	return "returns the IP address from a numeric value."
   209  }
   210  
   211  func (i *InetNtoa) String() string {
   212  	return fmt.Sprintf("%s(%s)", i.FunctionName(), i.Child.String())
   213  }
   214  
   215  func (i *InetNtoa) Type() sql.Type {
   216  	return types.LongText
   217  }
   218  
   219  // CollationCoercibility implements the interface sql.CollationCoercible.
   220  func (*InetNtoa) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   221  	return ctx.GetCollation(), 4
   222  }
   223  
   224  func (i *InetNtoa) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   225  	if len(children) != 1 {
   226  		return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1)
   227  	}
   228  	return NewInetNtoa(children[0]), nil
   229  }
   230  
   231  func (i *InetNtoa) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   232  	// Evaluate value
   233  	val, err := i.Child.Eval(ctx, row)
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  
   238  	// Return null if given null
   239  	if val == nil {
   240  		return nil, nil
   241  	}
   242  
   243  	// Convert val into int
   244  	ipv4int, _, err := types.Int32.Convert(val)
   245  	if ipv4int != nil && err != nil {
   246  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(val).String())
   247  	}
   248  
   249  	// Received a hex string instead of int
   250  	if ipv4int == nil {
   251  		// Create new IPv4
   252  		var ipv4 net.IP = []byte{0, 0, 0, 0}
   253  		return ipv4.String(), nil
   254  	}
   255  
   256  	// Create new IPv4, and fill with val
   257  	ipv4 := make(net.IP, 4)
   258  	binary.BigEndian.PutUint32(ipv4, uint32(ipv4int.(int32)))
   259  
   260  	return ipv4.String(), nil
   261  }
   262  
   263  type Inet6Ntoa struct {
   264  	expression.UnaryExpression
   265  }
   266  
   267  var _ sql.FunctionExpression = (*Inet6Ntoa)(nil)
   268  var _ sql.CollationCoercible = (*Inet6Ntoa)(nil)
   269  
   270  func NewInet6Ntoa(val sql.Expression) sql.Expression {
   271  	return &Inet6Ntoa{expression.UnaryExpression{Child: val}}
   272  }
   273  
   274  // FunctionName implements sql.FunctionExpression
   275  func (i *Inet6Ntoa) FunctionName() string {
   276  	return "inet6_ntoa"
   277  }
   278  
   279  // Description implements sql.FunctionExpression
   280  func (i *Inet6Ntoa) Description() string {
   281  	return "returns the IPv6 address from a numeric value."
   282  }
   283  
   284  func (i *Inet6Ntoa) String() string {
   285  	return fmt.Sprintf("%s(%s)", i.FunctionName(), i.Child.String())
   286  }
   287  
   288  func (i *Inet6Ntoa) Type() sql.Type {
   289  	return types.LongText
   290  }
   291  
   292  // CollationCoercibility implements the interface sql.CollationCoercible.
   293  func (*Inet6Ntoa) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   294  	return ctx.GetCollation(), 4
   295  }
   296  
   297  func (i *Inet6Ntoa) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   298  	if len(children) != 1 {
   299  		return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1)
   300  	}
   301  	return NewInet6Ntoa(children[0]), nil
   302  }
   303  
   304  func (i *Inet6Ntoa) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   305  	// Evaluate value
   306  	val, err := i.Child.Eval(ctx, row)
   307  	if err != nil {
   308  		return nil, err
   309  	}
   310  
   311  	// Return null if given null
   312  	if val == nil {
   313  		return nil, nil
   314  	}
   315  
   316  	// Only convert if received string as input
   317  	switch val.(type) {
   318  	case []byte:
   319  		ipbytes := val.([]byte)
   320  
   321  		// Exactly 4 bytes, treat as IPv4 address
   322  		if len(ipbytes) == 4 {
   323  			var ipv4 net.IP = ipbytes
   324  			return ipv4.String(), nil
   325  		}
   326  
   327  		// There must be exactly 4 or 16 bytes (len == 4 satisfied above)
   328  		if len(ipbytes) != 16 {
   329  			ctx.Warn(1411, fmt.Sprintf("Incorrect string value: ''%s'' for function %s", string(val.([]byte)), i.FunctionName()))
   330  			return nil, nil
   331  		}
   332  
   333  		// Check to see if it should be printed as IPv6; non-zero within first 10 bytes
   334  		for _, b := range ipbytes[:10] {
   335  			if b != 0 {
   336  				// Create new IPv6
   337  				var ipv6 net.IP = ipbytes
   338  				return ipv6.String(), nil
   339  			}
   340  		}
   341  
   342  		// IPv4-compatible (12 bytes of 0x00)
   343  		if ipbytes[10] == 0 && ipbytes[11] == 0 && (ipbytes[12] != 0 || ipbytes[13] != 0) {
   344  			var ipv4 net.IP = ipbytes[12:]
   345  			return "::" + ipv4.String(), nil
   346  		}
   347  
   348  		// IPv4-mapped (10 bytes of 0x00 followed by 2 bytes of 0xFF)
   349  		if ipbytes[10] == 0xFF && ipbytes[11] == 0xFF {
   350  			var ipv4 net.IP = ipbytes[12:]
   351  			return "::ffff:" + ipv4.String(), nil
   352  		}
   353  
   354  		// Print as IPv6 by default
   355  		var ipv6 net.IP = ipbytes
   356  		return ipv6.String(), nil
   357  	default:
   358  		ctx.Warn(1411, fmt.Sprintf("Incorrect string value: ''%v'' for function %s", val, i.FunctionName()))
   359  		return nil, nil
   360  	}
   361  }