go-hep.org/x/hep@v0.38.1/fads/tautagging.go (about)

     1  // Copyright ©2017 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fads
     6  
     7  import (
     8  	"math"
     9  	"math/rand/v2"
    10  	"reflect"
    11  	"sync"
    12  
    13  	"go-hep.org/x/hep/fmom"
    14  	"go-hep.org/x/hep/fwk"
    15  	"gonum.org/v1/gonum/stat/distuv"
    16  )
    17  
    18  type tauclassifier struct {
    19  	PtMin  float64
    20  	EtaMax float64
    21  }
    22  
    23  func (tag tauclassifier) Category(tau *Candidate, particles []Candidate) int {
    24  
    25  	pdg := tau.Pid
    26  	if pdg < 0 {
    27  		pdg = -pdg
    28  	}
    29  	if pdg != 15 {
    30  		return -1
    31  	}
    32  
    33  	if tau.Mom.Pt() <= tag.PtMin || math.Abs(tau.Mom.Eta()) > tag.EtaMax {
    34  		return -1
    35  	}
    36  
    37  	if tau.D1 < 0 {
    38  		return -1
    39  	}
    40  
    41  	for i := tau.D1; i <= tau.D2; i++ {
    42  		daughter := &particles[i]
    43  		pdg := daughter.Pid
    44  		if pdg < 0 {
    45  			pdg = -pdg
    46  		}
    47  		switch pdg {
    48  		case 11, 13, 15, 24:
    49  			return -1
    50  		}
    51  	}
    52  	return 0
    53  }
    54  
    55  type TauTagging struct {
    56  	fwk.TaskBase
    57  
    58  	particles string
    59  	partons   string
    60  	jets      string
    61  	output    string
    62  
    63  	dR float64
    64  
    65  	tag tauclassifier
    66  	eff map[int]func(pt, eta float64) float64
    67  
    68  	seed uint64
    69  	src  *rand.Rand
    70  
    71  	flatmu sync.Mutex
    72  	flat   distuv.Uniform
    73  }
    74  
    75  func (tsk *TauTagging) Configure(ctx fwk.Context) error {
    76  	var err error
    77  
    78  	err = tsk.DeclInPort(tsk.partons, reflect.TypeOf([]Candidate{}))
    79  	if err != nil {
    80  		return err
    81  	}
    82  
    83  	err = tsk.DeclInPort(tsk.jets, reflect.TypeOf([]Candidate{}))
    84  	if err != nil {
    85  		return err
    86  	}
    87  
    88  	err = tsk.DeclOutPort(tsk.output, reflect.TypeOf([]Candidate{}))
    89  	if err != nil {
    90  		return err
    91  	}
    92  
    93  	tsk.src = rand.New(rand.NewPCG(tsk.seed, tsk.seed))
    94  	tsk.flat = distuv.Uniform{Min: 0, Max: 1, Src: tsk.src}
    95  	return err
    96  }
    97  
    98  func (tsk *TauTagging) StartTask(ctx fwk.Context) error {
    99  	var err error
   100  
   101  	return err
   102  }
   103  
   104  func (tsk *TauTagging) StopTask(ctx fwk.Context) error {
   105  	var err error
   106  
   107  	return err
   108  }
   109  
   110  func (tsk *TauTagging) Process(ctx fwk.Context) error {
   111  	var err error
   112  
   113  	store := ctx.Store()
   114  	msg := ctx.Msg()
   115  
   116  	v, err := store.Get(tsk.particles)
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	particles := v.([]Candidate)
   122  
   123  	v, err = store.Get(tsk.partons)
   124  	if err != nil {
   125  		return err
   126  	}
   127  
   128  	allpartons := v.([]Candidate)
   129  
   130  	v, err = store.Get(tsk.jets)
   131  	if err != nil {
   132  		return err
   133  	}
   134  	jets := v.([]Candidate)
   135  
   136  	output := make([]Candidate, 0, len(jets))
   137  	defer func() {
   138  		err = store.Put(tsk.output, output)
   139  	}()
   140  
   141  	msg.Debugf("particles: %d\n", len(particles))
   142  	msg.Debugf("partons:   %d\n", len(allpartons))
   143  	msg.Debugf("jets:      %d\n", len(jets))
   144  
   145  	taus := make([]Candidate, 0, len(allpartons))
   146  	for i := range allpartons {
   147  		cand := &allpartons[i]
   148  		if tsk.tag.Category(cand, particles) < 0 {
   149  			continue
   150  		}
   151  		taus = append(taus, *cand)
   152  	}
   153  
   154  	for i := range jets {
   155  		jet := jets[i].Clone()
   156  		pdg := 0
   157  		eta := jet.Mom.Eta()
   158  		pt := jet.Mom.Pt()
   159  
   160  		charge := int32(-1)
   161  		tsk.flatmu.Lock()
   162  		if tsk.flat.Rand() > 0.5 {
   163  			charge = 1
   164  		}
   165  		tsk.flatmu.Unlock()
   166  
   167  		for j := range taus {
   168  			mc := &taus[j]
   169  			if mc.D1 < 0 {
   170  				continue
   171  			}
   172  
   173  			var p4 fmom.PxPyPzE
   174  			for ii := mc.D1; ii < mc.D2; ii++ {
   175  				daughter := &particles[ii]
   176  				pdg := daughter.Pid
   177  				if pdg == -16 || pdg == 16 {
   178  					continue
   179  				}
   180  				fmom.IAdd(&p4, &daughter.Mom)
   181  			}
   182  
   183  			if fmom.DeltaR(&jet.Mom, &p4) < tsk.dR {
   184  				pdg = 15
   185  				charge = mc.CandCharge
   186  			}
   187  		}
   188  
   189  		eff, ok := tsk.eff[pdg]
   190  		if !ok {
   191  			eff = tsk.eff[0]
   192  		}
   193  
   194  		// apply efficiency
   195  		tag := uint32(0)
   196  		tsk.flatmu.Lock()
   197  		if tsk.flat.Rand() <= eff(pt, eta) {
   198  			tag = 1
   199  		}
   200  		tsk.flatmu.Unlock()
   201  		jet.TauTag = tag
   202  		jet.CandCharge = charge
   203  
   204  		output = append(output, *jet)
   205  	}
   206  
   207  	msg.Debugf("output:  %d\n", len(output))
   208  	return err
   209  }
   210  
   211  func newTauTagging(typ, name string, mgr fwk.App) (fwk.Component, error) {
   212  	var err error
   213  
   214  	tsk := &TauTagging{
   215  		TaskBase:  fwk.NewTask(typ, name, mgr),
   216  		particles: "InputParticles",
   217  		partons:   "InputPartons",
   218  		jets:      "InputJets",
   219  		output:    "OutputJets",
   220  
   221  		dR: 0.5,
   222  		tag: tauclassifier{
   223  			PtMin:  1.0,
   224  			EtaMax: 2.5,
   225  		},
   226  		eff: map[int]func(pt, eta float64) float64{
   227  			0: func(pt, eta float64) float64 { return 0 },
   228  		},
   229  
   230  		seed: 1234,
   231  	}
   232  
   233  	err = tsk.DeclProp("Particles", &tsk.particles)
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  
   238  	err = tsk.DeclProp("Partons", &tsk.partons)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	err = tsk.DeclProp("Jets", &tsk.jets)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	err = tsk.DeclProp("Output", &tsk.output)
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  
   253  	err = tsk.DeclProp("DeltaR", &tsk.dR)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  
   258  	err = tsk.DeclProp("TauPtMin", &tsk.tag.PtMin)
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  
   263  	err = tsk.DeclProp("TauEtaMax", &tsk.tag.EtaMax)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  
   268  	err = tsk.DeclProp("Eff", &tsk.eff)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  
   273  	err = tsk.DeclProp("Seed", &tsk.seed)
   274  	if err != nil {
   275  		return nil, err
   276  	}
   277  
   278  	return tsk, err
   279  }
   280  
   281  func init() {
   282  	fwk.Register(reflect.TypeOf(TauTagging{}), newTauTagging)
   283  }