go-hep.org/x/hep@v0.38.1/fads/tautagging.go (about) 1 // Copyright ©2017 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fads 6 7 import ( 8 "math" 9 "math/rand/v2" 10 "reflect" 11 "sync" 12 13 "go-hep.org/x/hep/fmom" 14 "go-hep.org/x/hep/fwk" 15 "gonum.org/v1/gonum/stat/distuv" 16 ) 17 18 type tauclassifier struct { 19 PtMin float64 20 EtaMax float64 21 } 22 23 func (tag tauclassifier) Category(tau *Candidate, particles []Candidate) int { 24 25 pdg := tau.Pid 26 if pdg < 0 { 27 pdg = -pdg 28 } 29 if pdg != 15 { 30 return -1 31 } 32 33 if tau.Mom.Pt() <= tag.PtMin || math.Abs(tau.Mom.Eta()) > tag.EtaMax { 34 return -1 35 } 36 37 if tau.D1 < 0 { 38 return -1 39 } 40 41 for i := tau.D1; i <= tau.D2; i++ { 42 daughter := &particles[i] 43 pdg := daughter.Pid 44 if pdg < 0 { 45 pdg = -pdg 46 } 47 switch pdg { 48 case 11, 13, 15, 24: 49 return -1 50 } 51 } 52 return 0 53 } 54 55 type TauTagging struct { 56 fwk.TaskBase 57 58 particles string 59 partons string 60 jets string 61 output string 62 63 dR float64 64 65 tag tauclassifier 66 eff map[int]func(pt, eta float64) float64 67 68 seed uint64 69 src *rand.Rand 70 71 flatmu sync.Mutex 72 flat distuv.Uniform 73 } 74 75 func (tsk *TauTagging) Configure(ctx fwk.Context) error { 76 var err error 77 78 err = tsk.DeclInPort(tsk.partons, reflect.TypeOf([]Candidate{})) 79 if err != nil { 80 return err 81 } 82 83 err = tsk.DeclInPort(tsk.jets, reflect.TypeOf([]Candidate{})) 84 if err != nil { 85 return err 86 } 87 88 err = tsk.DeclOutPort(tsk.output, reflect.TypeOf([]Candidate{})) 89 if err != nil { 90 return err 91 } 92 93 tsk.src = rand.New(rand.NewPCG(tsk.seed, tsk.seed)) 94 tsk.flat = distuv.Uniform{Min: 0, Max: 1, Src: tsk.src} 95 return err 96 } 97 98 func (tsk *TauTagging) StartTask(ctx fwk.Context) error { 99 var err error 100 101 return err 102 } 103 104 func (tsk *TauTagging) StopTask(ctx fwk.Context) error { 105 var err error 106 107 return err 108 } 109 110 func (tsk *TauTagging) Process(ctx fwk.Context) error { 111 var err error 112 113 store := ctx.Store() 114 msg := ctx.Msg() 115 116 v, err := store.Get(tsk.particles) 117 if err != nil { 118 return err 119 } 120 121 particles := v.([]Candidate) 122 123 v, err = store.Get(tsk.partons) 124 if err != nil { 125 return err 126 } 127 128 allpartons := v.([]Candidate) 129 130 v, err = store.Get(tsk.jets) 131 if err != nil { 132 return err 133 } 134 jets := v.([]Candidate) 135 136 output := make([]Candidate, 0, len(jets)) 137 defer func() { 138 err = store.Put(tsk.output, output) 139 }() 140 141 msg.Debugf("particles: %d\n", len(particles)) 142 msg.Debugf("partons: %d\n", len(allpartons)) 143 msg.Debugf("jets: %d\n", len(jets)) 144 145 taus := make([]Candidate, 0, len(allpartons)) 146 for i := range allpartons { 147 cand := &allpartons[i] 148 if tsk.tag.Category(cand, particles) < 0 { 149 continue 150 } 151 taus = append(taus, *cand) 152 } 153 154 for i := range jets { 155 jet := jets[i].Clone() 156 pdg := 0 157 eta := jet.Mom.Eta() 158 pt := jet.Mom.Pt() 159 160 charge := int32(-1) 161 tsk.flatmu.Lock() 162 if tsk.flat.Rand() > 0.5 { 163 charge = 1 164 } 165 tsk.flatmu.Unlock() 166 167 for j := range taus { 168 mc := &taus[j] 169 if mc.D1 < 0 { 170 continue 171 } 172 173 var p4 fmom.PxPyPzE 174 for ii := mc.D1; ii < mc.D2; ii++ { 175 daughter := &particles[ii] 176 pdg := daughter.Pid 177 if pdg == -16 || pdg == 16 { 178 continue 179 } 180 fmom.IAdd(&p4, &daughter.Mom) 181 } 182 183 if fmom.DeltaR(&jet.Mom, &p4) < tsk.dR { 184 pdg = 15 185 charge = mc.CandCharge 186 } 187 } 188 189 eff, ok := tsk.eff[pdg] 190 if !ok { 191 eff = tsk.eff[0] 192 } 193 194 // apply efficiency 195 tag := uint32(0) 196 tsk.flatmu.Lock() 197 if tsk.flat.Rand() <= eff(pt, eta) { 198 tag = 1 199 } 200 tsk.flatmu.Unlock() 201 jet.TauTag = tag 202 jet.CandCharge = charge 203 204 output = append(output, *jet) 205 } 206 207 msg.Debugf("output: %d\n", len(output)) 208 return err 209 } 210 211 func newTauTagging(typ, name string, mgr fwk.App) (fwk.Component, error) { 212 var err error 213 214 tsk := &TauTagging{ 215 TaskBase: fwk.NewTask(typ, name, mgr), 216 particles: "InputParticles", 217 partons: "InputPartons", 218 jets: "InputJets", 219 output: "OutputJets", 220 221 dR: 0.5, 222 tag: tauclassifier{ 223 PtMin: 1.0, 224 EtaMax: 2.5, 225 }, 226 eff: map[int]func(pt, eta float64) float64{ 227 0: func(pt, eta float64) float64 { return 0 }, 228 }, 229 230 seed: 1234, 231 } 232 233 err = tsk.DeclProp("Particles", &tsk.particles) 234 if err != nil { 235 return nil, err 236 } 237 238 err = tsk.DeclProp("Partons", &tsk.partons) 239 if err != nil { 240 return nil, err 241 } 242 243 err = tsk.DeclProp("Jets", &tsk.jets) 244 if err != nil { 245 return nil, err 246 } 247 248 err = tsk.DeclProp("Output", &tsk.output) 249 if err != nil { 250 return nil, err 251 } 252 253 err = tsk.DeclProp("DeltaR", &tsk.dR) 254 if err != nil { 255 return nil, err 256 } 257 258 err = tsk.DeclProp("TauPtMin", &tsk.tag.PtMin) 259 if err != nil { 260 return nil, err 261 } 262 263 err = tsk.DeclProp("TauEtaMax", &tsk.tag.EtaMax) 264 if err != nil { 265 return nil, err 266 } 267 268 err = tsk.DeclProp("Eff", &tsk.eff) 269 if err != nil { 270 return nil, err 271 } 272 273 err = tsk.DeclProp("Seed", &tsk.seed) 274 if err != nil { 275 return nil, err 276 } 277 278 return tsk, err 279 } 280 281 func init() { 282 fwk.Register(reflect.TypeOf(TauTagging{}), newTauTagging) 283 }