github.com/tencent/goom@v1.0.1/mocker.go (about) 1 // Package mocker 定义了 mock 的外层用户使用 API 定义, 2 // 包括函数、方法、接口、未导出函数(或方法的)的 Mocker 的实现。 3 // 当前文件定义了函数、方法、未导出函数(或方法)的 Mocker 的行为。 4 package mocker 5 6 import ( 7 "fmt" 8 "reflect" 9 "runtime" 10 "strings" 11 12 "github.com/tencent/goom/erro" 13 "github.com/tencent/goom/internal/iface" 14 "github.com/tencent/goom/internal/logger" 15 "github.com/tencent/goom/internal/patch" 16 "github.com/tencent/goom/internal/proxy" 17 "github.com/tencent/goom/internal/unexports" 18 ) 19 20 // Mocker mock 接口, 所有类型(函数、方法、未导出函数、接口等)的 Mocker 的抽象 21 type Mocker interface { 22 // Apply 代理方法实现 23 // 注意: Apply 会覆盖之前设定的 When 条件和 Return 24 // 注意: 不支持在多个协程中并发地 Apply 不同的 imp 函数 25 Apply(callback interface{}) 26 // Cancel 取消代理 27 Cancel() 28 // Canceled 是否已经被取消 29 Canceled() bool 30 // String mock 的名称或描述, 方便调试和问题排查 31 String() string 32 } 33 34 // ExportedMocker 导出函数 mock 接口 35 type ExportedMocker interface { 36 Mocker 37 // When 指定条件匹配 38 When(specArg ...interface{}) *When 39 // Return 执行返回值 40 Return(value ...interface{}) *When 41 // Returns 依次按顺序返回值, 如果是多参可使用[]interface{} 42 Returns(values ...interface{}) *When 43 // Origin 指定 Mock 之后的原函数, origin 签名和 mock 的函数一致 44 Origin(originFunc interface{}) ExportedMocker 45 } 46 47 // UnExportedMocker 未导出函数 mock 接口 48 type UnExportedMocker interface { 49 Mocker 50 // As 将未导出函数(或方法)转换为导出函数(或方法) 51 // As 调用之后,请使用 Return 或 When API 的方式来指定 mock 返回。 52 As(aFunc interface{}) ExportedMocker 53 // Origin 指定 Mock 之后的原函数, origin 签名和 mock 的函数一致 54 Origin(originFunc interface{}) UnExportedMocker 55 } 56 57 // baseMocker mocker 基础类型 58 type baseMocker struct { 59 pkgName string 60 origin interface{} 61 guard MockGuard 62 funcDef interface{} 63 imp interface{} 64 65 when *When 66 // canceled 是否被取消 67 canceled bool 68 } 69 70 // newBaseMocker 新增基础类型 mocker 71 func newBaseMocker(pkgName string) *baseMocker { 72 return &baseMocker{ 73 pkgName: pkgName, 74 } 75 } 76 77 // applyByName 根据函数名称应用 mock 78 func (m *baseMocker) applyByName(funcName string, callback interface{}) { 79 guard, err := proxy.FuncName(funcName, callback, m.origin) 80 if err != nil { 81 panic(fmt.Sprintf("proxy func name error: %v", err)) 82 } 83 84 m.guard = newPatchMockGuard(guard) 85 m.guard.Apply() 86 m.imp = callback 87 } 88 89 // applyByFunc 根据函数应用 mock 90 func (m *baseMocker) applyByFunc(funcDef interface{}, callback interface{}) { 91 guard, err := proxy.Func(funcDef, callback, m.origin) 92 if err != nil { 93 panic(fmt.Sprintf("proxy func definition error: %v", err)) 94 } 95 96 m.guard = newPatchMockGuard(guard) 97 m.guard.Apply() 98 m.imp = callback 99 m.funcDef = funcDef 100 } 101 102 // applyByMethod 根据函数名应用 mock 103 func (m *baseMocker) applyByMethod(structDef interface{}, method string, callback interface{}) { 104 guard, err := proxy.Method(reflect.TypeOf(structDef), method, callback, m.origin) 105 if err != nil { 106 panic(fmt.Sprintf("proxy method error: %v", err)) 107 } 108 109 m.guard = newPatchMockGuard(guard) 110 m.guard.Apply() 111 m.imp = callback 112 m.funcDef = reflect.ValueOf(structDef).MethodByName(method).Interface() 113 } 114 115 // applyByIFaceMethod 根据接口方法应用 mock 116 func (m *baseMocker) applyByIFaceMethod(ctx *iface.IContext, iFace interface{}, method string, callback interface{}, 117 implV iface.PFunc) { 118 119 impV := reflect.TypeOf(callback) 120 if impV.In(0) != reflect.TypeOf(&IContext{}) { 121 panic(erro.NewIllegalParamTypeError("<first arg>", impV.In(0).Name(), "*IContext")) 122 } 123 124 err := proxy.Interface(iFace, ctx, method, callback, implV) 125 if err != nil { 126 panic(erro.NewTraceableErrorf("interface mock apply error", err)) 127 } 128 129 m.guard = newIFaceMockGuard(ctx) 130 m.guard.Apply() 131 m.imp = callback 132 } 133 134 // whens 指定的返回值 135 func (m *baseMocker) whens(when *When) error { 136 m.imp = reflect.MakeFunc(when.funcTyp, m.callback).Interface() 137 m.when = when 138 return nil 139 } 140 141 // callback 通用的 MakeFunc callback 142 func (m *baseMocker) callback(args []reflect.Value) (results []reflect.Value) { 143 if m.canceled && m.funcDef != nil { 144 return reflect.ValueOf(m.funcDef).Call(args) 145 } 146 if m.when != nil { 147 results = m.when.invoke(args) 148 if results != nil { 149 return results 150 } 151 } 152 panic("there is no suitable condition matched, or set default return with: mocker.Return(...)") 153 } 154 155 // Cancel 取消 Mock 156 func (m *baseMocker) Cancel() { 157 if m.guard != nil { 158 m.guard.Cancel() 159 } 160 m.when = nil 161 m.origin = nil 162 m.canceled = true 163 } 164 165 // Canceled 是否被取消 166 func (m *baseMocker) Canceled() bool { 167 return m.canceled 168 } 169 170 // MethodMocker 对结构体函数或方法进行 mock 171 // 能支持到私有函数、私有类型的方法的 Mock 172 type MethodMocker struct { 173 *baseMocker 174 structDef interface{} 175 method string 176 methodIns interface{} 177 } 178 179 // NewMethodMocker 创建 MethodMocker 180 // pkgName 包路径 181 // structDef 结构体变量定义, 不能为 nil 182 func NewMethodMocker(pkgName string, structDef interface{}) *MethodMocker { 183 return &MethodMocker{ 184 baseMocker: newBaseMocker(pkgName), 185 structDef: structDef, 186 } 187 } 188 189 // String mock 的名称或描述, 方便调试和问题排查 190 func (m *MethodMocker) String() string { 191 t := reflect.TypeOf(m.structDef) 192 if t.Kind() == reflect.Ptr { 193 t = t.Elem() 194 } 195 return fmt.Sprintf("%s.(%s).%s", t.PkgPath(), t.Name(), m.method) 196 } 197 198 // Method 设置结构体的方法名 199 func (m *MethodMocker) Method(name string) ExportedMocker { 200 if name == "" { 201 panic("method is empty") 202 } 203 m.method = name 204 205 sTyp := reflect.TypeOf(m.structDef) 206 method, ok := sTyp.MethodByName(m.method) 207 if !ok { 208 panic("method " + m.method + " not found on " + sTyp.String()) 209 } 210 m.methodIns = method.Func.Interface() 211 return m 212 } 213 214 // ExportMethod 导出私有方法 215 func (m *MethodMocker) ExportMethod(name string) UnExportedMocker { 216 if name == "" { 217 panic("method is empty") 218 } 219 220 // 转换结构体名 221 structName := typeName(m.structDef) 222 if strings.Contains(structName, "*") { 223 structName = fmt.Sprintf("(%s)", structName) 224 } 225 226 packageName := packageName(m.structDef) 227 m.baseMocker.pkgName = packageName 228 229 return (&UnexportedMethodMocker{ 230 baseMocker: m.baseMocker, 231 structName: structName, 232 methodName: name, 233 }).Method(name) 234 } 235 236 // Apply 指定 mock 执行的回调函数 237 // mock 回调函数, 需要和 mock 模板函数的签名保持一致 238 // 方法的参数签名写法比如: func(s *Struct, arg1, arg2 type), 其中第一个参数必须是接收体类型 239 func (m *MethodMocker) Apply(callback interface{}) { 240 m.doApply(callback) 241 } 242 243 func (m *MethodMocker) doApply(imp interface{}) { 244 if m.method == "" { 245 panic("method is empty") 246 } 247 imp, _ = interceptDebugInfo(imp, nil, m) 248 m.applyByMethod(m.structDef, m.method, imp) 249 logger.Consolefc(logger.DebugLevel, "mocker [%s] apply.", logger.Caller(6), m.String()) 250 } 251 252 // When 指定条件匹配 253 func (m *MethodMocker) When(specArg ...interface{}) *When { 254 if m.method == "" { 255 panic("method is empty") 256 } 257 if m.when != nil { 258 return m.when.When(specArg...) 259 } 260 261 sTyp := reflect.TypeOf(m.structDef) 262 methodIns, ok := sTyp.MethodByName(m.method) 263 if !ok { 264 panic("method " + m.method + " not found on " + sTyp.String()) 265 } 266 267 var ( 268 when *When 269 err error 270 ) 271 if when, err = CreateWhen(m, methodIns.Func.Interface(), specArg, nil, true); err != nil { 272 panic(err) 273 } 274 if err := m.whens(when); err != nil { 275 panic(err) 276 } 277 278 m.doApply(m.imp) 279 return when 280 } 281 282 // Return 指定返回值 283 func (m *MethodMocker) Return(value ...interface{}) *When { 284 if m.method == "" { 285 panic("method is empty") 286 } 287 if m.when != nil { 288 return m.when.Return(value...) 289 } 290 291 var ( 292 when *When 293 err error 294 ) 295 if when, err = CreateWhen(m, m.methodIns, nil, value, true); err != nil { 296 panic(err) 297 } 298 if err := m.whens(when); err != nil { 299 panic(err) 300 } 301 m.doApply(m.imp) 302 return when 303 } 304 305 // Returns 依次按顺序返回值 306 func (m *MethodMocker) Returns(values ...interface{}) *When { 307 if m.method == "" { 308 panic("method is empty") 309 } 310 if m.when != nil { 311 return m.when.Returns(values...) 312 } 313 314 var ( 315 when *When 316 err error 317 ) 318 if when, err = CreateWhen(m, m.methodIns, nil, nil, true); err != nil { 319 panic(err) 320 } 321 if err := m.whens(when); err != nil { 322 panic(err) 323 } 324 m.when.Returns(values...) 325 m.doApply(m.imp) 326 return when 327 } 328 329 // Origin 指定调用的原函数 330 func (m *MethodMocker) Origin(originFunc interface{}) ExportedMocker { 331 m.origin = originFunc 332 return m 333 } 334 335 // UnexportedMethodMocker 对结构体函数或方法进行 mock 336 // 能支持到未导出类型、未导出类型的方法的 Mock 337 type UnexportedMethodMocker struct { 338 *baseMocker 339 structName string 340 methodName string 341 } 342 343 // NewUnexportedMethodMocker 创建未导出方法 Mocker 344 // pkgName 包路径 345 // structName 结构体名称 346 func NewUnexportedMethodMocker(pkgName string, structName string) *UnexportedMethodMocker { 347 return &UnexportedMethodMocker{ 348 baseMocker: newBaseMocker(pkgName), 349 structName: structName, 350 } 351 } 352 353 // String mock 的名称或描述, 方便调试和问题排查 354 func (m *UnexportedMethodMocker) String() string { 355 return fmt.Sprintf("%s.%s.%s", m.pkgName, m.structName, m.methodName) 356 } 357 358 // objName 获取对象名 359 func (m *UnexportedMethodMocker) objName() string { 360 return fmt.Sprintf("%s.%s.%s", m.pkgName, m.structName, m.methodName) 361 } 362 363 // Method 设置结构体的方法名 364 func (m *UnexportedMethodMocker) Method(name string) UnExportedMocker { 365 m.methodName = name 366 return m 367 } 368 369 // Apply 指定 mock 执行的回调函数 370 // mock 回调函数, 需要和 mock 模板函数的签名保持一致 371 // 方法的参数签名写法比如: func(s *Struct, arg1, arg2 type), 其中第一个参数必须是接收体类型 372 func (m *UnexportedMethodMocker) Apply(callback interface{}) { 373 name := m.objName() 374 if name == "" { 375 panic("method name is empty") 376 } 377 378 if !strings.Contains(name, "*") { 379 _, _ = unexports.FindFuncByName(name) 380 } 381 382 callback, _ = interceptDebugInfo(callback, nil, m) 383 m.applyByName(name, callback) 384 logger.Consolefc(logger.DebugLevel, "mocker [%s] apply.", logger.Caller(5), m.String()) 385 } 386 387 // Origin 调用原函数 388 func (m *UnexportedMethodMocker) Origin(originFunc interface{}) UnExportedMocker { 389 m.origin = originFunc 390 return m 391 } 392 393 // As 将未导出函数(或方法)转换为导出函数(或方法) 394 func (m *UnexportedMethodMocker) As(aFunc interface{}) ExportedMocker { 395 name := m.objName() 396 if name == "" { 397 panic("method name is empty") 398 } 399 400 var ( 401 err error 402 originFuncPtr uintptr 403 ) 404 originFuncPtr, err = unexports.FindFuncByName(name) 405 if err != nil { 406 panic(err) 407 } 408 newFunc := unexports.NewFuncWithCodePtr(reflect.TypeOf(aFunc), originFuncPtr) 409 return &DefMocker{ 410 baseMocker: m.baseMocker, 411 funcDef: newFunc.Interface(), 412 } 413 } 414 415 // UnexportedFuncMocker 对函数或方法进行 mock 416 // 能支持到私有函数、私有类型的方法的 Mock 417 type UnexportedFuncMocker struct { 418 *baseMocker 419 funcName string 420 } 421 422 // NewUnexportedFuncMocker 创建未导出函数 Mocker 423 // pkgName 包路径 424 // funcName 函数名称 425 func NewUnexportedFuncMocker(pkgName, funcName string) *UnexportedFuncMocker { 426 return &UnexportedFuncMocker{ 427 baseMocker: newBaseMocker(pkgName), 428 funcName: funcName, 429 } 430 } 431 432 // String mock 的名称或描述, 方便调试和问题排查 433 func (m *UnexportedFuncMocker) String() string { 434 return fmt.Sprintf("%s.%s", m.pkgName, m.funcName) 435 } 436 437 // objName 获取对象名 438 func (m *UnexportedFuncMocker) objName() string { 439 return fmt.Sprintf("%s.%s", m.pkgName, m.funcName) 440 } 441 442 // Apply 指定 mock 执行的回调函数 443 // mock 回调函数, 需要和 mock 模板函数的签名保持一致 444 // 方法的参数签名写法比如: func(s *Struct, arg1, arg2 type), 其中第一个参数必须是接收体类型 445 func (m *UnexportedFuncMocker) Apply(callback interface{}) { 446 callback, _ = interceptDebugInfo(callback, nil, m) 447 m.applyByName(m.objName(), callback) 448 logger.Consolefc(logger.DebugLevel, "mocker [%s] apply.", logger.Caller(5), m.String()) 449 } 450 451 // Origin 调用原函数 452 func (m *UnexportedFuncMocker) Origin(originFunc interface{}) UnExportedMocker { 453 m.origin = originFunc 454 return m 455 } 456 457 // As 将未导出函数(或方法)转换为导出函数(或方法) 458 func (m *UnexportedFuncMocker) As(aFunc interface{}) ExportedMocker { 459 originFuncPtr, err := unexports.FindFuncByName(m.objName()) 460 if err != nil { 461 panic(err) 462 } 463 464 newFunc := unexports.NewFuncWithCodePtr(reflect.TypeOf(aFunc), originFuncPtr) 465 return &DefMocker{ 466 baseMocker: m.baseMocker, 467 funcDef: newFunc.Interface(), 468 } 469 } 470 471 // DefMocker 对函数或方法进行 mock,使用函数定义筛选 472 type DefMocker struct { 473 *baseMocker 474 funcDef interface{} 475 } 476 477 // String mock 的名称或描述 478 func (m *DefMocker) String() string { 479 return runtime.FuncForPC(reflect.ValueOf(m.funcDef).Pointer()).Name() 480 } 481 482 // NewDefMocker 创建 DefMocker 483 // pkgName 包路径 484 // funcDef 函数变量定义 485 func NewDefMocker(pkgName string, funcDef interface{}) *DefMocker { 486 return &DefMocker{ 487 baseMocker: newBaseMocker(pkgName), 488 funcDef: funcDef, 489 } 490 } 491 492 // Apply 代理方法实现 493 func (m *DefMocker) Apply(callback interface{}) { 494 m.doApply(callback) 495 } 496 497 func (m *DefMocker) doApply(imp interface{}) { 498 if m.funcDef == nil { 499 panic("funcDef is empty") 500 } 501 502 funcName := functionName(m.funcDef) 503 imp, _ = interceptDebugInfo(imp, nil, m) 504 if patch.IsGenericsFunc(funcName) { 505 // for generic variants func 506 m.applyByFunc(m.funcDef, imp) 507 } else if strings.HasSuffix(funcName, "-fm") { 508 // TODO 理清-fm的用意 509 m.applyByName(strings.TrimSuffix(funcName, "-fm"), imp) 510 } else { 511 m.applyByFunc(m.funcDef, imp) 512 } 513 logger.Consolefc(logger.DebugLevel, "mocker [%s] apply.", logger.Caller(6), m.String()) 514 } 515 516 // When 指定条件匹配 517 func (m *DefMocker) When(specArg ...interface{}) *When { 518 if m.when != nil { 519 return m.when.When(specArg...) 520 } 521 var ( 522 when *When 523 err error 524 ) 525 if when, err = CreateWhen(m, m.funcDef, specArg, nil, false); err != nil { 526 panic(err) 527 } 528 if err := m.whens(when); err != nil { 529 panic(err) 530 } 531 m.doApply(m.imp) 532 return when 533 } 534 535 // Return 代理方法返回 536 func (m *DefMocker) Return(value ...interface{}) *When { 537 if m.when != nil { 538 return m.when.Return(value...) 539 } 540 var ( 541 when *When 542 err error 543 ) 544 if when, err = CreateWhen(m, m.funcDef, nil, value, false); err != nil { 545 panic(err) 546 } 547 if err := m.whens(when); err != nil { 548 panic(err) 549 } 550 m.doApply(m.imp) 551 return when 552 } 553 554 // Returns 依次按顺序返回值, 如果是多参可使用[]interface{} 555 func (m *DefMocker) Returns(values ...interface{}) *When { 556 if m.when != nil { 557 return m.when.Returns(values...) 558 } 559 var ( 560 when *When 561 err error 562 ) 563 if when, err = CreateWhen(m, m.funcDef, nil, nil, false); err != nil { 564 panic(err) 565 } 566 if err := m.whens(when); err != nil { 567 panic(err) 568 } 569 m.when.Returns(values...) 570 m.doApply(m.imp) 571 return when 572 } 573 574 // Origin 调用原函数 575 // origin 需要和原函数的参数列表保持一致 576 func (m *DefMocker) Origin(originFunc interface{}) ExportedMocker { 577 m.origin = originFunc 578 return m 579 }