github.com/clubpay/ronykit/kit@v0.14.4-0.20240515065620-d0dace45cbc7/ctx_testkit.go (about)

     1  package kit
     2  
     3  import (
     4  	"mime/multipart"
     5  	"sync"
     6  
     7  	"github.com/clubpay/ronykit/kit/utils"
     8  )
     9  
    10  // TestContext is useful for writing end-to-end tests for your Contract handlers.
    11  type TestContext struct {
    12  	ls         localStore
    13  	handlers   HandlerFuncChain
    14  	inMsg      Message
    15  	inHdr      EnvelopeHdr
    16  	clientIP   string
    17  	expectFunc func(...*Envelope) error
    18  }
    19  
    20  func NewTestContext() *TestContext {
    21  	return &TestContext{
    22  		ls: localStore{
    23  			kv: map[string]any{},
    24  		},
    25  	}
    26  }
    27  
    28  func (testCtx *TestContext) SetHandler(h ...HandlerFunc) *TestContext {
    29  	testCtx.handlers = h
    30  
    31  	return testCtx
    32  }
    33  
    34  func (testCtx *TestContext) SetClientIP(ip string) *TestContext {
    35  	testCtx.clientIP = ip
    36  
    37  	return testCtx
    38  }
    39  
    40  func (testCtx *TestContext) Input(m Message, hdr EnvelopeHdr) *TestContext {
    41  	testCtx.inMsg = m
    42  	testCtx.inHdr = hdr
    43  
    44  	return testCtx
    45  }
    46  
    47  func (testCtx *TestContext) Receiver(f func(out ...*Envelope) error) *TestContext {
    48  	testCtx.expectFunc = f
    49  
    50  	return testCtx
    51  }
    52  
    53  func (testCtx *TestContext) Run(stream bool) error {
    54  	ctx := newContext(&testCtx.ls)
    55  	conn := newTestConn()
    56  	conn.clientIP = testCtx.clientIP
    57  	conn.stream = stream
    58  	ctx.conn = conn
    59  	ctx.in = newEnvelope(ctx, conn, false)
    60  	ctx.in.
    61  		SetMsg(testCtx.inMsg).
    62  		SetHdrMap(testCtx.inHdr)
    63  	ctx.handlers = append(ctx.handlers, testCtx.handlers...)
    64  	ctx.Next()
    65  
    66  	return testCtx.expectFunc(conn.out...)
    67  }
    68  
    69  func (testCtx *TestContext) RunREST() error {
    70  	ctx := newContext(&testCtx.ls)
    71  	conn := newTestRESTConn()
    72  	conn.clientIP = testCtx.clientIP
    73  	conn.stream = false
    74  	ctx.conn = conn
    75  	ctx.in = newEnvelope(ctx, conn, false)
    76  	ctx.in.
    77  		SetMsg(testCtx.inMsg).
    78  		SetHdrMap(testCtx.inHdr)
    79  	ctx.handlers = append(ctx.handlers, testCtx.handlers...)
    80  	ctx.Next()
    81  
    82  	return testCtx.expectFunc(conn.out...)
    83  }
    84  
    85  type testConn struct {
    86  	sync.Mutex
    87  
    88  	id       uint64
    89  	clientIP string
    90  	stream   bool
    91  	kv       map[string]string
    92  	out      []*Envelope
    93  }
    94  
    95  var _ Conn = (*testConn)(nil)
    96  
    97  func newTestConn() *testConn {
    98  	return &testConn{
    99  		id: utils.RandomUint64(0),
   100  	}
   101  }
   102  
   103  func (t *testConn) ConnID() uint64 {
   104  	return t.id
   105  }
   106  
   107  func (t *testConn) ClientIP() string {
   108  	return t.clientIP
   109  }
   110  
   111  func (t *testConn) Write(_ []byte) (int, error) {
   112  	return 0, nil
   113  }
   114  
   115  func (t *testConn) WriteEnvelope(e *Envelope) error {
   116  	e.dontReuse()
   117  	t.Lock()
   118  	t.out = append(t.out, e)
   119  	t.Unlock()
   120  
   121  	return nil
   122  }
   123  
   124  func (t *testConn) Stream() bool {
   125  	return t.stream
   126  }
   127  
   128  func (t *testConn) Walk(f func(key string, val string) bool) {
   129  	t.Lock()
   130  	defer t.Unlock()
   131  
   132  	for k, v := range t.kv {
   133  		if !f(k, v) {
   134  			return
   135  		}
   136  	}
   137  }
   138  
   139  func (t *testConn) Get(key string) string {
   140  	t.Lock()
   141  	defer t.Unlock()
   142  
   143  	return t.kv[key]
   144  }
   145  
   146  func (t *testConn) Set(key string, val string) {
   147  	t.Lock()
   148  	t.kv[key] = val
   149  	t.Unlock()
   150  }
   151  
   152  func (t *testConn) Keys() []string {
   153  	keys := make([]string, 0, len(t.kv))
   154  	for k := range t.kv {
   155  		keys = append(keys, k)
   156  	}
   157  
   158  	return keys
   159  }
   160  
   161  type testRESTConn struct {
   162  	testConn
   163  
   164  	method     string
   165  	path       string
   166  	host       string
   167  	requestURI string
   168  	statusCode int
   169  }
   170  
   171  var _ RESTConn = (*testRESTConn)(nil)
   172  
   173  func newTestRESTConn() *testRESTConn {
   174  	return &testRESTConn{
   175  		testConn: testConn{
   176  			id: utils.RandomUint64(0),
   177  		},
   178  	}
   179  }
   180  
   181  func (t *testRESTConn) WalkQueryParams(f func(key string, val string) bool) {}
   182  
   183  func (t *testRESTConn) GetHost() string {
   184  	return t.host
   185  }
   186  
   187  func (t *testRESTConn) GetRequestURI() string {
   188  	return t.requestURI
   189  }
   190  
   191  func (t *testRESTConn) GetMethod() string {
   192  	return t.method
   193  }
   194  
   195  func (t *testRESTConn) GetPath() string {
   196  	return t.path
   197  }
   198  
   199  func (t *testRESTConn) Form() (*multipart.Form, error) {
   200  	panic("not implemented")
   201  }
   202  
   203  func (t *testRESTConn) SetStatusCode(code int) {
   204  	t.statusCode = code
   205  }
   206  
   207  func (t *testRESTConn) Redirect(_ int, _ string) {}