github.com/mithrandie/csvq@v1.18.1/lib/query/user_defined_function.go (about)

     1  package query
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  
     8  	"github.com/mithrandie/csvq/lib/parser"
     9  	"github.com/mithrandie/csvq/lib/value"
    10  )
    11  
    12  type UserDefinedFunctionMap struct {
    13  	*SyncMap
    14  }
    15  
    16  func NewUserDefinedFunctionMap() UserDefinedFunctionMap {
    17  	return UserDefinedFunctionMap{
    18  		NewSyncMap(),
    19  	}
    20  }
    21  
    22  func (m UserDefinedFunctionMap) IsEmpty() bool {
    23  	return m.SyncMap == nil
    24  }
    25  
    26  func (m UserDefinedFunctionMap) Store(name string, val *UserDefinedFunction) {
    27  	m.store(strings.ToUpper(name), val)
    28  }
    29  
    30  func (m UserDefinedFunctionMap) LoadDirect(name string) (interface{}, bool) {
    31  	return m.load(strings.ToUpper(name))
    32  }
    33  
    34  func (m UserDefinedFunctionMap) Load(name string) (*UserDefinedFunction, bool) {
    35  	if v, ok := m.load(strings.ToUpper(name)); ok {
    36  		return v.(*UserDefinedFunction), true
    37  	}
    38  	return nil, false
    39  }
    40  
    41  func (m UserDefinedFunctionMap) Delete(name string) {
    42  	m.delete(strings.ToUpper(name))
    43  }
    44  
    45  func (m UserDefinedFunctionMap) Exists(name string) bool {
    46  	return m.exists(strings.ToUpper(name))
    47  }
    48  
    49  func (m UserDefinedFunctionMap) Declare(expr parser.FunctionDeclaration) error {
    50  	if err := m.CheckDuplicate(expr.Name); err != nil {
    51  		return err
    52  	}
    53  
    54  	parameters, defaults, required, err := m.parseParameters(expr.Parameters)
    55  	if err != nil {
    56  		return err
    57  	}
    58  
    59  	m.Store(expr.Name.Literal, &UserDefinedFunction{
    60  		Name:         expr.Name,
    61  		Statements:   expr.Statements,
    62  		Parameters:   parameters,
    63  		Defaults:     defaults,
    64  		RequiredArgs: required,
    65  	})
    66  	return nil
    67  }
    68  
    69  func (m UserDefinedFunctionMap) DeclareAggregate(expr parser.AggregateDeclaration) error {
    70  	if err := m.CheckDuplicate(expr.Name); err != nil {
    71  		return err
    72  	}
    73  
    74  	parameters, defaults, required, err := m.parseParameters(expr.Parameters)
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	m.Store(expr.Name.Literal, &UserDefinedFunction{
    80  		Name:         expr.Name,
    81  		Statements:   expr.Statements,
    82  		Parameters:   parameters,
    83  		Defaults:     defaults,
    84  		RequiredArgs: required,
    85  		IsAggregate:  true,
    86  		Cursor:       expr.Cursor,
    87  	})
    88  	return nil
    89  }
    90  
    91  func (m UserDefinedFunctionMap) parseParameters(parameters []parser.VariableAssignment) ([]parser.Variable, map[string]parser.QueryExpression, int, error) {
    92  	var isDuplicate = func(variable parser.Variable, variables []parser.Variable) bool {
    93  		for _, v := range variables {
    94  			if variable.Name == v.Name {
    95  				return true
    96  			}
    97  		}
    98  		return false
    99  	}
   100  
   101  	variables := make([]parser.Variable, len(parameters))
   102  	defaults := make(map[string]parser.QueryExpression)
   103  
   104  	required := 0
   105  	for i, assignment := range parameters {
   106  		if isDuplicate(assignment.Variable, variables) {
   107  			return nil, nil, 0, NewDuplicateParameterError(assignment.Variable)
   108  		}
   109  
   110  		variables[i] = assignment.Variable
   111  		if assignment.Value == nil {
   112  			required = i + 1
   113  		} else {
   114  			defaults[assignment.Variable.Name] = assignment.Value
   115  		}
   116  	}
   117  	return variables, defaults, required, nil
   118  }
   119  
   120  func (m UserDefinedFunctionMap) CheckDuplicate(name parser.Identifier) error {
   121  	uname := strings.ToUpper(name.Literal)
   122  
   123  	if _, ok := Functions[uname]; ok || uname == "CALL" || uname == "NOW" || uname == "JSON_OBJECT" {
   124  		return NewBuiltInFunctionDeclaredError(name)
   125  	}
   126  	if _, ok := AggregateFunctions[uname]; ok {
   127  		return NewBuiltInFunctionDeclaredError(name)
   128  	}
   129  	if _, ok := AnalyticFunctions[uname]; ok {
   130  		return NewBuiltInFunctionDeclaredError(name)
   131  	}
   132  	if m.Exists(uname) {
   133  		return NewFunctionRedeclaredError(name)
   134  	}
   135  	return nil
   136  }
   137  
   138  func (m UserDefinedFunctionMap) Get(name string) (*UserDefinedFunction, bool) {
   139  	if fn, ok := m.Load(name); ok {
   140  		return fn, true
   141  	}
   142  	return nil, false
   143  }
   144  
   145  func (m UserDefinedFunctionMap) Dispose(name parser.Identifier) bool {
   146  	if m.Exists(name.Literal) {
   147  		m.Delete(name.Literal)
   148  		return true
   149  	}
   150  	return false
   151  }
   152  
   153  type UserDefinedFunction struct {
   154  	Name         parser.Identifier
   155  	Statements   []parser.Statement
   156  	Parameters   []parser.Variable
   157  	Defaults     map[string]parser.QueryExpression
   158  	RequiredArgs int
   159  
   160  	IsAggregate bool
   161  	Cursor      parser.Identifier // For Aggregate Functions
   162  }
   163  
   164  func (fn *UserDefinedFunction) Execute(ctx context.Context, scope *ReferenceScope, args []value.Primary) (value.Primary, error) {
   165  	childScope := scope.CreateChild()
   166  	defer childScope.CloseCurrentBlock()
   167  
   168  	return fn.execute(ctx, childScope, args)
   169  }
   170  
   171  func (fn *UserDefinedFunction) ExecuteAggregate(ctx context.Context, scope *ReferenceScope, values []value.Primary, args []value.Primary) (value.Primary, error) {
   172  	childScope := scope.CreateChild()
   173  	defer childScope.CloseCurrentBlock()
   174  
   175  	if err := childScope.AddPseudoCursor(fn.Cursor, values); err != nil {
   176  		return nil, err
   177  	}
   178  	return fn.execute(ctx, childScope, args)
   179  }
   180  
   181  func (fn *UserDefinedFunction) CheckArgsLen(expr parser.QueryExpression, name string, argsLen int) error {
   182  	parametersLen := len(fn.Parameters)
   183  	requiredLen := fn.RequiredArgs
   184  	if fn.IsAggregate {
   185  		parametersLen++
   186  		requiredLen++
   187  	}
   188  
   189  	if len(fn.Defaults) < 1 {
   190  		if argsLen != len(fn.Parameters) {
   191  			return NewFunctionArgumentLengthError(expr, name, []int{parametersLen})
   192  		}
   193  	} else if argsLen < fn.RequiredArgs {
   194  		return NewFunctionArgumentLengthErrorWithCustomArgs(expr, name, fmt.Sprintf("at least %s", FormatCount(requiredLen, "argument")))
   195  	} else if len(fn.Parameters) < argsLen {
   196  		return NewFunctionArgumentLengthErrorWithCustomArgs(expr, name, fmt.Sprintf("at most %s", FormatCount(parametersLen, "argument")))
   197  	}
   198  
   199  	return nil
   200  }
   201  
   202  func (fn *UserDefinedFunction) execute(ctx context.Context, scope *ReferenceScope, args []value.Primary) (value.Primary, error) {
   203  	if err := fn.CheckArgsLen(fn.Name, fn.Name.Literal, len(args)); err != nil {
   204  		return nil, err
   205  	}
   206  
   207  	for i, v := range fn.Parameters {
   208  		if i < len(args) {
   209  			if err := scope.Blocks[0].Variables.Add(v, args[i]); err != nil {
   210  				return nil, err
   211  			}
   212  		} else {
   213  			defaultValue, _ := fn.Defaults[v.Name]
   214  			val, err := Evaluate(ctx, scope, defaultValue)
   215  			if err != nil {
   216  				return nil, err
   217  			}
   218  			if err = scope.DeclareVariableDirectly(v, val); err != nil {
   219  				return nil, err
   220  			}
   221  		}
   222  	}
   223  
   224  	proc := NewProcessorWithScope(scope.Tx, scope)
   225  	if _, err := proc.execute(ctx, fn.Statements); err != nil {
   226  		return nil, err
   227  	}
   228  
   229  	ret := proc.returnVal
   230  	if ret == nil {
   231  		ret = value.NewNull()
   232  	}
   233  
   234  	return ret, nil
   235  }