src.elv.sh@v0.21.0-dev.0.20240515223629-06979efb9a2a/pkg/eval/node_utils.go (about)

     1  package eval
     2  
     3  import (
     4  	"src.elv.sh/pkg/diag"
     5  	"src.elv.sh/pkg/parse"
     6  	"src.elv.sh/pkg/parse/cmpd"
     7  )
     8  
     9  // Utilities for working with nodes.
    10  
    11  func stringLiteralOrError(cp *compiler, n *parse.Compound, what string) string {
    12  	s, err := cmpd.StringLiteralOrError(n, what)
    13  	if err != nil {
    14  		cp.errorpf(n, "%v", err)
    15  	}
    16  	return s
    17  }
    18  
    19  type argsGetter struct {
    20  	cp *compiler
    21  	fn *parse.Form
    22  	ok bool
    23  	n  int
    24  }
    25  
    26  func getArgs(cp *compiler, fn *parse.Form) *argsGetter {
    27  	return &argsGetter{cp, fn, true, 0}
    28  }
    29  
    30  func (ag *argsGetter) errorpf(r diag.Ranger, format string, args ...any) {
    31  	if ag.ok {
    32  		ag.cp.errorpf(r, format, args...)
    33  		ag.ok = false
    34  	}
    35  }
    36  
    37  func (ag *argsGetter) get(i int, what string) *argAsserter {
    38  	if ag.n < i+1 {
    39  		ag.n = i + 1
    40  	}
    41  	if i >= len(ag.fn.Args) {
    42  		ag.errorpf(diag.PointRanging(ag.fn.To), "need %s", what)
    43  		return &argAsserter{ag, what, nil}
    44  	}
    45  	return &argAsserter{ag, what, ag.fn.Args[i]}
    46  }
    47  
    48  func (ag *argsGetter) has(i int) bool { return i < len(ag.fn.Args) }
    49  
    50  func (ag *argsGetter) hasKeyword(i int, kw string) bool {
    51  	if i < len(ag.fn.Args) {
    52  		s, ok := cmpd.StringLiteral(ag.fn.Args[i])
    53  		return ok && s == kw
    54  	}
    55  	return false
    56  }
    57  
    58  func (ag *argsGetter) optionalKeywordBody(i int, kw string) *parse.Primary {
    59  	if ag.has(i+1) && ag.hasKeyword(i, kw) {
    60  		return ag.get(i+1, kw+" body").thunk()
    61  	}
    62  	return nil
    63  }
    64  
    65  func (ag *argsGetter) finish() bool {
    66  	if ag.n < len(ag.fn.Args) {
    67  		ag.errorpf(
    68  			diag.Ranging{From: ag.fn.Args[ag.n].Range().From, To: ag.fn.To},
    69  			"superfluous arguments")
    70  	}
    71  	return ag.ok
    72  }
    73  
    74  type argAsserter struct {
    75  	ag   *argsGetter
    76  	what string
    77  	node *parse.Compound
    78  }
    79  
    80  func (aa *argAsserter) any() *parse.Compound {
    81  	return aa.node
    82  }
    83  
    84  func (aa *argAsserter) stringLiteral() string {
    85  	if aa.node == nil {
    86  		return ""
    87  	}
    88  	s, err := cmpd.StringLiteralOrError(aa.node, aa.what)
    89  	if err != nil {
    90  		aa.ag.errorpf(aa.node, "%v", err)
    91  		return ""
    92  	}
    93  	return s
    94  }
    95  
    96  func (aa *argAsserter) lambda() *parse.Primary {
    97  	if aa.node == nil {
    98  		return nil
    99  	}
   100  	lambda, ok := cmpd.Lambda(aa.node)
   101  	if !ok {
   102  		aa.ag.errorpf(aa.node,
   103  			"%s must be lambda, found %s", aa.what, cmpd.Shape(aa.node))
   104  		return nil
   105  	}
   106  	return lambda
   107  }
   108  
   109  func (aa *argAsserter) thunk() *parse.Primary {
   110  	lambda := aa.lambda()
   111  	if lambda == nil {
   112  		return nil
   113  	}
   114  	if len(lambda.Elements) > 0 {
   115  		aa.ag.errorpf(lambda, "%s must not have arguments", aa.what)
   116  		return nil
   117  	}
   118  	if len(lambda.MapPairs) > 0 {
   119  		aa.ag.errorpf(lambda, "%s must not have options", aa.what)
   120  		return nil
   121  	}
   122  	return lambda
   123  }