github.com/benhoyt/goawk@v1.8.1/interp/functions.go (about) 1 // Evaluate builtin and user-defined function calls 2 3 package interp 4 5 import ( 6 "bytes" 7 "errors" 8 "fmt" 9 "io" 10 "math" 11 "os/exec" 12 "reflect" 13 "sort" 14 "strconv" 15 "strings" 16 "syscall" 17 "time" 18 "unicode/utf8" 19 20 . "github.com/benhoyt/goawk/internal/ast" 21 . "github.com/benhoyt/goawk/lexer" 22 ) 23 24 // Call builtin function specified by "op" with given args 25 func (p *interp) callBuiltin(op Token, argExprs []Expr) (value, error) { 26 // split() has an array arg (not evaluated) and [g]sub() have an 27 // lvalue arg, so handle them as special cases 28 switch op { 29 case F_SPLIT: 30 strValue, err := p.eval(argExprs[0]) 31 if err != nil { 32 return null(), err 33 } 34 str := p.toString(strValue) 35 var fieldSep string 36 if len(argExprs) == 3 { 37 sepValue, err := p.eval(argExprs[2]) 38 if err != nil { 39 return null(), err 40 } 41 fieldSep = p.toString(sepValue) 42 } else { 43 fieldSep = p.fieldSep 44 } 45 arrayExpr := argExprs[1].(*ArrayExpr) 46 n, err := p.split(str, arrayExpr.Scope, arrayExpr.Index, fieldSep) 47 if err != nil { 48 return null(), err 49 } 50 return num(float64(n)), nil 51 52 case F_SUB, F_GSUB: 53 regexValue, err := p.eval(argExprs[0]) 54 if err != nil { 55 return null(), err 56 } 57 regex := p.toString(regexValue) 58 replValue, err := p.eval(argExprs[1]) 59 if err != nil { 60 return null(), err 61 } 62 repl := p.toString(replValue) 63 var in string 64 if len(argExprs) == 3 { 65 inValue, err := p.eval(argExprs[2]) 66 if err != nil { 67 return null(), err 68 } 69 in = p.toString(inValue) 70 } else { 71 in = p.line 72 } 73 out, n, err := p.sub(regex, repl, in, op == F_GSUB) 74 if err != nil { 75 return null(), err 76 } 77 if len(argExprs) == 3 { 78 err := p.assign(argExprs[2], str(out)) 79 if err != nil { 80 return null(), err 81 } 82 } else { 83 p.setLine(out) 84 } 85 return num(float64(n)), nil 86 } 87 88 // Now evaluate the argExprs (calls with up to 7 args don't 89 // require heap allocation) 90 args := make([]value, 0, 7) 91 for _, a := range argExprs { 92 arg, err := p.eval(a) 93 if err != nil { 94 return null(), err 95 } 96 args = append(args, arg) 97 } 98 99 // Then switch on the function for the ordinary functions 100 switch op { 101 case F_LENGTH: 102 switch len(args) { 103 case 0: 104 return num(float64(len(p.line))), nil 105 default: 106 return num(float64(len(p.toString(args[0])))), nil 107 } 108 109 case F_MATCH: 110 re, err := p.compileRegex(p.toString(args[1])) 111 if err != nil { 112 return null(), err 113 } 114 loc := re.FindStringIndex(p.toString(args[0])) 115 if loc == nil { 116 p.matchStart = 0 117 p.matchLength = -1 118 return num(0), nil 119 } 120 p.matchStart = loc[0] + 1 121 p.matchLength = loc[1] - loc[0] 122 return num(float64(p.matchStart)), nil 123 124 case F_SUBSTR: 125 s := p.toString(args[0]) 126 pos := int(args[1].num()) 127 if pos > len(s) { 128 pos = len(s) + 1 129 } 130 if pos < 1 { 131 pos = 1 132 } 133 maxLength := len(s) - pos + 1 134 length := maxLength 135 if len(args) == 3 { 136 length = int(args[2].num()) 137 if length < 0 { 138 length = 0 139 } 140 if length > maxLength { 141 length = maxLength 142 } 143 } 144 return str(s[pos-1 : pos-1+length]), nil 145 146 case F_SPRINTF: 147 s, err := p.sprintf(p.toString(args[0]), args[1:]) 148 if err != nil { 149 return null(), err 150 } 151 return str(s), nil 152 153 case F_INDEX: 154 s := p.toString(args[0]) 155 substr := p.toString(args[1]) 156 return num(float64(strings.Index(s, substr) + 1)), nil 157 158 case F_TOLOWER: 159 return str(strings.ToLower(p.toString(args[0]))), nil 160 case F_TOUPPER: 161 return str(strings.ToUpper(p.toString(args[0]))), nil 162 163 case F_ATAN2: 164 return num(math.Atan2(args[0].num(), args[1].num())), nil 165 case F_COS: 166 return num(math.Cos(args[0].num())), nil 167 case F_EXP: 168 return num(math.Exp(args[0].num())), nil 169 case F_INT: 170 return num(float64(int(args[0].num()))), nil 171 case F_LOG: 172 return num(math.Log(args[0].num())), nil 173 case F_SQRT: 174 return num(math.Sqrt(args[0].num())), nil 175 case F_RAND: 176 return num(p.random.Float64()), nil 177 case F_SIN: 178 return num(math.Sin(args[0].num())), nil 179 180 case F_SRAND: 181 prevSeed := p.randSeed 182 switch len(args) { 183 case 0: 184 p.random.Seed(time.Now().UnixNano()) 185 case 1: 186 p.randSeed = args[0].num() 187 p.random.Seed(int64(math.Float64bits(p.randSeed))) 188 } 189 return num(prevSeed), nil 190 191 case F_SYSTEM: 192 if p.noExec { 193 return null(), newError("can't call system() due to NoExec") 194 } 195 cmdline := p.toString(args[0]) 196 cmd := exec.Command("sh", "-c", cmdline) 197 cmd.Stdout = p.output 198 cmd.Stderr = p.errorOutput 199 err := cmd.Start() 200 if err != nil { 201 fmt.Fprintln(p.errorOutput, err) 202 return num(-1), nil 203 } 204 err = cmd.Wait() 205 if err != nil { 206 if exitErr, ok := err.(*exec.ExitError); ok { 207 if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { 208 return num(float64(status.ExitStatus())), nil 209 } else { 210 fmt.Fprintf(p.errorOutput, "couldn't get exit status for %q: %v\n", cmdline, err) 211 return num(-1), nil 212 } 213 } else { 214 fmt.Fprintf(p.errorOutput, "unexpected error running command %q: %v\n", cmdline, err) 215 return num(-1), nil 216 } 217 } 218 return num(0), nil 219 220 case F_CLOSE: 221 name := p.toString(args[0]) 222 var c io.Closer = p.inputStreams[name] 223 if c != nil { 224 // Close input stream 225 delete(p.inputStreams, name) 226 err := c.Close() 227 if err != nil { 228 return num(-1), nil 229 } 230 return num(0), nil 231 } 232 c = p.outputStreams[name] 233 if c != nil { 234 // Close output stream 235 delete(p.outputStreams, name) 236 err := c.Close() 237 if err != nil { 238 return num(-1), nil 239 } 240 return num(0), nil 241 } 242 // Nothing to close 243 return num(-1), nil 244 245 case F_FFLUSH: 246 var name string 247 if len(args) > 0 { 248 name = p.toString(args[0]) 249 } 250 var ok bool 251 if name != "" { 252 // Flush a single, named output stream 253 ok = p.flushStream(name) 254 } else { 255 // fflush() or fflush("") flushes all output streams 256 ok = p.flushAll() 257 } 258 if !ok { 259 return num(-1), nil 260 } 261 return num(0), nil 262 263 default: 264 // Shouldn't happen 265 panic(fmt.Sprintf("unexpected function: %s", op)) 266 } 267 } 268 269 // Call user-defined function with given index and arguments, return 270 // its return value (or null value if it doesn't return anything) 271 func (p *interp) callUser(index int, args []Expr) (value, error) { 272 f := p.program.Functions[index] 273 274 if p.callDepth >= maxCallDepth { 275 return null(), newError("calling %q exceeded maximum call depth of %d", f.Name, maxCallDepth) 276 } 277 278 // Evaluate the arguments and push them onto the locals stack 279 oldFrame := p.frame 280 newFrameStart := len(p.stack) 281 var arrays []int 282 for i, arg := range args { 283 if f.Arrays[i] { 284 a := arg.(*VarExpr) 285 arrays = append(arrays, p.getArrayIndex(a.Scope, a.Index)) 286 } else { 287 argValue, err := p.eval(arg) 288 if err != nil { 289 return null(), err 290 } 291 p.stack = append(p.stack, argValue) 292 } 293 } 294 // Push zero value for any additional parameters (it's valid to 295 // call a function with fewer arguments than it has parameters) 296 oldArraysLen := len(p.arrays) 297 for i := len(args); i < len(f.Params); i++ { 298 if f.Arrays[i] { 299 arrays = append(arrays, len(p.arrays)) 300 p.arrays = append(p.arrays, make(map[string]value)) 301 } else { 302 p.stack = append(p.stack, null()) 303 } 304 } 305 p.frame = p.stack[newFrameStart:] 306 p.localArrays = append(p.localArrays, arrays) 307 308 // Execute the function! 309 p.callDepth++ 310 err := p.executes(f.Body) 311 p.callDepth-- 312 313 // Pop the locals off the stack 314 p.stack = p.stack[:newFrameStart] 315 p.frame = oldFrame 316 p.localArrays = p.localArrays[:len(p.localArrays)-1] 317 p.arrays = p.arrays[:oldArraysLen] 318 319 if r, ok := err.(returnValue); ok { 320 return r.Value, nil 321 } 322 if err != nil { 323 return null(), err 324 } 325 return null(), nil 326 } 327 328 // Call native-defined function with given name and arguments, return 329 // return value (or null value if it doesn't return anything). 330 func (p *interp) callNative(index int, args []Expr) (value, error) { 331 f := p.nativeFuncs[index] 332 minIn := len(f.in) // Mininum number of args we should pass 333 var variadicType reflect.Type 334 if f.isVariadic { 335 variadicType = f.in[len(f.in)-1].Elem() 336 minIn-- 337 } 338 339 // Build list of args to pass to function 340 values := make([]reflect.Value, 0, 7) // up to 7 args won't require heap allocation 341 for i, arg := range args { 342 a, err := p.eval(arg) 343 if err != nil { 344 return null(), err 345 } 346 var argType reflect.Type 347 if !f.isVariadic || i < len(f.in)-1 { 348 argType = f.in[i] 349 } else { 350 // Final arg(s) when calling a variadic are all of this type 351 argType = variadicType 352 } 353 values = append(values, p.toNative(a, argType)) 354 } 355 // Use zero value for any unspecified args 356 for i := len(args); i < minIn; i++ { 357 values = append(values, reflect.Zero(f.in[i])) 358 } 359 360 // Call Go function, determine return value 361 outs := f.value.Call(values) 362 switch len(outs) { 363 case 0: 364 // No return value, return null value to AWK 365 return null(), nil 366 case 1: 367 // Single return value 368 return fromNative(outs[0]), nil 369 case 2: 370 // Two-valued return of (scalar, error) 371 if !outs[1].IsNil() { 372 return null(), outs[1].Interface().(error) 373 } 374 return fromNative(outs[0]), nil 375 default: 376 // Should never happen (checked at parse time) 377 panic(fmt.Sprintf("unexpected number of return values: %d", len(outs))) 378 } 379 } 380 381 // Convert from an AWK value to a native Go value 382 func (p *interp) toNative(v value, typ reflect.Type) reflect.Value { 383 switch typ.Kind() { 384 case reflect.Bool: 385 return reflect.ValueOf(v.boolean()) 386 case reflect.Int: 387 return reflect.ValueOf(int(v.num())) 388 case reflect.Int8: 389 return reflect.ValueOf(int8(v.num())) 390 case reflect.Int16: 391 return reflect.ValueOf(int16(v.num())) 392 case reflect.Int32: 393 return reflect.ValueOf(int32(v.num())) 394 case reflect.Int64: 395 return reflect.ValueOf(int64(v.num())) 396 case reflect.Uint: 397 return reflect.ValueOf(uint(v.num())) 398 case reflect.Uint8: 399 return reflect.ValueOf(uint8(v.num())) 400 case reflect.Uint16: 401 return reflect.ValueOf(uint16(v.num())) 402 case reflect.Uint32: 403 return reflect.ValueOf(uint32(v.num())) 404 case reflect.Uint64: 405 return reflect.ValueOf(uint64(v.num())) 406 case reflect.Float32: 407 return reflect.ValueOf(float32(v.num())) 408 case reflect.Float64: 409 return reflect.ValueOf(v.num()) 410 case reflect.String: 411 return reflect.ValueOf(p.toString(v)) 412 case reflect.Slice: 413 if typ.Elem().Kind() != reflect.Uint8 { 414 // Shouldn't happen: prevented by checkNativeFunc 415 panic(fmt.Sprintf("unexpected argument slice: %s", typ.Elem().Kind())) 416 } 417 return reflect.ValueOf([]byte(p.toString(v))) 418 default: 419 // Shouldn't happen: prevented by checkNativeFunc 420 panic(fmt.Sprintf("unexpected argument type: %s", typ.Kind())) 421 } 422 } 423 424 // Convert from a native Go value to an AWK value 425 func fromNative(v reflect.Value) value { 426 switch v.Kind() { 427 case reflect.Bool: 428 return boolean(v.Bool()) 429 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 430 return num(float64(v.Int())) 431 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 432 return num(float64(v.Uint())) 433 case reflect.Float32, reflect.Float64: 434 return num(v.Float()) 435 case reflect.String: 436 return str(v.String()) 437 case reflect.Slice: 438 if b, ok := v.Interface().([]byte); ok { 439 return str(string(b)) 440 } 441 // Shouldn't happen: prevented by checkNativeFunc 442 panic(fmt.Sprintf("unexpected return slice: %s", v.Type().Elem().Kind())) 443 default: 444 // Shouldn't happen: prevented by checkNativeFunc 445 panic(fmt.Sprintf("unexpected return type: %s", v.Kind())) 446 } 447 } 448 449 // Used for caching native function type information on init 450 type nativeFunc struct { 451 isVariadic bool 452 in []reflect.Type 453 value reflect.Value 454 } 455 456 // Check and initialize native functions 457 func (p *interp) initNativeFuncs(funcs map[string]interface{}) error { 458 for name, f := range funcs { 459 err := checkNativeFunc(name, f) 460 if err != nil { 461 return err 462 } 463 } 464 465 // Sort functions by name, then use those indexes to build slice 466 // (this has to match how the parser sets the indexes). 467 names := make([]string, 0, len(funcs)) 468 for name := range funcs { 469 names = append(names, name) 470 } 471 sort.Strings(names) 472 p.nativeFuncs = make([]nativeFunc, len(names)) 473 for i, name := range names { 474 f := funcs[name] 475 typ := reflect.TypeOf(f) 476 in := make([]reflect.Type, typ.NumIn()) 477 for j := 0; j < len(in); j++ { 478 in[j] = typ.In(j) 479 } 480 p.nativeFuncs[i] = nativeFunc{ 481 isVariadic: typ.IsVariadic(), 482 in: in, 483 value: reflect.ValueOf(f), 484 } 485 } 486 return nil 487 } 488 489 // Got this trick from the Go stdlib text/template source 490 var errorType = reflect.TypeOf((*error)(nil)).Elem() 491 492 // Check that native function with given name is okay to call from 493 // AWK, return a *interp.Error if not. This checks that f is actually 494 // a function, and that its parameter and return types are good. 495 func checkNativeFunc(name string, f interface{}) error { 496 if KeywordToken(name) != ILLEGAL { 497 return newError("can't use keyword %q as native function name", name) 498 } 499 500 typ := reflect.TypeOf(f) 501 if typ.Kind() != reflect.Func { 502 return newError("native function %q is not a function", name) 503 } 504 for i := 0; i < typ.NumIn(); i++ { 505 param := typ.In(i) 506 if typ.IsVariadic() && i == typ.NumIn()-1 { 507 param = param.Elem() 508 } 509 if !validNativeType(param) { 510 return newError("native function %q param %d is not int or string", name, i) 511 } 512 } 513 514 switch typ.NumOut() { 515 case 0: 516 // No return value is fine 517 case 1: 518 // Single scalar return value is fine 519 if !validNativeType(typ.Out(0)) { 520 return newError("native function %q return value is not int or string", name) 521 } 522 case 2: 523 // Returning (scalar, error) is handled too 524 if !validNativeType(typ.Out(0)) { 525 return newError("native function %q first return value is not int or string", name) 526 } 527 if typ.Out(1) != errorType { 528 return newError("native function %q second return value is not an error", name) 529 } 530 default: 531 return newError("native function %q returns more than two values", name) 532 } 533 return nil 534 } 535 536 // Return true if typ is a valid parameter or return type. 537 func validNativeType(typ reflect.Type) bool { 538 switch typ.Kind() { 539 case reflect.Bool: 540 return true 541 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 542 return true 543 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 544 return true 545 case reflect.Float32, reflect.Float64: 546 return true 547 case reflect.String: 548 return true 549 case reflect.Slice: 550 // Only allow []byte (convert to string in AWK) 551 return typ.Elem().Kind() == reflect.Uint8 552 default: 553 return false 554 } 555 } 556 557 // Guts of the split() function 558 func (p *interp) split(s string, scope VarScope, index int, fs string) (int, error) { 559 var parts []string 560 if fs == " " { 561 parts = strings.Fields(s) 562 } else if s == "" { 563 // NF should be 0 on empty line 564 } else if utf8.RuneCountInString(fs) <= 1 { 565 parts = strings.Split(s, fs) 566 } else { 567 re, err := p.compileRegex(fs) 568 if err != nil { 569 return 0, err 570 } 571 parts = re.Split(s, -1) 572 } 573 array := make(map[string]value, len(parts)) 574 for i, part := range parts { 575 array[strconv.Itoa(i+1)] = numStr(part) 576 } 577 p.arrays[p.getArrayIndex(scope, index)] = array 578 return len(array), nil 579 } 580 581 // Guts of the sub() and gsub() functions 582 func (p *interp) sub(regex, repl, in string, global bool) (out string, num int, err error) { 583 re, err := p.compileRegex(regex) 584 if err != nil { 585 return "", 0, err 586 } 587 count := 0 588 out = re.ReplaceAllStringFunc(in, func(s string) string { 589 // Only do the first replacement for sub(), or all for gsub() 590 if !global && count > 0 { 591 return s 592 } 593 count++ 594 // Handle & (ampersand) properly in replacement string 595 r := make([]byte, 0, 64) // Up to 64 byte replacement won't require heap allocation 596 for i := 0; i < len(repl); i++ { 597 switch repl[i] { 598 case '&': 599 r = append(r, s...) 600 case '\\': 601 i++ 602 if i < len(repl) { 603 switch repl[i] { 604 case '&': 605 r = append(r, '&') 606 case '\\': 607 r = append(r, '\\') 608 default: 609 r = append(r, '\\', repl[i]) 610 } 611 } else { 612 r = append(r, '\\') 613 } 614 default: 615 r = append(r, repl[i]) 616 } 617 } 618 return string(r) 619 }) 620 return out, count, nil 621 } 622 623 type cachedFormat struct { 624 format string 625 types []byte 626 } 627 628 // Parse given sprintf format string into Go format string, along with 629 // type conversion specifiers. Output is memoized in a simple cache 630 // for performance. 631 func (p *interp) parseFmtTypes(s string) (format string, types []byte, err error) { 632 if item, ok := p.formatCache[s]; ok { 633 return item.format, item.types, nil 634 } 635 636 out := []byte(s) 637 for i := 0; i < len(s); i++ { 638 if s[i] == '%' { 639 i++ 640 if i >= len(s) { 641 return "", nil, errors.New("expected type specifier after %") 642 } 643 if s[i] == '%' { 644 continue 645 } 646 for i < len(s) && bytes.IndexByte([]byte(" .-+*#0123456789"), s[i]) >= 0 { 647 if s[i] == '*' { 648 types = append(types, 'd') 649 } 650 i++ 651 } 652 if i >= len(s) { 653 return "", nil, errors.New("expected type specifier after %") 654 } 655 var t byte 656 switch s[i] { 657 case 's': 658 t = 's' 659 case 'd', 'i', 'o', 'x', 'X': 660 t = 'd' 661 case 'f', 'e', 'E', 'g', 'G': 662 t = 'f' 663 case 'u': 664 t = 'u' 665 out[i] = 'd' 666 case 'c': 667 t = 'c' 668 out[i] = 's' 669 default: 670 return "", nil, fmt.Errorf("invalid format type %q", s[i]) 671 } 672 types = append(types, t) 673 } 674 } 675 676 // Dumb, non-LRU cache: just cache the first N formats 677 format = string(out) 678 if len(p.formatCache) < maxCachedFormats { 679 p.formatCache[s] = cachedFormat{format, types} 680 } 681 return format, types, nil 682 } 683 684 // Guts of sprintf() function (also used by "printf" statement) 685 func (p *interp) sprintf(format string, args []value) (string, error) { 686 format, types, err := p.parseFmtTypes(format) 687 if err != nil { 688 return "", newError("format error: %s", err) 689 } 690 if len(types) > len(args) { 691 return "", newError("format error: got %d args, expected %d", len(args), len(types)) 692 } 693 converted := make([]interface{}, len(types)) 694 for i, t := range types { 695 a := args[i] 696 var v interface{} 697 switch t { 698 case 's': 699 v = p.toString(a) 700 case 'd': 701 v = int(a.num()) 702 case 'f': 703 v = a.num() 704 case 'u': 705 v = uint32(a.num()) 706 case 'c': 707 var c []byte 708 if a.isTrueStr() { 709 s := p.toString(a) 710 if len(s) > 0 { 711 c = []byte{s[0]} 712 } else { 713 c = []byte{0} 714 } 715 } else { 716 // Follow the behaviour of awk and mawk, where %c 717 // operates on bytes (0-255), not Unicode codepoints 718 c = []byte{byte(a.num())} 719 } 720 v = c 721 } 722 converted[i] = v 723 } 724 return fmt.Sprintf(format, converted...), nil 725 }