github.com/nevalang/neva@v0.23.1-0.20240507185603-7696a9bb8dda/internal/compiler/sourcecode/typesystem/resolver.go (about)

     1  package typesystem
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  
     7  	"github.com/nevalang/neva/internal/compiler/sourcecode/core"
     8  )
     9  
    10  var (
    11  	ErrInvalidExpr        = errors.New("expression must be valid in order to be resolved")
    12  	ErrScope              = errors.New("can't get type def from scope by ref")
    13  	ErrScopeUpdate        = errors.New("scope update")
    14  	ErrInstArgsCount      = errors.New("Wrong number of type arguments")
    15  	ErrIncompatArg        = errors.New("argument is not subtype of the parameter's contraint")
    16  	ErrUnresolvedArg      = errors.New("can't resolve argument")
    17  	ErrConstr             = errors.New("can't resolve constraint")
    18  	ErrArrType            = errors.New("could not resolve array type")
    19  	ErrUnionUnresolvedEl  = errors.New("can't resolve union element")
    20  	ErrRecFieldUnresolved = errors.New("can't resolve struct field")
    21  	ErrInvalidDef         = errors.New("invalid definition")
    22  	ErrTerminator         = errors.New("recursion terminator")
    23  )
    24  
    25  // Resolver transforms expression it into a form where all references it contains points to resolved expressions.
    26  type Resolver struct {
    27  	validator  exprValidator       // Check if expression invalid before resolving it
    28  	checker    subtypeChecker      // Compare arguments with constraints
    29  	terminator recursionTerminator // Don't stuck in a loop
    30  }
    31  
    32  //go:generate mockgen -source $GOFILE -destination mocks_test.go -package ${GOPACKAGE}_test
    33  type (
    34  	exprValidator interface {
    35  		Validate(expr Expr) error
    36  		ValidateDef(def Def) error
    37  	}
    38  	subtypeChecker interface {
    39  		Check(sub Expr, sup Expr, params TerminatorParams) error
    40  	}
    41  	recursionTerminator interface {
    42  		ShouldTerminate(trace Trace, scope Scope) (bool, error)
    43  	}
    44  	Scope interface {
    45  		GetType(ref core.EntityRef) (Def, Scope, error)
    46  		IsTopType(expr Expr) bool
    47  	}
    48  )
    49  
    50  // ResolveExpr resolves given expression using only global scope.
    51  func (r Resolver) ResolveExpr(expr Expr, scope Scope) (Expr, error) {
    52  	return r.resolveExpr(expr, scope, map[string]Def{}, nil)
    53  }
    54  
    55  // ResolveExprWithFrame works like ResolveExpr but allows to pass local scope.
    56  func (r Resolver) ResolveExprWithFrame(
    57  	expr Expr,
    58  	frame map[string]Def,
    59  	scope Scope,
    60  ) (
    61  	Expr,
    62  	error,
    63  ) {
    64  	return r.resolveExpr(expr, scope, frame, nil)
    65  }
    66  
    67  // ResolveExprWithFrame works like ResolveExprWithFrame but for list of expressions.
    68  func (r Resolver) ResolveExprsWithFrame(
    69  	exprs []Expr,
    70  	frame map[string]Def,
    71  	scope Scope,
    72  ) (
    73  	[]Expr,
    74  	error,
    75  ) {
    76  	resolvedExprs := make([]Expr, 0, len(exprs))
    77  	for _, expr := range exprs {
    78  		resolvedExpr, err := r.resolveExpr(expr, scope, frame, nil)
    79  		if err != nil {
    80  			return nil, err
    81  		}
    82  		resolvedExprs = append(resolvedExprs, resolvedExpr)
    83  	}
    84  	return resolvedExprs, nil
    85  }
    86  
    87  // ResolveParams resolves every constraint in given parameter list.
    88  func (r Resolver) ResolveParams(
    89  	params []Param,
    90  	scope Scope,
    91  ) (
    92  	[]Param, // resolved parameters
    93  	map[string]Def, // resolved frame `paramName:resolvedConstr`
    94  	error,
    95  ) {
    96  	result := make([]Param, 0, len(params))
    97  	frame := make(map[string]Def, len(params))
    98  	for _, param := range params {
    99  		resolved, err := r.resolveExpr(param.Constr, scope, frame, nil)
   100  		if err != nil {
   101  			return nil, nil, fmt.Errorf("resolve expr: %w", err)
   102  		}
   103  		frame[param.Name] = Def{BodyExpr: &resolved}
   104  		result = append(result, Param{
   105  			Name:   param.Name,
   106  			Constr: resolved,
   107  		})
   108  	}
   109  	return result, frame, nil
   110  }
   111  
   112  // IsSubtypeOf resolves both `sub` and `sup` expressions
   113  // and returns error if `sub` is not subtype of `sup`.
   114  func (r Resolver) IsSubtypeOf(sub, sup Expr, scope Scope) error {
   115  	resolvedSub, err := r.resolveExpr(sub, scope, nil, nil)
   116  	if err != nil {
   117  		return fmt.Errorf("resolve sub expr: %w", err)
   118  	}
   119  	resolvedSup, err := r.resolveExpr(sup, scope, nil, nil)
   120  	if err != nil {
   121  		return fmt.Errorf("resolve sup expr: %w", err)
   122  	}
   123  	return r.checker.Check(
   124  		resolvedSub,
   125  		resolvedSup,
   126  		TerminatorParams{Scope: scope},
   127  	)
   128  }
   129  
   130  // CheckArgsCompatibility resolves args
   131  // and params and then checks their compatibility.
   132  func (r Resolver) CheckArgsCompatibility(args []Expr, params []Param, scope Scope) error {
   133  	if len(args) != len(params) {
   134  		return fmt.Errorf(
   135  			"count of arguments mismatch count of parameters, want %d got %d",
   136  			len(params),
   137  			len(args),
   138  		)
   139  	}
   140  
   141  	for i := range params {
   142  		arg := args[i]
   143  		param := params[i]
   144  
   145  		resolvedSub, err := r.resolveExpr(arg, scope, nil, nil)
   146  		if err != nil {
   147  			return fmt.Errorf("resolve arg expr: %w", err)
   148  		}
   149  
   150  		resolvedSup, err := r.resolveExpr(param.Constr, scope, nil, nil)
   151  		if err != nil {
   152  			return fmt.Errorf("resolve param constr expr: %w", err)
   153  		}
   154  
   155  		if err := r.checker.Check(
   156  			resolvedSub,
   157  			resolvedSup,
   158  			TerminatorParams{Scope: scope},
   159  		); err != nil {
   160  			return err
   161  		}
   162  	}
   163  
   164  	return nil
   165  }
   166  
   167  // resolveExpr turn one expression into another where all references points to native types.
   168  // It's a recursive process where each step starts with validation. Invalid expression always leads to error.
   169  // For inst expr it checks compatibility between args and params and returns error if some constraint isn't satisfied.
   170  // Then it updates scope by adding params of ref type with resolved args as values to allow substitution later.
   171  // Then it checks whether base type of current ref type is native type to terminate with nil err and resolved expr.
   172  // For non-native types process starts from the beginning with updated scope. New scope will contain values for params.
   173  // For lit exprs logic is the this: for enum do nothing (it's valid and not composite, there's nothing to resolveExpr),
   174  // for array resolveExpr it's type, for struct and union apply recursion for it's every field/element.
   175  func (r Resolver) resolveExpr( //nolint:funlen,gocognit
   176  	expr Expr, // expression to be resolved
   177  	scope Scope, // global scope
   178  	frame map[string]Def, // local scope
   179  	trace *Trace, // how did we get here
   180  ) (Expr, error) {
   181  	if err := r.validator.Validate(expr); err != nil {
   182  		return Expr{}, fmt.Errorf("%w: %v", ErrInvalidExpr, err)
   183  	}
   184  
   185  	if expr.Lit != nil {
   186  		switch expr.Lit.Type() {
   187  		case EnumLitType:
   188  			return expr, nil
   189  		case UnionLitType:
   190  			resolvedUnion := make([]Expr, 0, len(expr.Lit.Union))
   191  			for _, unionEl := range expr.Lit.Union {
   192  				resolvedEl, err := r.resolveExpr(unionEl, scope, frame, trace)
   193  				if err != nil {
   194  					return Expr{}, fmt.Errorf("%w: %v", ErrUnionUnresolvedEl, err)
   195  				}
   196  				resolvedUnion = append(resolvedUnion, resolvedEl)
   197  			}
   198  			return Expr{
   199  				Lit: &LitExpr{Union: resolvedUnion},
   200  			}, nil
   201  		case StructLitType:
   202  			resolvedStruct := make(map[string]Expr, len(expr.Lit.Struct))
   203  			for field, fieldExpr := range expr.Lit.Struct {
   204  				// we create new trace with virtual ref "struct" (it's safe because it's reserved word)
   205  				// otherwise expressions like `error struct {child maybe<error>}` will be direct recursive for terminator
   206  				newTrace := Trace{
   207  					prev: trace,
   208  					cur:  core.EntityRef{Name: "struct"},
   209  				}
   210  				resolvedFieldExpr, err := r.resolveExpr(
   211  					fieldExpr,
   212  					scope,
   213  					frame,
   214  					&newTrace,
   215  				)
   216  				if err != nil {
   217  					return Expr{}, fmt.Errorf(
   218  						"%w: %v: %v",
   219  						ErrRecFieldUnresolved,
   220  						field,
   221  						err,
   222  					)
   223  				}
   224  				resolvedStruct[field] = resolvedFieldExpr
   225  			}
   226  			return Expr{
   227  				Lit: &LitExpr{Struct: resolvedStruct},
   228  			}, nil
   229  		}
   230  	}
   231  
   232  	def, scopeWhereDefFound, err := r.getDef(expr.Inst.Ref, frame, scope)
   233  	if err != nil {
   234  		return Expr{}, err
   235  	}
   236  
   237  	if err := r.validator.ValidateDef(def); err != nil {
   238  		return Expr{}, errors.Join(ErrInvalidDef, err)
   239  	}
   240  
   241  	if len(def.Params) != len(expr.Inst.Args) { // args must not be > than params to avoid bad case with constraint
   242  		return Expr{}, fmt.Errorf(
   243  			"%w for '%v': want %d, got %d",
   244  			ErrInstArgsCount,
   245  			expr.Inst.Ref,
   246  			len(def.Params),
   247  			len(expr.Inst.Args),
   248  		)
   249  	}
   250  
   251  	newTrace := Trace{
   252  		prev: trace,
   253  		cur:  expr.Inst.Ref, // FIXME t1
   254  	}
   255  
   256  	shouldReturn, err := r.terminator.ShouldTerminate(newTrace, scope)
   257  	if err != nil {
   258  		return Expr{}, fmt.Errorf("%w: %v", ErrTerminator, err)
   259  	} else if shouldReturn {
   260  		return expr, nil
   261  	}
   262  
   263  	newFrame := make(map[string]Def, len(def.Params))
   264  	resolvedArgs := make([]Expr, 0, len(expr.Inst.Args))
   265  	for i, param := range def.Params { // resolve args and constrs and check their compatibility
   266  		resolvedArg, err := r.resolveExpr(expr.Inst.Args[i], scope, frame, &newTrace)
   267  		if err != nil {
   268  			return Expr{}, fmt.Errorf("%w: %v", ErrUnresolvedArg, err)
   269  		}
   270  
   271  		newFrame[param.Name] = Def{BodyExpr: &resolvedArg} // no params for generics
   272  		resolvedArgs = append(resolvedArgs, resolvedArg)
   273  
   274  		// we pass newFrame because constr can refer to type parameters
   275  		resolvedConstr, err := r.resolveExpr(
   276  			param.Constr,
   277  			scope,
   278  			newFrame,
   279  			&newTrace,
   280  		)
   281  		if err != nil {
   282  			return Expr{}, fmt.Errorf("%w: %v", ErrConstr, err)
   283  		}
   284  
   285  		params := TerminatorParams{
   286  			Scope:          scope,
   287  			SubtypeTrace:   newTrace,
   288  			SupertypeTrace: newTrace,
   289  		}
   290  
   291  		if err := r.checker.Check(resolvedArg, resolvedConstr, params); err != nil {
   292  			return Expr{}, fmt.Errorf(" %w: %v", ErrIncompatArg, err)
   293  		}
   294  	}
   295  
   296  	if def.BodyExpr == nil {
   297  		return Expr{
   298  			Inst: &InstExpr{
   299  				Ref:  expr.Inst.Ref,
   300  				Args: resolvedArgs,
   301  			},
   302  		}, nil
   303  	}
   304  
   305  	return r.resolveExpr(*def.BodyExpr, scopeWhereDefFound, newFrame, &newTrace)
   306  }
   307  
   308  func (Resolver) getDef(
   309  	ref core.EntityRef,
   310  	frame map[string]Def,
   311  	scope Scope,
   312  ) (Def, Scope, error) {
   313  	strRef := ref.String()
   314  	def, exist := frame[strRef]
   315  	if exist {
   316  		return def, scope, nil
   317  	}
   318  
   319  	def, scope, err := scope.GetType(ref)
   320  	if err != nil {
   321  		return Def{}, nil, fmt.Errorf("%w: %v", ErrScope, err)
   322  	}
   323  
   324  	return def, scope, nil
   325  }
   326  
   327  func MustNewResolver(validator exprValidator, checker subtypeChecker, terminator recursionTerminator) Resolver {
   328  	if validator == nil || checker == nil || terminator == nil {
   329  		panic("all arguments must be not nil")
   330  	}
   331  	return Resolver{validator, checker, terminator}
   332  }