gitee.com/quant1x/engine@v1.8.4/tracker/backtesting.go (about)

     1  package tracker
     2  
     3  import (
     4  	"fmt"
     5  	"gitee.com/quant1x/engine/cache"
     6  	"gitee.com/quant1x/engine/config"
     7  	"gitee.com/quant1x/engine/factors"
     8  	"gitee.com/quant1x/engine/market"
     9  	"gitee.com/quant1x/engine/models"
    10  	"gitee.com/quant1x/engine/storages"
    11  	"gitee.com/quant1x/exchange"
    12  	"gitee.com/quant1x/gox/api"
    13  	"gitee.com/quant1x/gox/progressbar"
    14  	"gitee.com/quant1x/gox/tags"
    15  	"gitee.com/quant1x/num"
    16  	"gitee.com/quant1x/pandas"
    17  	"gitee.com/quant1x/pkg/tablewriter"
    18  	"os"
    19  	"sort"
    20  )
    21  
    22  // GoodCase good case
    23  type GoodCase struct {
    24  	Date   string  `dataframe:"日期"`
    25  	Num    int     `dataframe:"数量"`
    26  	Yields float64 `dataframe:"浮动收益率%"`
    27  	//NextYields float64 `dataframe:"隔日收益率%"`
    28  	GtP1 float64 `dataframe:"胜率率%"`
    29  	GtP2 float64 `dataframe:"溢价超1%"`
    30  	GtP3 float64 `dataframe:"溢价超2%"`
    31  	GtP4 float64 `dataframe:"溢价超3%"`
    32  	GtP5 float64 `dataframe:"溢价超5%"`
    33  }
    34  
    35  // SampleFeature 样本特征
    36  type SampleFeature struct {
    37  	SecurityCode      string
    38  	Name              string
    39  	OpenChangeRate    float64
    40  	OpenTurnZ         float64
    41  	LastClose         float64
    42  	Open              float64
    43  	Price             float64
    44  	UpRate            float64
    45  	OpenPremiumRate   float64
    46  	NextPremiumRate   float64
    47  	OpenQuantityRatio float64 // 量比
    48  	Beta              float64
    49  	Alpha             float64
    50  }
    51  
    52  // BackTesting 回测
    53  func BackTesting(strategyNo uint64, countDays, countTopN int) {
    54  	currentlyDay := exchange.GetCurrentlyDay()
    55  	dates := exchange.TradingDateRange(exchange.MARKET_CH_FIRST_LISTTIME, currentlyDay)
    56  	scope := api.RangeFinite(-countDays)
    57  	s, e, err := scope.Limits(len(dates))
    58  	if err != nil {
    59  		fmt.Println(err)
    60  		return
    61  	}
    62  	model, err := models.CheckoutStrategy(strategyNo)
    63  	if err != nil {
    64  		fmt.Println(err)
    65  		return
    66  	}
    67  	//TODO: 这里应该要取策略的规则参数
    68  	tradeRule := config.GetStrategyParameterByCode(strategyNo)
    69  	if tradeRule == nil {
    70  		return
    71  	}
    72  	backTestingParameter := config.GetDataConfig().BackTesting
    73  	allResult := []models.Statistics{}
    74  	gcs := []GoodCase{}
    75  	dates = dates[s : e+1]
    76  	codes := market.GetCodeList()
    77  	mapStock := map[string][]factors.SecurityFeature{}
    78  	for i, date := range dates {
    79  		// 切换策略数据的缓存日期
    80  		factors.SwitchDate(date)
    81  		var marketPrices []float64
    82  		stockSnapshots := []factors.QuoteSnapshot{}
    83  		total := len(codes)
    84  		pos := 0
    85  		bar := progressbar.NewBar(1, "执行["+date+"涨幅扫描]", total)
    86  		for _, securityCode := range codes {
    87  			bar.Add(1)
    88  			if !exchange.AssertStockBySecurityCode(securityCode) && securityCode != backTestingParameter.TargetIndex {
    89  				continue
    90  			}
    91  			//features := factors.CheckoutWideTableByDate(securityCode, date)
    92  			features, ok := mapStock[securityCode]
    93  			if !ok {
    94  				filename := cache.WideFilename(securityCode)
    95  				err := api.CsvToSlices(filename, &features)
    96  				if err != nil {
    97  					continue
    98  				}
    99  				mapStock[securityCode] = features
   100  			}
   101  			if securityCode == backTestingParameter.TargetIndex && len(marketPrices) == 0 {
   102  				for _, m := range features {
   103  					marketPrices = append(marketPrices, m.ChangeRate)
   104  				}
   105  			}
   106  			length := len(features)
   107  			pos = length - countDays + i
   108  			if pos < 0 {
   109  				continue
   110  			}
   111  			markets := marketPrices[:pos+1]
   112  			prices := make([]float64, pos+1)
   113  			for si, sv := range features[:pos+1] {
   114  				prices[si] = sv.ChangeRate
   115  			}
   116  			feature := features[pos]
   117  			snapshot := models.FeatureToSnapshot(feature, securityCode)
   118  			// 下一个交易日开盘价
   119  			diffDays := 1
   120  			if pos+diffDays < length {
   121  				nextFeature := features[pos+diffDays]
   122  				snapshot.NextOpen = nextFeature.Open
   123  				snapshot.NextClose = nextFeature.Close
   124  				snapshot.NextHigh = nextFeature.High
   125  				snapshot.NextLow = nextFeature.Low
   126  			}
   127  			snapshot.Beta, snapshot.Alpha = exchange.EvaluateYields(prices, markets, config.TraderConfig().DailyRiskFreeRate(date))
   128  			snapshot.Beta *= 100
   129  			snapshot.Alpha *= 100
   130  			stockSnapshots = append(stockSnapshots, snapshot)
   131  		}
   132  		bar.Wait()
   133  		if len(stockSnapshots) == 0 {
   134  			continue
   135  		}
   136  
   137  		// 过滤不符合条件的个股
   138  		stockSnapshots = api.Filter(stockSnapshots, func(snapshot factors.QuoteSnapshot) bool {
   139  			return model.Filter(tradeRule.Rules, snapshot) == nil
   140  		})
   141  		// 排序
   142  		sortedStatus := model.Sort(stockSnapshots)
   143  		if sortedStatus == models.SortDefault || sortedStatus == models.SortNotExecuted {
   144  			sort.Slice(stockSnapshots, func(i, j int) bool {
   145  				a := stockSnapshots[i]
   146  				b := stockSnapshots[j]
   147  				if a.OpenTurnZ > b.OpenTurnZ {
   148  					return true
   149  				}
   150  				return a.OpenTurnZ == b.OpenTurnZ && a.OpeningChangeRate > b.OpeningChangeRate
   151  			})
   152  		}
   153  
   154  		samples := []SampleFeature{}
   155  		for _, snapshot := range stockSnapshots {
   156  			securityCode := snapshot.SecurityCode
   157  			// 获取证券名称
   158  			securityName := "unknown"
   159  			f10 := factors.GetL5F10(securityCode)
   160  			if f10 != nil {
   161  				securityName = f10.SecurityName
   162  			}
   163  
   164  			sample := SampleFeature{
   165  				Name:              securityName,
   166  				SecurityCode:      securityCode,
   167  				OpenQuantityRatio: snapshot.OpenQuantityRatio,
   168  				OpenTurnZ:         snapshot.OpenTurnZ,
   169  				OpenChangeRate:    num.NetChangeRate(snapshot.LastClose, snapshot.Open),
   170  				LastClose:         snapshot.LastClose,
   171  				Open:              snapshot.Open,
   172  				Price:             snapshot.Price,
   173  				UpRate:            num.NetChangeRate(snapshot.LastClose, snapshot.Price),
   174  				OpenPremiumRate:   num.NetChangeRate(snapshot.Open, snapshot.Price),
   175  				NextPremiumRate:   num.NetChangeRate(snapshot.Open, snapshot.NextOpen),
   176  			}
   177  			switch tradeRule.Flag {
   178  			case models.OrderFlagHead:
   179  				sample.OpenPremiumRate = num.NetChangeRate(snapshot.Open, snapshot.Price)
   180  				sample.NextPremiumRate = num.NetChangeRate(snapshot.Open, snapshot.NextOpen)
   181  			case models.OrderFlagTail:
   182  				sample.OpenPremiumRate = num.NetChangeRate(snapshot.Price, snapshot.Price)
   183  				sample.NextPremiumRate = num.NetChangeRate(snapshot.Price, snapshot.NextClose)
   184  				if snapshot.Price < snapshot.NextClose && snapshot.Price*(1+backTestingParameter.NextPremiumRate+0.005) < snapshot.NextHigh {
   185  					sample.NextPremiumRate = num.NetChangeRate(snapshot.Price, snapshot.Price*(1+backTestingParameter.NextPremiumRate))
   186  				}
   187  			case models.OrderFlagTick:
   188  				sample.OpenPremiumRate = num.NetChangeRate(snapshot.Price, snapshot.Price)
   189  				sample.NextPremiumRate = num.NetChangeRate(snapshot.Price, snapshot.NextClose)
   190  			}
   191  			sample.Beta = snapshot.Beta
   192  			sample.Alpha = snapshot.Alpha
   193  			samples = append(samples, sample)
   194  		}
   195  
   196  		// 单日回测结果
   197  		// 检查有效记录最大数
   198  		topN := countTopN
   199  		if topN > len(samples) {
   200  			topN = len(samples)
   201  		}
   202  
   203  		tbl := tablewriter.NewWriter(os.Stdout)
   204  		tbl.SetHeader(tags.GetHeadersByTags(models.Statistics{}))
   205  		samples = samples[:topN]
   206  		var results []models.Statistics
   207  		for _, v := range samples {
   208  			zs := models.Statistics{
   209  				Date:            date,                // 日期
   210  				Code:            v.SecurityCode,      // 证券代码
   211  				Name:            v.Name,              // 证券名称
   212  				OpenRaise:       v.OpenChangeRate,    // 开盘涨幅
   213  				TurnZ:           v.OpenTurnZ,         // 开盘换手率z
   214  				QuantityRatio:   v.OpenQuantityRatio, // 开盘量比
   215  				LastClose:       v.LastClose,         // 昨日收盘
   216  				Open:            v.Open,              // 开盘价
   217  				Price:           v.Price,             // 现价
   218  				UpRate:          v.UpRate,            // 涨跌幅
   219  				OpenPremiumRate: v.OpenPremiumRate,   // 集合竞价买入, 溢价率
   220  				NextPremiumRate: v.NextPremiumRate,   // 隔日溢价率
   221  				Beta:            v.Beta,
   222  				Alpha:           v.Alpha,
   223  			}
   224  			switch tradeRule.Flag {
   225  			case models.OrderFlagHead:
   226  				zs.UpdateTime = zs.Date + " 09:27:10.000"
   227  			case models.OrderFlagTail:
   228  				zs.UpdateTime = zs.Date + " 14:56:10.000"
   229  			case models.OrderFlagTick:
   230  				zs.UpdateTime = zs.Date + " 14:56:10.000"
   231  			}
   232  
   233  			results = append(results, zs)
   234  		}
   235  		gtP1 := 0 // 存在溢价
   236  		gtP2 := 0 // 超过1%
   237  		gtP3 := 0 // 超过2%
   238  		gtP4 := 0 // 超过3%
   239  		gtP5 := 0 // 超过5%
   240  		yields := float64(0.00)
   241  		for _, v := range results {
   242  			rate := v.NextPremiumRate
   243  			if rate > 0 {
   244  				gtP1 += 1
   245  			}
   246  			if rate >= 1.00 {
   247  				gtP2 += 1
   248  			}
   249  			if rate >= 2.00 {
   250  				gtP3 += 1
   251  			}
   252  			if rate >= 3.00 {
   253  				gtP4 += 1
   254  			}
   255  			if rate >= 5.00 {
   256  				gtP5 += 1
   257  			}
   258  			yields += rate
   259  			tbl.Append(tags.GetValuesByTags(v))
   260  		}
   261  		yields /= float64(len(results))
   262  		fmt.Println() // 输出一个换行
   263  		tbl.Render()
   264  		count := len(samples)
   265  		gc := GoodCase{
   266  			Date:   date,
   267  			Num:    count,
   268  			Yields: yields,
   269  			GtP1:   100 * float64(gtP1) / float64(count),
   270  			GtP2:   100 * float64(gtP2) / float64(count),
   271  			GtP3:   100 * float64(gtP3) / float64(count),
   272  			GtP4:   100 * float64(gtP4) / float64(count),
   273  			GtP5:   100 * float64(gtP5) / float64(count),
   274  		}
   275  		if num.IsNaN(gc.Yields) {
   276  			gc.Yields = 0
   277  		}
   278  		if num.IsNaN(gc.GtP1) {
   279  			gc.GtP1 = 0
   280  		}
   281  		gcs = append(gcs, gc)
   282  		fmt.Println(date + ", 胜率统计:")
   283  		fmt.Printf("\t==> 胜    率: %d/%d, %.2f%%, 收益率: %.2f%%\n", gtP1, count, 100*float64(gtP1)/float64(count), yields)
   284  		fmt.Printf("\t==> 溢价超1%%: %d/%d, %.2f%%\n", gtP2, count, 100*float64(gtP2)/float64(count))
   285  		fmt.Printf("\t==> 溢价超2%%: %d/%d, %.2f%%\n", gtP3, count, 100*float64(gtP3)/float64(count))
   286  		fmt.Printf("\t==> 溢价超3%%: %d/%d, %.2f%%\n", gtP4, count, 100*float64(gtP4)/float64(count))
   287  		fmt.Printf("\t==> 溢价超5%%: %d/%d, %.2f%%\n", gtP5, count, 100*float64(gtP5)/float64(count))
   288  		fmt.Println()
   289  		allResult = append(allResult, results...)
   290  		//storages.OutputStatistics("tracker", topN, date, results)
   291  	}
   292  
   293  	// 合计输出
   294  	fmt.Printf("%s - %s 合计:\n", dates[0], dates[len(dates)-1])
   295  	today := cache.Today()
   296  	dfTotal := pandas.LoadStructs(gcs)
   297  	if dfTotal.Nrow() > 0 {
   298  		winningRate := dfTotal.Col("浮动收益率%").FillNa(0, true).Mean()
   299  		winningAverage := dfTotal.Col("胜率率%").FillNa(0, true).Mean()
   300  		fmt.Printf("\t==> 平均 浮动溢价率:%.4f%%, 平均 胜率率: %.4f%%\n", winningRate, winningAverage)
   301  		filename := fmt.Sprintf("%s/total-%s-%s-%d.csv", storages.GetResultCachePath(), tradeRule.QmtStrategyName(), today, countTopN)
   302  		_ = dfTotal.WriteCSV(filename)
   303  	}
   304  	dfRecords := pandas.LoadStructs(allResult)
   305  	if dfRecords.Nrow() > 0 {
   306  		fudu := dfRecords.Col("open_premium_rate").FillNa(0, true).Mean()
   307  		geri := dfRecords.Col("next_premium_rate").FillNa(0, true).Mean()
   308  		fmt.Printf("\t==> 平均 浮动溢价率:%.4f%%, 平均 隔日溢价率: %.4f%%\n", fudu, geri)
   309  		colNames := tags.GetHeadersByTags(allResult[0])
   310  		_ = dfRecords.SetNames(colNames...)
   311  		filename := fmt.Sprintf("%s/backtesting-%s-%s-%d.csv", storages.GetResultCachePath(), tradeRule.QmtStrategyName(), today, countTopN)
   312  		_ = dfRecords.WriteCSV(filename)
   313  	}
   314  	//fmt.Println("\n")
   315  }