github.com/mithrandie/csvq@v1.18.1/lib/query/user_defined_function.go (about) 1 package query 2 3 import ( 4 "context" 5 "fmt" 6 "strings" 7 8 "github.com/mithrandie/csvq/lib/parser" 9 "github.com/mithrandie/csvq/lib/value" 10 ) 11 12 type UserDefinedFunctionMap struct { 13 *SyncMap 14 } 15 16 func NewUserDefinedFunctionMap() UserDefinedFunctionMap { 17 return UserDefinedFunctionMap{ 18 NewSyncMap(), 19 } 20 } 21 22 func (m UserDefinedFunctionMap) IsEmpty() bool { 23 return m.SyncMap == nil 24 } 25 26 func (m UserDefinedFunctionMap) Store(name string, val *UserDefinedFunction) { 27 m.store(strings.ToUpper(name), val) 28 } 29 30 func (m UserDefinedFunctionMap) LoadDirect(name string) (interface{}, bool) { 31 return m.load(strings.ToUpper(name)) 32 } 33 34 func (m UserDefinedFunctionMap) Load(name string) (*UserDefinedFunction, bool) { 35 if v, ok := m.load(strings.ToUpper(name)); ok { 36 return v.(*UserDefinedFunction), true 37 } 38 return nil, false 39 } 40 41 func (m UserDefinedFunctionMap) Delete(name string) { 42 m.delete(strings.ToUpper(name)) 43 } 44 45 func (m UserDefinedFunctionMap) Exists(name string) bool { 46 return m.exists(strings.ToUpper(name)) 47 } 48 49 func (m UserDefinedFunctionMap) Declare(expr parser.FunctionDeclaration) error { 50 if err := m.CheckDuplicate(expr.Name); err != nil { 51 return err 52 } 53 54 parameters, defaults, required, err := m.parseParameters(expr.Parameters) 55 if err != nil { 56 return err 57 } 58 59 m.Store(expr.Name.Literal, &UserDefinedFunction{ 60 Name: expr.Name, 61 Statements: expr.Statements, 62 Parameters: parameters, 63 Defaults: defaults, 64 RequiredArgs: required, 65 }) 66 return nil 67 } 68 69 func (m UserDefinedFunctionMap) DeclareAggregate(expr parser.AggregateDeclaration) error { 70 if err := m.CheckDuplicate(expr.Name); err != nil { 71 return err 72 } 73 74 parameters, defaults, required, err := m.parseParameters(expr.Parameters) 75 if err != nil { 76 return err 77 } 78 79 m.Store(expr.Name.Literal, &UserDefinedFunction{ 80 Name: expr.Name, 81 Statements: expr.Statements, 82 Parameters: parameters, 83 Defaults: defaults, 84 RequiredArgs: required, 85 IsAggregate: true, 86 Cursor: expr.Cursor, 87 }) 88 return nil 89 } 90 91 func (m UserDefinedFunctionMap) parseParameters(parameters []parser.VariableAssignment) ([]parser.Variable, map[string]parser.QueryExpression, int, error) { 92 var isDuplicate = func(variable parser.Variable, variables []parser.Variable) bool { 93 for _, v := range variables { 94 if variable.Name == v.Name { 95 return true 96 } 97 } 98 return false 99 } 100 101 variables := make([]parser.Variable, len(parameters)) 102 defaults := make(map[string]parser.QueryExpression) 103 104 required := 0 105 for i, assignment := range parameters { 106 if isDuplicate(assignment.Variable, variables) { 107 return nil, nil, 0, NewDuplicateParameterError(assignment.Variable) 108 } 109 110 variables[i] = assignment.Variable 111 if assignment.Value == nil { 112 required = i + 1 113 } else { 114 defaults[assignment.Variable.Name] = assignment.Value 115 } 116 } 117 return variables, defaults, required, nil 118 } 119 120 func (m UserDefinedFunctionMap) CheckDuplicate(name parser.Identifier) error { 121 uname := strings.ToUpper(name.Literal) 122 123 if _, ok := Functions[uname]; ok || uname == "CALL" || uname == "NOW" || uname == "JSON_OBJECT" { 124 return NewBuiltInFunctionDeclaredError(name) 125 } 126 if _, ok := AggregateFunctions[uname]; ok { 127 return NewBuiltInFunctionDeclaredError(name) 128 } 129 if _, ok := AnalyticFunctions[uname]; ok { 130 return NewBuiltInFunctionDeclaredError(name) 131 } 132 if m.Exists(uname) { 133 return NewFunctionRedeclaredError(name) 134 } 135 return nil 136 } 137 138 func (m UserDefinedFunctionMap) Get(name string) (*UserDefinedFunction, bool) { 139 if fn, ok := m.Load(name); ok { 140 return fn, true 141 } 142 return nil, false 143 } 144 145 func (m UserDefinedFunctionMap) Dispose(name parser.Identifier) bool { 146 if m.Exists(name.Literal) { 147 m.Delete(name.Literal) 148 return true 149 } 150 return false 151 } 152 153 type UserDefinedFunction struct { 154 Name parser.Identifier 155 Statements []parser.Statement 156 Parameters []parser.Variable 157 Defaults map[string]parser.QueryExpression 158 RequiredArgs int 159 160 IsAggregate bool 161 Cursor parser.Identifier // For Aggregate Functions 162 } 163 164 func (fn *UserDefinedFunction) Execute(ctx context.Context, scope *ReferenceScope, args []value.Primary) (value.Primary, error) { 165 childScope := scope.CreateChild() 166 defer childScope.CloseCurrentBlock() 167 168 return fn.execute(ctx, childScope, args) 169 } 170 171 func (fn *UserDefinedFunction) ExecuteAggregate(ctx context.Context, scope *ReferenceScope, values []value.Primary, args []value.Primary) (value.Primary, error) { 172 childScope := scope.CreateChild() 173 defer childScope.CloseCurrentBlock() 174 175 if err := childScope.AddPseudoCursor(fn.Cursor, values); err != nil { 176 return nil, err 177 } 178 return fn.execute(ctx, childScope, args) 179 } 180 181 func (fn *UserDefinedFunction) CheckArgsLen(expr parser.QueryExpression, name string, argsLen int) error { 182 parametersLen := len(fn.Parameters) 183 requiredLen := fn.RequiredArgs 184 if fn.IsAggregate { 185 parametersLen++ 186 requiredLen++ 187 } 188 189 if len(fn.Defaults) < 1 { 190 if argsLen != len(fn.Parameters) { 191 return NewFunctionArgumentLengthError(expr, name, []int{parametersLen}) 192 } 193 } else if argsLen < fn.RequiredArgs { 194 return NewFunctionArgumentLengthErrorWithCustomArgs(expr, name, fmt.Sprintf("at least %s", FormatCount(requiredLen, "argument"))) 195 } else if len(fn.Parameters) < argsLen { 196 return NewFunctionArgumentLengthErrorWithCustomArgs(expr, name, fmt.Sprintf("at most %s", FormatCount(parametersLen, "argument"))) 197 } 198 199 return nil 200 } 201 202 func (fn *UserDefinedFunction) execute(ctx context.Context, scope *ReferenceScope, args []value.Primary) (value.Primary, error) { 203 if err := fn.CheckArgsLen(fn.Name, fn.Name.Literal, len(args)); err != nil { 204 return nil, err 205 } 206 207 for i, v := range fn.Parameters { 208 if i < len(args) { 209 if err := scope.Blocks[0].Variables.Add(v, args[i]); err != nil { 210 return nil, err 211 } 212 } else { 213 defaultValue, _ := fn.Defaults[v.Name] 214 val, err := Evaluate(ctx, scope, defaultValue) 215 if err != nil { 216 return nil, err 217 } 218 if err = scope.DeclareVariableDirectly(v, val); err != nil { 219 return nil, err 220 } 221 } 222 } 223 224 proc := NewProcessorWithScope(scope.Tx, scope) 225 if _, err := proc.execute(ctx, fn.Statements); err != nil { 226 return nil, err 227 } 228 229 ret := proc.returnVal 230 if ret == nil { 231 ret = value.NewNull() 232 } 233 234 return ret, nil 235 }