github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/inter/imocks/util.go (about) 1 // Copyright (c) 2021-2022, R.I. Pienaar and the Choria Project contributors 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 package imock 6 7 import ( 8 "fmt" 9 "io" 10 "os" 11 "strings" 12 13 "github.com/brutella/hc/util" 14 "github.com/choria-io/go-choria/config" 15 "github.com/choria-io/go-choria/inter" 16 "github.com/choria-io/go-choria/protocol" 17 "github.com/golang/mock/gomock" 18 "github.com/sirupsen/logrus" 19 ) 20 21 type fwMockOpts struct { 22 callerID string 23 logDiscard bool 24 cfg *config.Config 25 ddlResolver inter.DDLResolver 26 ddls [][3]string 27 reqProto protocol.ProtocolVersion 28 } 29 type fwMockOption func(*fwMockOpts) 30 31 func WithRequestProtocol(p protocol.ProtocolVersion) fwMockOption { 32 return func(o *fwMockOpts) { 33 o.reqProto = p 34 } 35 } 36 37 func WithCallerID(c ...string) fwMockOption { 38 return func(o *fwMockOpts) { 39 if len(c) == 0 { 40 o.callerID = "choria=rip.mcollective" 41 } else { 42 o.callerID = c[0] 43 } 44 } 45 } 46 47 func LogDiscard() fwMockOption { 48 return func(o *fwMockOpts) { 49 o.logDiscard = true 50 } 51 } 52 53 func WithConfig(c *config.Config) fwMockOption { 54 return func(o *fwMockOpts) { o.cfg = c } 55 } 56 57 func WithConfigFile(f string) fwMockOption { 58 return func(o *fwMockOpts) { 59 cfg, err := config.NewConfig(f) 60 if err != nil { 61 panic(err) 62 } 63 o.cfg = cfg 64 } 65 } 66 67 func WithDDLResolver(r inter.DDLResolver) fwMockOption { 68 return func(o *fwMockOpts) { 69 o.ddlResolver = r 70 } 71 } 72 73 func WithDDLFiles(kind string, plugin string, path string) fwMockOption { 74 return func(o *fwMockOpts) { 75 o.ddls = append(o.ddls, [3]string{kind, plugin, path}) 76 } 77 } 78 79 func NewFrameworkForTests(ctrl *gomock.Controller, logWriter io.Writer, opts ...fwMockOption) (*MockFramework, *config.Config) { 80 mopts := &fwMockOpts{ 81 cfg: config.NewConfigForTests(), 82 reqProto: protocol.RequestV1, 83 } 84 for _, o := range opts { 85 o(mopts) 86 } 87 88 logger := logrus.New() 89 if mopts.logDiscard { 90 logger.SetOutput(io.Discard) 91 } else { 92 logger.SetOutput(logWriter) 93 } 94 95 fw := NewMockFramework(ctrl) 96 fw.EXPECT().Configuration().Return(mopts.cfg).AnyTimes() 97 fw.EXPECT().Logger(gomock.AssignableToTypeOf("")).Return(logrus.NewEntry(logger)).AnyTimes() 98 fw.EXPECT().NewRequestID().Return(util.RandomHexString(), nil).AnyTimes() 99 fw.EXPECT().HasCollective(gomock.AssignableToTypeOf("")).DoAndReturn(func(c string) bool { 100 for _, collective := range fw.Configuration().Collectives { 101 if c == collective { 102 return true 103 } 104 } 105 return false 106 }).AnyTimes() 107 108 if mopts.callerID != "" { 109 fw.EXPECT().CallerID().Return(mopts.callerID).AnyTimes() 110 fw.EXPECT().Certname().DoAndReturn(func() string { 111 if fw.Configuration().OverrideCertname != "" { 112 return fw.Configuration().OverrideCertname 113 } 114 115 parts := strings.SplitN(mopts.callerID, "=", 2) 116 return parts[1] 117 }).AnyTimes() 118 } 119 120 if mopts.ddlResolver == nil { 121 resolver := NewMockDDLResolver(ctrl) 122 mopts.ddlResolver = resolver 123 for _, ddl := range mopts.ddls { 124 f, err := os.ReadFile(ddl[2]) 125 if err != nil { 126 panic(fmt.Sprintf("ddl file %s: %s", ddl[2], err)) 127 } 128 resolver.EXPECT().DDLBytes(gomock.Any(), gomock.Eq("agent"), gomock.Eq("package"), gomock.Any()).Return(f, nil).AnyTimes() 129 } 130 } 131 132 fw.EXPECT().DDLResolvers().Return([]inter.DDLResolver{mopts.ddlResolver}, nil).AnyTimes() 133 fw.EXPECT().RequestProtocol().Return(mopts.reqProto).AnyTimes() 134 135 return fw, fw.Configuration() 136 }