github.com/xg0n/routine@v0.0.0-20240119033701-c364deb94aee/thread_test.go (about)

     1  package routine
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"runtime"
     7  	"runtime/pprof"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  func TestCurrentThread(t *testing.T) {
    16  	assert.NotNil(t, currentThread(true))
    17  	assert.Same(t, currentThread(true), currentThread(true))
    18  }
    19  
    20  func TestPProf(t *testing.T) {
    21  	const concurrency = 10
    22  	const loopTimes = 10
    23  	tls := NewThreadLocal[any]()
    24  	tls.Set("你好")
    25  	wg := &sync.WaitGroup{}
    26  	wg.Add(concurrency)
    27  	for i := 0; i < concurrency; i++ {
    28  		tmp := i
    29  		go func() {
    30  			for j := 0; j < loopTimes; j++ {
    31  				time.Sleep(100 * time.Millisecond)
    32  				tls.Set(tmp)
    33  				assert.Equal(t, tmp, tls.Get())
    34  				pprof.Do(context.Background(), pprof.Labels("key", "value"), func(ctx context.Context) {
    35  					assert.Nil(t, currentThread(false))
    36  					assert.Nil(t, tls.Get())
    37  					tls.Set("hi")
    38  					//
    39  					label, find := pprof.Label(ctx, "key")
    40  					assert.True(t, find)
    41  					assert.Equal(t, "value", label)
    42  					//
    43  					assert.Equal(t, "hi", tls.Get())
    44  					//
    45  					label2, find2 := pprof.Label(ctx, "key")
    46  					assert.True(t, find2)
    47  					assert.Equal(t, "value", label2)
    48  				})
    49  				assert.Nil(t, tls.Get())
    50  			}
    51  			wg.Done()
    52  		}()
    53  	}
    54  	assert.Nil(t, pprof.StartCPUProfile(&bytes.Buffer{}))
    55  	wg.Wait()
    56  	pprof.StopCPUProfile()
    57  	assert.Equal(t, "你好", tls.Get())
    58  }
    59  
    60  func TestThreadGC(t *testing.T) {
    61  	const allocSize = 10_000_000
    62  	tls := NewThreadLocal[[]byte]()
    63  	tls2 := NewInheritableThreadLocal[[]byte]()
    64  	allocWait := &sync.WaitGroup{}
    65  	allocWait.Add(1)
    66  	gatherWait := &sync.WaitGroup{}
    67  	gatherWait.Add(1)
    68  	gcWait := &sync.WaitGroup{}
    69  	gcWait.Add(1)
    70  	//=========Init
    71  	heapInit, numInit := getMemStats()
    72  	printMemStats("Init", heapInit, numInit)
    73  	//
    74  	task := GoWait(func(token CancelToken) {
    75  		tls.Set(make([]byte, allocSize))
    76  		tls2.Set(make([]byte, allocSize))
    77  		go func() {
    78  			gcWait.Wait()
    79  		}()
    80  		task2 := GoWaitResult(func(token CancelToken) int {
    81  			return 1
    82  		})
    83  		assert.Equal(t, 1, task2.Get())
    84  		allocWait.Done()  //alloc ok, release main thread
    85  		gatherWait.Wait() //wait gather heap info
    86  	})
    87  	//=========Alloc
    88  	allocWait.Wait() //wait alloc done
    89  	heapAlloc, numAlloc := getMemStats()
    90  	printMemStats("Alloc", heapAlloc, numAlloc)
    91  	assert.Greater(t, heapAlloc, heapInit+allocSize*2*0.9)
    92  	assert.Greater(t, numAlloc, numInit)
    93  	//=========GC
    94  	gatherWait.Done() //gather ok, release sub thread
    95  	task.Get()        //wait sub thread finish
    96  	time.Sleep(500 * time.Millisecond)
    97  	heapGC, numGC := getMemStats()
    98  	printMemStats("AfterGC", heapGC, numGC)
    99  	gcWait.Done()
   100  	//=========Summary
   101  	heapRelease := heapAlloc - heapGC
   102  	numRelease := numAlloc - numGC
   103  	printMemStats("Summary", heapRelease, numRelease)
   104  	assert.Greater(t, int(heapRelease), int(allocSize*2*0.9))
   105  	assert.Equal(t, 1, numRelease)
   106  }
   107  
   108  func getMemStats() (uint64, int) {
   109  	stats := runtime.MemStats{}
   110  	runtime.GC()
   111  	runtime.ReadMemStats(&stats)
   112  	return stats.HeapAlloc, runtime.NumGoroutine()
   113  }
   114  
   115  func printMemStats(section string, heapAlloc uint64, numGoroutine int) {
   116  	//fmt.Printf("%v\n", section)
   117  	//fmt.Printf("HeapAlloc    = %v\n", heapAlloc)
   118  	//fmt.Printf("NumGoroutine = %v\n", numGoroutine)
   119  	//fmt.Printf("===\n")
   120  }