github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/tobase64_frombase64.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"encoding/base64"
    19  	"fmt"
    20  	"reflect"
    21  	"strings"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql/encodings"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	"github.com/dolthub/go-mysql-server/sql/expression"
    28  )
    29  
    30  // ToBase64 is a function to encode a string to the Base64 format
    31  // using the same dialect that MySQL's TO_BASE64 uses
    32  type ToBase64 struct {
    33  	expression.UnaryExpression
    34  }
    35  
    36  var _ sql.FunctionExpression = (*ToBase64)(nil)
    37  var _ sql.CollationCoercible = (*ToBase64)(nil)
    38  
    39  // NewToBase64 creates a new ToBase64 expression.
    40  func NewToBase64(e sql.Expression) sql.Expression {
    41  	return &ToBase64{expression.UnaryExpression{Child: e}}
    42  }
    43  
    44  // FunctionName implements sql.FunctionExpression
    45  func (t *ToBase64) FunctionName() string {
    46  	return "to_base64"
    47  }
    48  
    49  // Description implements sql.FunctionExpression
    50  func (t *ToBase64) Description() string {
    51  	return "encodes the string str in base64 format."
    52  }
    53  
    54  // Eval implements the Expression interface.
    55  func (t *ToBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    56  	val, err := t.Child.Eval(ctx, row)
    57  
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	if val == nil {
    63  		return nil, nil
    64  	}
    65  
    66  	var strBytes []byte
    67  	if types.IsTextOnly(t.Child.Type()) {
    68  		val, _, err = t.Child.Type().Convert(val)
    69  		if err != nil {
    70  			return nil, sql.ErrInvalidType.New(reflect.TypeOf(val))
    71  		}
    72  		// For string types we need to re-encode the internal string so that we get the correct base64 output
    73  		encoder := t.Child.Type().(sql.StringType).Collation().CharacterSet().Encoder()
    74  		encodedBytes, ok := encoder.Encode(encodings.StringToBytes(val.(string)))
    75  		if !ok {
    76  			return nil, fmt.Errorf("unable to re-encode string for TO_BASE64 function")
    77  		}
    78  		strBytes = encodedBytes
    79  	} else {
    80  		val, _, err = types.LongText.Convert(val)
    81  		if err != nil {
    82  			return nil, sql.ErrInvalidType.New(reflect.TypeOf(val))
    83  		}
    84  		strBytes = []byte(val.(string))
    85  	}
    86  
    87  	encoded := base64.StdEncoding.EncodeToString(strBytes)
    88  
    89  	lenEncoded := len(encoded)
    90  	if lenEncoded <= 76 {
    91  		return encoded, nil
    92  	}
    93  
    94  	// Split into max 76 chars lines
    95  	var out strings.Builder
    96  	start := 0
    97  	end := 76
    98  	for {
    99  		out.WriteString(encoded[start:end] + "\n")
   100  		start += 76
   101  		end += 76
   102  		if end >= lenEncoded {
   103  			out.WriteString(encoded[start:lenEncoded])
   104  			break
   105  		}
   106  	}
   107  
   108  	return out.String(), nil
   109  }
   110  
   111  // String implements the fmt.Stringer interface.
   112  func (t *ToBase64) String() string {
   113  	return fmt.Sprintf("%s(%s)", t.FunctionName(), t.Child)
   114  }
   115  
   116  // IsNullable implements the Expression interface.
   117  func (t *ToBase64) IsNullable() bool {
   118  	return t.Child.IsNullable()
   119  }
   120  
   121  // WithChildren implements the Expression interface.
   122  func (t *ToBase64) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   123  	if len(children) != 1 {
   124  		return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
   125  	}
   126  	return NewToBase64(children[0]), nil
   127  }
   128  
   129  // Type implements the Expression interface.
   130  func (t *ToBase64) Type() sql.Type {
   131  	return types.LongText
   132  }
   133  
   134  // CollationCoercibility implements the interface sql.CollationCoercible.
   135  func (*ToBase64) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   136  	return ctx.GetCollation(), 4
   137  }
   138  
   139  // FromBase64 is a function to decode a Base64-formatted string
   140  // using the same dialect that MySQL's FROM_BASE64 uses
   141  type FromBase64 struct {
   142  	expression.UnaryExpression
   143  }
   144  
   145  var _ sql.FunctionExpression = (*FromBase64)(nil)
   146  var _ sql.CollationCoercible = (*FromBase64)(nil)
   147  
   148  // NewFromBase64 creates a new FromBase64 expression.
   149  func NewFromBase64(e sql.Expression) sql.Expression {
   150  	return &FromBase64{expression.UnaryExpression{Child: e}}
   151  }
   152  
   153  // FunctionName implements sql.FunctionExpression
   154  func (t *FromBase64) FunctionName() string {
   155  	return "from_base64"
   156  }
   157  
   158  // Description implements sql.FunctionExpression
   159  func (t *FromBase64) Description() string {
   160  	return "decodes the base64-encoded string str."
   161  }
   162  
   163  // Eval implements the Expression interface.
   164  func (t *FromBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   165  	str, err := t.Child.Eval(ctx, row)
   166  
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	if str == nil {
   172  		return nil, nil
   173  	}
   174  
   175  	str, _, err = types.LongText.Convert(str)
   176  	if err != nil {
   177  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
   178  	}
   179  
   180  	decoded, err := base64.StdEncoding.DecodeString(str.(string))
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	return decoded, nil
   186  }
   187  
   188  // String implements the fmt.Stringer interface.
   189  func (t *FromBase64) String() string {
   190  	return fmt.Sprintf("%s(%s)", t.FunctionName(), t.Child)
   191  }
   192  
   193  // IsNullable implements the Expression interface.
   194  func (t *FromBase64) IsNullable() bool {
   195  	return t.Child.IsNullable()
   196  }
   197  
   198  // WithChildren implements the Expression interface.
   199  func (t *FromBase64) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   200  	if len(children) != 1 {
   201  		return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
   202  	}
   203  	return NewFromBase64(children[0]), nil
   204  }
   205  
   206  // Type implements the Expression interface.
   207  func (t *FromBase64) Type() sql.Type {
   208  	return types.LongBlob
   209  }
   210  
   211  // CollationCoercibility implements the interface sql.CollationCoercible.
   212  func (*FromBase64) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   213  	return sql.Collation_binary, 4
   214  }