github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/inter/imocks/util.go (about)

     1  // Copyright (c) 2021-2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package imock
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"strings"
    12  
    13  	"github.com/brutella/hc/util"
    14  	"github.com/choria-io/go-choria/config"
    15  	"github.com/choria-io/go-choria/inter"
    16  	"github.com/choria-io/go-choria/protocol"
    17  	"github.com/golang/mock/gomock"
    18  	"github.com/sirupsen/logrus"
    19  )
    20  
    21  type fwMockOpts struct {
    22  	callerID    string
    23  	logDiscard  bool
    24  	cfg         *config.Config
    25  	ddlResolver inter.DDLResolver
    26  	ddls        [][3]string
    27  	reqProto    protocol.ProtocolVersion
    28  }
    29  type fwMockOption func(*fwMockOpts)
    30  
    31  func WithRequestProtocol(p protocol.ProtocolVersion) fwMockOption {
    32  	return func(o *fwMockOpts) {
    33  		o.reqProto = p
    34  	}
    35  }
    36  
    37  func WithCallerID(c ...string) fwMockOption {
    38  	return func(o *fwMockOpts) {
    39  		if len(c) == 0 {
    40  			o.callerID = "choria=rip.mcollective"
    41  		} else {
    42  			o.callerID = c[0]
    43  		}
    44  	}
    45  }
    46  
    47  func LogDiscard() fwMockOption {
    48  	return func(o *fwMockOpts) {
    49  		o.logDiscard = true
    50  	}
    51  }
    52  
    53  func WithConfig(c *config.Config) fwMockOption {
    54  	return func(o *fwMockOpts) { o.cfg = c }
    55  }
    56  
    57  func WithConfigFile(f string) fwMockOption {
    58  	return func(o *fwMockOpts) {
    59  		cfg, err := config.NewConfig(f)
    60  		if err != nil {
    61  			panic(err)
    62  		}
    63  		o.cfg = cfg
    64  	}
    65  }
    66  
    67  func WithDDLResolver(r inter.DDLResolver) fwMockOption {
    68  	return func(o *fwMockOpts) {
    69  		o.ddlResolver = r
    70  	}
    71  }
    72  
    73  func WithDDLFiles(kind string, plugin string, path string) fwMockOption {
    74  	return func(o *fwMockOpts) {
    75  		o.ddls = append(o.ddls, [3]string{kind, plugin, path})
    76  	}
    77  }
    78  
    79  func NewFrameworkForTests(ctrl *gomock.Controller, logWriter io.Writer, opts ...fwMockOption) (*MockFramework, *config.Config) {
    80  	mopts := &fwMockOpts{
    81  		cfg:      config.NewConfigForTests(),
    82  		reqProto: protocol.RequestV1,
    83  	}
    84  	for _, o := range opts {
    85  		o(mopts)
    86  	}
    87  
    88  	logger := logrus.New()
    89  	if mopts.logDiscard {
    90  		logger.SetOutput(io.Discard)
    91  	} else {
    92  		logger.SetOutput(logWriter)
    93  	}
    94  
    95  	fw := NewMockFramework(ctrl)
    96  	fw.EXPECT().Configuration().Return(mopts.cfg).AnyTimes()
    97  	fw.EXPECT().Logger(gomock.AssignableToTypeOf("")).Return(logrus.NewEntry(logger)).AnyTimes()
    98  	fw.EXPECT().NewRequestID().Return(util.RandomHexString(), nil).AnyTimes()
    99  	fw.EXPECT().HasCollective(gomock.AssignableToTypeOf("")).DoAndReturn(func(c string) bool {
   100  		for _, collective := range fw.Configuration().Collectives {
   101  			if c == collective {
   102  				return true
   103  			}
   104  		}
   105  		return false
   106  	}).AnyTimes()
   107  
   108  	if mopts.callerID != "" {
   109  		fw.EXPECT().CallerID().Return(mopts.callerID).AnyTimes()
   110  		fw.EXPECT().Certname().DoAndReturn(func() string {
   111  			if fw.Configuration().OverrideCertname != "" {
   112  				return fw.Configuration().OverrideCertname
   113  			}
   114  
   115  			parts := strings.SplitN(mopts.callerID, "=", 2)
   116  			return parts[1]
   117  		}).AnyTimes()
   118  	}
   119  
   120  	if mopts.ddlResolver == nil {
   121  		resolver := NewMockDDLResolver(ctrl)
   122  		mopts.ddlResolver = resolver
   123  		for _, ddl := range mopts.ddls {
   124  			f, err := os.ReadFile(ddl[2])
   125  			if err != nil {
   126  				panic(fmt.Sprintf("ddl file %s: %s", ddl[2], err))
   127  			}
   128  			resolver.EXPECT().DDLBytes(gomock.Any(), gomock.Eq("agent"), gomock.Eq("package"), gomock.Any()).Return(f, nil).AnyTimes()
   129  		}
   130  	}
   131  
   132  	fw.EXPECT().DDLResolvers().Return([]inter.DDLResolver{mopts.ddlResolver}, nil).AnyTimes()
   133  	fw.EXPECT().RequestProtocol().Return(mopts.reqProto).AnyTimes()
   134  
   135  	return fw, fw.Configuration()
   136  }