github.com/KarpelesLab/contexter@v1.0.2/contexter_test.go (about)

     1  package contexter_test
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"log"
     8  	"runtime"
     9  	"testing"
    10  
    11  	"github.com/KarpelesLab/contexter"
    12  )
    13  
    14  //go:noinline
    15  func getTest(ctx context.Context, s string) context.Context {
    16  	// NOTE it is important that the value is being used at least once
    17  	ctx.Value(nil)
    18  	return getTest2()
    19  }
    20  
    21  func getTest2() context.Context {
    22  	return contexter.Context()
    23  }
    24  
    25  //go:noinline
    26  func getTestAny(ctx context.Context, s string) context.Context {
    27  	// NOTE it is important that the value is being used at least once
    28  	ctx.Value(nil)
    29  	return getTestAny2()
    30  }
    31  
    32  func getTestAny2() context.Context {
    33  	var ctx context.Context
    34  	if !contexter.Find(&ctx) {
    35  		log.Printf("NOT FOUND")
    36  	}
    37  	return ctx
    38  }
    39  
    40  func TestContext(t *testing.T) {
    41  	ctx := context.Background()
    42  
    43  	log.Printf("ctx = %p", ctx)
    44  
    45  	ctx2 := getTest(ctx, "hello world")
    46  	log.Printf("ctx2 = %p", ctx2)
    47  
    48  	if ctx != ctx2 {
    49  		t.Errorf("invalid value returned: %p", ctx2)
    50  	}
    51  
    52  	ctx3 := getTestAny(ctx, "hello world")
    53  	log.Printf("ctx3 = %p", ctx3)
    54  
    55  	if ctx != ctx3 {
    56  		t.Errorf("invalid value returned in any: %p", ctx3)
    57  	}
    58  }
    59  
    60  type TestObj struct{}
    61  
    62  func (t *TestObj) MarshalJSON() ([]byte, error) {
    63  	ctx := contexter.Context()
    64  	if ctx == nil {
    65  		return nil, errors.New("could not fetch context")
    66  	}
    67  
    68  	res := map[string]interface{}{"foo": ctx.Value("test")}
    69  	return json.Marshal(res)
    70  }
    71  
    72  //go:noinline
    73  func encodeJson(ctx context.Context, obj interface{}) ([]byte, error) {
    74  	res, err := json.Marshal(obj)
    75  	runtime.KeepAlive(ctx)
    76  	return res, err
    77  }
    78  
    79  func TestJson(t *testing.T) {
    80  	ctx := context.WithValue(context.Background(), "test", "bar")
    81  	obj := &TestObj{}
    82  
    83  	val, err := encodeJson(ctx, obj)
    84  
    85  	if err != nil {
    86  		t.Errorf("json test failed: %s", err)
    87  		return
    88  	}
    89  
    90  	if string(val) != `{"foo":"bar"}` {
    91  		t.Errorf("json output failed, should be {\"foo\":\"bar\"} but got %s", val)
    92  	}
    93  }