go-hep.org/x/hep@v0.38.1/fads/btagging.go (about)

     1  // Copyright ©2017 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fads
     6  
     7  import (
     8  	"math"
     9  	"math/rand/v2"
    10  	"reflect"
    11  	"sync"
    12  
    13  	"go-hep.org/x/hep/fmom"
    14  	"go-hep.org/x/hep/fwk"
    15  	"gonum.org/v1/gonum/stat/distuv"
    16  )
    17  
    18  type btagclassifier struct {
    19  	PtMin  float64
    20  	EtaMax float64
    21  }
    22  
    23  func (btag btagclassifier) Category(parton *Candidate) int {
    24  	if parton.Mom.Pt() <= btag.PtMin || math.Abs(parton.Mom.Eta()) > btag.EtaMax {
    25  		return -1
    26  	}
    27  	pdg := parton.Pid
    28  	if pdg < 0 {
    29  		pdg = -pdg
    30  	}
    31  
    32  	if pdg != 21 && pdg > 5 {
    33  		return -1
    34  	}
    35  
    36  	return 0
    37  }
    38  
    39  type BTagging struct {
    40  	fwk.TaskBase
    41  
    42  	partons string
    43  	jets    string
    44  	output  string
    45  
    46  	dR  float64
    47  	bit uint
    48  
    49  	btag btagclassifier
    50  	eff  map[int]func(pt, eta float64) float64
    51  
    52  	seed uint64
    53  	src  *rand.Rand
    54  
    55  	flat   distuv.Uniform
    56  	flatmu sync.Mutex
    57  }
    58  
    59  func (tsk *BTagging) Configure(ctx fwk.Context) error {
    60  	var err error
    61  
    62  	err = tsk.DeclInPort(tsk.partons, reflect.TypeOf([]Candidate{}))
    63  	if err != nil {
    64  		return err
    65  	}
    66  
    67  	err = tsk.DeclInPort(tsk.jets, reflect.TypeOf([]Candidate{}))
    68  	if err != nil {
    69  		return err
    70  	}
    71  
    72  	err = tsk.DeclOutPort(tsk.output, reflect.TypeOf([]Candidate{}))
    73  	if err != nil {
    74  		return err
    75  	}
    76  
    77  	tsk.src = rand.New(rand.NewPCG(tsk.seed, tsk.seed))
    78  	tsk.flat = distuv.Uniform{Min: 0, Max: 1, Src: tsk.src}
    79  	return err
    80  }
    81  
    82  func (tsk *BTagging) StartTask(ctx fwk.Context) error {
    83  	var err error
    84  
    85  	return err
    86  }
    87  
    88  func (tsk *BTagging) StopTask(ctx fwk.Context) error {
    89  	var err error
    90  
    91  	return err
    92  }
    93  
    94  func (tsk *BTagging) Process(ctx fwk.Context) error {
    95  	var err error
    96  
    97  	store := ctx.Store()
    98  	msg := ctx.Msg()
    99  
   100  	v, err := store.Get(tsk.partons)
   101  	if err != nil {
   102  		return err
   103  	}
   104  
   105  	allpartons := v.([]Candidate)
   106  
   107  	v, err = store.Get(tsk.jets)
   108  	if err != nil {
   109  		return err
   110  	}
   111  	jets := v.([]Candidate)
   112  
   113  	output := make([]Candidate, 0, len(jets))
   114  	defer func() {
   115  		err = store.Put(tsk.output, jets)
   116  	}()
   117  
   118  	msg.Debugf("partons: %d\n", len(allpartons))
   119  	msg.Debugf("jets:    %d\n", len(jets))
   120  
   121  	partons := make([]Candidate, 0, len(allpartons))
   122  	for i := range allpartons {
   123  		cand := &allpartons[i]
   124  		if tsk.btag.Category(cand) < 0 {
   125  			continue
   126  		}
   127  		partons = append(partons, *cand)
   128  	}
   129  
   130  	for i := range jets {
   131  		jet := jets[i].Clone()
   132  		pdgmax := -1
   133  		eta := jet.Mom.Eta()
   134  		pt := jet.Mom.Pt()
   135  
   136  		for j := range partons {
   137  			p := &partons[j]
   138  			pdg := int(p.Pid)
   139  			if pdg < 0 {
   140  				pdg = -pdg
   141  			}
   142  			if pdg == 21 {
   143  				pdg = 0
   144  			}
   145  			if fmom.DeltaR(&jet.Mom, &p.Mom) < tsk.dR {
   146  				if pdgmax < pdg {
   147  					pdgmax = pdg
   148  				}
   149  			}
   150  		}
   151  
   152  		switch pdgmax {
   153  		case 0:
   154  			pdgmax = 21
   155  		case -1:
   156  			pdgmax = 0
   157  		}
   158  
   159  		eff, ok := tsk.eff[pdgmax]
   160  		if !ok {
   161  			eff = tsk.eff[0]
   162  		}
   163  
   164  		// apply efficiency
   165  		tag := uint32(0)
   166  		tsk.flatmu.Lock()
   167  		if tsk.flat.Rand() <= eff(pt, eta) {
   168  			tag = 1
   169  		}
   170  		tsk.flatmu.Unlock()
   171  		jet.BTag |= tag << tsk.bit
   172  
   173  		output = append(output, *jet)
   174  	}
   175  
   176  	msg.Debugf("output:  %d\n", len(output))
   177  	return err
   178  }
   179  
   180  func newBTagging(typ, name string, mgr fwk.App) (fwk.Component, error) {
   181  	var err error
   182  
   183  	tsk := &BTagging{
   184  		TaskBase: fwk.NewTask(typ, name, mgr),
   185  		partons:  "InputPartons",
   186  		jets:     "InputJets",
   187  		output:   "OutputJets",
   188  
   189  		bit: 0,
   190  		dR:  0.5,
   191  		btag: btagclassifier{
   192  			PtMin:  1.0,
   193  			EtaMax: 2.5,
   194  		},
   195  		eff: map[int]func(pt, eta float64) float64{
   196  			0: func(pt, eta float64) float64 { return 0 },
   197  		},
   198  
   199  		seed: 1234,
   200  	}
   201  
   202  	err = tsk.DeclProp("Partons", &tsk.partons)
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  
   207  	err = tsk.DeclProp("Jets", &tsk.jets)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  
   212  	err = tsk.DeclProp("Output", &tsk.output)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  
   217  	err = tsk.DeclProp("BitNumber", &tsk.bit)
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  
   222  	err = tsk.DeclProp("DeltaR", &tsk.dR)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	err = tsk.DeclProp("PartonPtMin", &tsk.btag.PtMin)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  
   232  	err = tsk.DeclProp("PartonEtaMax", &tsk.btag.EtaMax)
   233  	if err != nil {
   234  		return nil, err
   235  	}
   236  
   237  	err = tsk.DeclProp("Eff", &tsk.eff)
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  
   242  	err = tsk.DeclProp("Seed", &tsk.seed)
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  
   247  	return tsk, err
   248  }
   249  
   250  func init() {
   251  	fwk.Register(reflect.TypeOf(BTagging{}), newBTagging)
   252  }