github.com/nevalang/neva@v0.23.1-0.20240507185603-7696a9bb8dda/internal/compiler/sourcecode/typesystem/resolver.go (about) 1 package typesystem 2 3 import ( 4 "errors" 5 "fmt" 6 7 "github.com/nevalang/neva/internal/compiler/sourcecode/core" 8 ) 9 10 var ( 11 ErrInvalidExpr = errors.New("expression must be valid in order to be resolved") 12 ErrScope = errors.New("can't get type def from scope by ref") 13 ErrScopeUpdate = errors.New("scope update") 14 ErrInstArgsCount = errors.New("Wrong number of type arguments") 15 ErrIncompatArg = errors.New("argument is not subtype of the parameter's contraint") 16 ErrUnresolvedArg = errors.New("can't resolve argument") 17 ErrConstr = errors.New("can't resolve constraint") 18 ErrArrType = errors.New("could not resolve array type") 19 ErrUnionUnresolvedEl = errors.New("can't resolve union element") 20 ErrRecFieldUnresolved = errors.New("can't resolve struct field") 21 ErrInvalidDef = errors.New("invalid definition") 22 ErrTerminator = errors.New("recursion terminator") 23 ) 24 25 // Resolver transforms expression it into a form where all references it contains points to resolved expressions. 26 type Resolver struct { 27 validator exprValidator // Check if expression invalid before resolving it 28 checker subtypeChecker // Compare arguments with constraints 29 terminator recursionTerminator // Don't stuck in a loop 30 } 31 32 //go:generate mockgen -source $GOFILE -destination mocks_test.go -package ${GOPACKAGE}_test 33 type ( 34 exprValidator interface { 35 Validate(expr Expr) error 36 ValidateDef(def Def) error 37 } 38 subtypeChecker interface { 39 Check(sub Expr, sup Expr, params TerminatorParams) error 40 } 41 recursionTerminator interface { 42 ShouldTerminate(trace Trace, scope Scope) (bool, error) 43 } 44 Scope interface { 45 GetType(ref core.EntityRef) (Def, Scope, error) 46 IsTopType(expr Expr) bool 47 } 48 ) 49 50 // ResolveExpr resolves given expression using only global scope. 51 func (r Resolver) ResolveExpr(expr Expr, scope Scope) (Expr, error) { 52 return r.resolveExpr(expr, scope, map[string]Def{}, nil) 53 } 54 55 // ResolveExprWithFrame works like ResolveExpr but allows to pass local scope. 56 func (r Resolver) ResolveExprWithFrame( 57 expr Expr, 58 frame map[string]Def, 59 scope Scope, 60 ) ( 61 Expr, 62 error, 63 ) { 64 return r.resolveExpr(expr, scope, frame, nil) 65 } 66 67 // ResolveExprWithFrame works like ResolveExprWithFrame but for list of expressions. 68 func (r Resolver) ResolveExprsWithFrame( 69 exprs []Expr, 70 frame map[string]Def, 71 scope Scope, 72 ) ( 73 []Expr, 74 error, 75 ) { 76 resolvedExprs := make([]Expr, 0, len(exprs)) 77 for _, expr := range exprs { 78 resolvedExpr, err := r.resolveExpr(expr, scope, frame, nil) 79 if err != nil { 80 return nil, err 81 } 82 resolvedExprs = append(resolvedExprs, resolvedExpr) 83 } 84 return resolvedExprs, nil 85 } 86 87 // ResolveParams resolves every constraint in given parameter list. 88 func (r Resolver) ResolveParams( 89 params []Param, 90 scope Scope, 91 ) ( 92 []Param, // resolved parameters 93 map[string]Def, // resolved frame `paramName:resolvedConstr` 94 error, 95 ) { 96 result := make([]Param, 0, len(params)) 97 frame := make(map[string]Def, len(params)) 98 for _, param := range params { 99 resolved, err := r.resolveExpr(param.Constr, scope, frame, nil) 100 if err != nil { 101 return nil, nil, fmt.Errorf("resolve expr: %w", err) 102 } 103 frame[param.Name] = Def{BodyExpr: &resolved} 104 result = append(result, Param{ 105 Name: param.Name, 106 Constr: resolved, 107 }) 108 } 109 return result, frame, nil 110 } 111 112 // IsSubtypeOf resolves both `sub` and `sup` expressions 113 // and returns error if `sub` is not subtype of `sup`. 114 func (r Resolver) IsSubtypeOf(sub, sup Expr, scope Scope) error { 115 resolvedSub, err := r.resolveExpr(sub, scope, nil, nil) 116 if err != nil { 117 return fmt.Errorf("resolve sub expr: %w", err) 118 } 119 resolvedSup, err := r.resolveExpr(sup, scope, nil, nil) 120 if err != nil { 121 return fmt.Errorf("resolve sup expr: %w", err) 122 } 123 return r.checker.Check( 124 resolvedSub, 125 resolvedSup, 126 TerminatorParams{Scope: scope}, 127 ) 128 } 129 130 // CheckArgsCompatibility resolves args 131 // and params and then checks their compatibility. 132 func (r Resolver) CheckArgsCompatibility(args []Expr, params []Param, scope Scope) error { 133 if len(args) != len(params) { 134 return fmt.Errorf( 135 "count of arguments mismatch count of parameters, want %d got %d", 136 len(params), 137 len(args), 138 ) 139 } 140 141 for i := range params { 142 arg := args[i] 143 param := params[i] 144 145 resolvedSub, err := r.resolveExpr(arg, scope, nil, nil) 146 if err != nil { 147 return fmt.Errorf("resolve arg expr: %w", err) 148 } 149 150 resolvedSup, err := r.resolveExpr(param.Constr, scope, nil, nil) 151 if err != nil { 152 return fmt.Errorf("resolve param constr expr: %w", err) 153 } 154 155 if err := r.checker.Check( 156 resolvedSub, 157 resolvedSup, 158 TerminatorParams{Scope: scope}, 159 ); err != nil { 160 return err 161 } 162 } 163 164 return nil 165 } 166 167 // resolveExpr turn one expression into another where all references points to native types. 168 // It's a recursive process where each step starts with validation. Invalid expression always leads to error. 169 // For inst expr it checks compatibility between args and params and returns error if some constraint isn't satisfied. 170 // Then it updates scope by adding params of ref type with resolved args as values to allow substitution later. 171 // Then it checks whether base type of current ref type is native type to terminate with nil err and resolved expr. 172 // For non-native types process starts from the beginning with updated scope. New scope will contain values for params. 173 // For lit exprs logic is the this: for enum do nothing (it's valid and not composite, there's nothing to resolveExpr), 174 // for array resolveExpr it's type, for struct and union apply recursion for it's every field/element. 175 func (r Resolver) resolveExpr( //nolint:funlen,gocognit 176 expr Expr, // expression to be resolved 177 scope Scope, // global scope 178 frame map[string]Def, // local scope 179 trace *Trace, // how did we get here 180 ) (Expr, error) { 181 if err := r.validator.Validate(expr); err != nil { 182 return Expr{}, fmt.Errorf("%w: %v", ErrInvalidExpr, err) 183 } 184 185 if expr.Lit != nil { 186 switch expr.Lit.Type() { 187 case EnumLitType: 188 return expr, nil 189 case UnionLitType: 190 resolvedUnion := make([]Expr, 0, len(expr.Lit.Union)) 191 for _, unionEl := range expr.Lit.Union { 192 resolvedEl, err := r.resolveExpr(unionEl, scope, frame, trace) 193 if err != nil { 194 return Expr{}, fmt.Errorf("%w: %v", ErrUnionUnresolvedEl, err) 195 } 196 resolvedUnion = append(resolvedUnion, resolvedEl) 197 } 198 return Expr{ 199 Lit: &LitExpr{Union: resolvedUnion}, 200 }, nil 201 case StructLitType: 202 resolvedStruct := make(map[string]Expr, len(expr.Lit.Struct)) 203 for field, fieldExpr := range expr.Lit.Struct { 204 // we create new trace with virtual ref "struct" (it's safe because it's reserved word) 205 // otherwise expressions like `error struct {child maybe<error>}` will be direct recursive for terminator 206 newTrace := Trace{ 207 prev: trace, 208 cur: core.EntityRef{Name: "struct"}, 209 } 210 resolvedFieldExpr, err := r.resolveExpr( 211 fieldExpr, 212 scope, 213 frame, 214 &newTrace, 215 ) 216 if err != nil { 217 return Expr{}, fmt.Errorf( 218 "%w: %v: %v", 219 ErrRecFieldUnresolved, 220 field, 221 err, 222 ) 223 } 224 resolvedStruct[field] = resolvedFieldExpr 225 } 226 return Expr{ 227 Lit: &LitExpr{Struct: resolvedStruct}, 228 }, nil 229 } 230 } 231 232 def, scopeWhereDefFound, err := r.getDef(expr.Inst.Ref, frame, scope) 233 if err != nil { 234 return Expr{}, err 235 } 236 237 if err := r.validator.ValidateDef(def); err != nil { 238 return Expr{}, errors.Join(ErrInvalidDef, err) 239 } 240 241 if len(def.Params) != len(expr.Inst.Args) { // args must not be > than params to avoid bad case with constraint 242 return Expr{}, fmt.Errorf( 243 "%w for '%v': want %d, got %d", 244 ErrInstArgsCount, 245 expr.Inst.Ref, 246 len(def.Params), 247 len(expr.Inst.Args), 248 ) 249 } 250 251 newTrace := Trace{ 252 prev: trace, 253 cur: expr.Inst.Ref, // FIXME t1 254 } 255 256 shouldReturn, err := r.terminator.ShouldTerminate(newTrace, scope) 257 if err != nil { 258 return Expr{}, fmt.Errorf("%w: %v", ErrTerminator, err) 259 } else if shouldReturn { 260 return expr, nil 261 } 262 263 newFrame := make(map[string]Def, len(def.Params)) 264 resolvedArgs := make([]Expr, 0, len(expr.Inst.Args)) 265 for i, param := range def.Params { // resolve args and constrs and check their compatibility 266 resolvedArg, err := r.resolveExpr(expr.Inst.Args[i], scope, frame, &newTrace) 267 if err != nil { 268 return Expr{}, fmt.Errorf("%w: %v", ErrUnresolvedArg, err) 269 } 270 271 newFrame[param.Name] = Def{BodyExpr: &resolvedArg} // no params for generics 272 resolvedArgs = append(resolvedArgs, resolvedArg) 273 274 // we pass newFrame because constr can refer to type parameters 275 resolvedConstr, err := r.resolveExpr( 276 param.Constr, 277 scope, 278 newFrame, 279 &newTrace, 280 ) 281 if err != nil { 282 return Expr{}, fmt.Errorf("%w: %v", ErrConstr, err) 283 } 284 285 params := TerminatorParams{ 286 Scope: scope, 287 SubtypeTrace: newTrace, 288 SupertypeTrace: newTrace, 289 } 290 291 if err := r.checker.Check(resolvedArg, resolvedConstr, params); err != nil { 292 return Expr{}, fmt.Errorf(" %w: %v", ErrIncompatArg, err) 293 } 294 } 295 296 if def.BodyExpr == nil { 297 return Expr{ 298 Inst: &InstExpr{ 299 Ref: expr.Inst.Ref, 300 Args: resolvedArgs, 301 }, 302 }, nil 303 } 304 305 return r.resolveExpr(*def.BodyExpr, scopeWhereDefFound, newFrame, &newTrace) 306 } 307 308 func (Resolver) getDef( 309 ref core.EntityRef, 310 frame map[string]Def, 311 scope Scope, 312 ) (Def, Scope, error) { 313 strRef := ref.String() 314 def, exist := frame[strRef] 315 if exist { 316 return def, scope, nil 317 } 318 319 def, scope, err := scope.GetType(ref) 320 if err != nil { 321 return Def{}, nil, fmt.Errorf("%w: %v", ErrScope, err) 322 } 323 324 return def, scope, nil 325 } 326 327 func MustNewResolver(validator exprValidator, checker subtypeChecker, terminator recursionTerminator) Resolver { 328 if validator == nil || checker == nil || terminator == nil { 329 panic("all arguments must be not nil") 330 } 331 return Resolver{validator, checker, terminator} 332 }