gitee.com/quant1x/engine@v1.8.4/factors/cache1d.go (about)

     1  package factors
     2  
     3  import (
     4  	"context"
     5  	"gitee.com/quant1x/engine/cache"
     6  	"gitee.com/quant1x/engine/market"
     7  	"gitee.com/quant1x/exchange"
     8  	"gitee.com/quant1x/gox/api"
     9  	"gitee.com/quant1x/gox/concurrent"
    10  	"gitee.com/quant1x/gox/coroutine"
    11  	"gitee.com/quant1x/gox/logger"
    12  	"gitee.com/quant1x/gox/tags"
    13  	"gitee.com/quant1x/gox/util/treemap"
    14  	"gitee.com/quant1x/pkg/tablewriter"
    15  	"os"
    16  	"strings"
    17  	"sync"
    18  )
    19  
    20  // Cache1D 缓存所有证券代码的特征组合数据
    21  //
    22  //	每天1个证券代码1条数据
    23  type Cache1D[T Feature] struct {
    24  	once        coroutine.PeriodicOnce
    25  	m           sync.RWMutex
    26  	factory     func(date, securityCode string) T
    27  	cacheKey    string // 缓存关键字
    28  	Date        string // 日期
    29  	filename    string // 缓存文件名
    30  	mapCache    *concurrent.TreeMap[string, T]
    31  	replaceDate string // 替换缓存的日期
    32  	allCodes    []string
    33  	tShadow     T // 泛型T的影子
    34  }
    35  
    36  // NewCache1D 创建一个新的C1D对象
    37  //
    38  //	key支持多级相对路径, 比如a/b, 创建的路径是~/.quant1x/a/b.yyyy-mm-dd
    39  func NewCache1D[T Feature](key string, factory func(date, securityCode string) T) *Cache1D[T] {
    40  	d1 := &Cache1D[T]{
    41  		cacheKey:    key,
    42  		Date:        "",
    43  		factory:     factory,
    44  		mapCache:    concurrent.NewTreeMap[string, T](),
    45  		replaceDate: "",
    46  		allCodes:    []string{},
    47  	}
    48  	d1.Date = cache.DefaultCanReadDate()
    49  	d1.allCodes = market.GetCodeList()
    50  	//d1.Checkout(d1.Date)
    51  	d1.filename = getCache1DFilepath(d1.cacheKey, d1.Date)
    52  	d1.tShadow = d1.factory(d1.Date, defaultSecurityCode)
    53  	RegisterFeatureRotationAdapter(key, d1)
    54  	return d1
    55  }
    56  
    57  func (this *Cache1D[T]) Factory(date, securityCode string) Feature {
    58  	return this.tShadow.Factory(date, securityCode)
    59  }
    60  
    61  func (this *Cache1D[T]) Init(ctx context.Context, date, securityCode string) error {
    62  	_ = ctx
    63  	_ = date
    64  	_ = securityCode
    65  	return nil
    66  }
    67  
    68  func (this *Cache1D[T]) Owner() string {
    69  	return this.tShadow.Owner()
    70  }
    71  
    72  func (this *Cache1D[T]) Kind() cache.Kind {
    73  	return this.tShadow.Kind()
    74  }
    75  
    76  func (this *Cache1D[T]) Key() string {
    77  	return this.tShadow.Key()
    78  }
    79  
    80  func (this *Cache1D[T]) Name() string {
    81  	return this.tShadow.Name()
    82  }
    83  
    84  func (this *Cache1D[T]) Usage() string {
    85  	return this.tShadow.Usage()
    86  }
    87  
    88  // Length 获取长度
    89  func (this *Cache1D[T]) Length() int {
    90  	return len(this.allCodes)
    91  }
    92  
    93  // loadCache 加载指定日期的数据
    94  func (this *Cache1D[T]) loadCache(date string) {
    95  	// 重置个股列表
    96  	this.allCodes = market.GetCodeList()
    97  	this.Date = exchange.FixTradeDate(date)
    98  	this.filename = getCache1DFilepath(this.cacheKey, this.Date)
    99  	logger.Warnf("%s: date=%s, filename=%s", this.cacheKey, this.Date, this.filename)
   100  	var list []T
   101  	err := api.CsvToSlices(this.filename, &list)
   102  	if err != nil || len(list) == 0 {
   103  		logger.Errorf("%s 没有有效数据, error=%+v", this.filename, err)
   104  		return
   105  	}
   106  	for _, v := range list {
   107  		code := v.GetSecurityCode()
   108  		this.mapCache.Put(code, v)
   109  	}
   110  }
   111  
   112  // 加载默认数据, 日期为当前交易中的日期
   113  func (this *Cache1D[T]) loadDefault() {
   114  	this.loadCache(this.Date)
   115  }
   116  
   117  // ReplaceCache 替换当前缓存数据
   118  func (this *Cache1D[T]) ReplaceCache() {
   119  	this.mapCache.Clear()
   120  	this.loadCache(this.replaceDate)
   121  }
   122  
   123  func (this *Cache1D[T]) Checkout(date ...string) {
   124  	if len(date) > 0 {
   125  		this.m.Lock()
   126  		destDate := exchange.FixTradeDate(date[0])
   127  		if this.Date != destDate {
   128  			this.replaceDate = destDate
   129  		}
   130  		this.m.Unlock()
   131  	}
   132  	if len(this.replaceDate) == 0 || this.Date == this.replaceDate {
   133  		this.once.Do(this.loadDefault)
   134  	} else {
   135  		// 重置once锁计数器为0
   136  		this.once.Reset()
   137  		this.once.Do(this.ReplaceCache)
   138  	}
   139  }
   140  
   141  func checkoutTable(v any) (headers []string, records [][]string) {
   142  	headers = []string{"字段", "数值"}
   143  	fields := tags.GetHeadersByTags(v)
   144  	values := tags.GetValuesByTags(v)
   145  	num := len(fields)
   146  	if num > len(values) {
   147  		num = len(values)
   148  	}
   149  	for i := 0; i < num; i++ {
   150  		records = append(records, []string{fields[i], strings.TrimSpace(values[i])})
   151  	}
   152  	return
   153  }
   154  
   155  func (this *Cache1D[T]) Print(code string, date ...string) {
   156  	securityCode := exchange.CorrectSecurityCode(code)
   157  	tradeDate := cache.DefaultCanReadDate()
   158  	if len(date) > 0 {
   159  		tradeDate = exchange.FixTradeDate(date[0])
   160  	}
   161  	value := this.Get(securityCode, tradeDate)
   162  	if value != nil {
   163  		headers, records := checkoutTable(*value)
   164  		table := tablewriter.NewWriter(os.Stdout)
   165  		table.SetHeader(headers)
   166  		table.SetColumnAlignment([]int{tablewriter.ALIGN_RIGHT, tablewriter.ALIGN_LEFT})
   167  		table.AppendBulk(records)
   168  		table.Render()
   169  	}
   170  }
   171  
   172  func (this *Cache1D[T]) Check(cacheDate, featureDate string) {
   173  	_ = cacheDate
   174  	_ = featureDate
   175  	//TODO implement me
   176  	panic("implement me")
   177  }
   178  
   179  // Get 获取指定证券代码的数据
   180  func (this *Cache1D[T]) Get(securityCode string, date ...string) *T {
   181  	this.Checkout(date...)
   182  	this.once.Do(this.loadDefault)
   183  	t, ok := this.mapCache.Get(securityCode)
   184  	if ok {
   185  		return &t
   186  	}
   187  	return nil
   188  }
   189  
   190  // Set 更新map中指定证券代码的数据
   191  func (this *Cache1D[T]) Set(securityCode string, newValue T, date ...string) {
   192  	this.Checkout(date...)
   193  	this.once.Do(this.loadDefault)
   194  	this.mapCache.Put(securityCode, newValue)
   195  }
   196  
   197  func (this *Cache1D[T]) Filter(f func(v T) bool) []T {
   198  	var list []T
   199  	if f == nil {
   200  		return nil
   201  	}
   202  	for _, securityCode := range this.allCodes {
   203  		v, found := this.mapCache.Get(securityCode)
   204  		if found {
   205  			if ok := f(v); ok {
   206  				list = append(list, v)
   207  			}
   208  		}
   209  	}
   210  	return list
   211  }
   212  
   213  // Apply 数据合并
   214  //
   215  //	泛型T需要保持一个string类型的Date字段
   216  func (this *Cache1D[T]) Apply(merge func(code string, local *T) (updated bool)) {
   217  	list := make([]T, 0, len(this.allCodes))
   218  	for _, securityCode := range this.allCodes {
   219  		v, found := this.mapCache.Get(securityCode)
   220  		if !found && this.factory != nil {
   221  			v = this.factory(this.Date, securityCode)
   222  		}
   223  		if merge != nil {
   224  			ok := merge(securityCode, &v)
   225  			if ok {
   226  				this.mapCache.Put(securityCode, v)
   227  			}
   228  		}
   229  		list = append(list, v)
   230  	}
   231  	if len(list) > 0 {
   232  		err := api.SlicesToCsv(this.filename, list)
   233  		if err != nil {
   234  			logger.Errorf("刷新%s异常:%+v", this.filename, err)
   235  		}
   236  	}
   237  }
   238  
   239  func (this *Cache1D[T]) Merge(p *treemap.Map) {
   240  	list := make([]T, 0, len(this.allCodes))
   241  	for _, securityCode := range this.allCodes {
   242  		v, found := this.mapCache.Get(securityCode)
   243  		if !found && this.factory != nil {
   244  			v = this.factory(this.Date, securityCode)
   245  		}
   246  		if p != nil {
   247  			tmp, ok := p.Get(securityCode)
   248  			if ok {
   249  				_ = api.CopyWithOption(v, tmp, api.Option{})
   250  				if ok {
   251  					this.mapCache.Put(securityCode, v)
   252  				}
   253  			}
   254  		}
   255  		list = append(list, v)
   256  	}
   257  	if len(list) > 0 {
   258  		err := api.SlicesToCsv(this.filename, list)
   259  		if err != nil {
   260  			logger.Errorf("刷新%s异常:%+v", this.filename, err)
   261  		}
   262  	}
   263  }