go-hep.org/x/hep@v0.38.1/fads/btagging.go (about) 1 // Copyright ©2017 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fads 6 7 import ( 8 "math" 9 "math/rand/v2" 10 "reflect" 11 "sync" 12 13 "go-hep.org/x/hep/fmom" 14 "go-hep.org/x/hep/fwk" 15 "gonum.org/v1/gonum/stat/distuv" 16 ) 17 18 type btagclassifier struct { 19 PtMin float64 20 EtaMax float64 21 } 22 23 func (btag btagclassifier) Category(parton *Candidate) int { 24 if parton.Mom.Pt() <= btag.PtMin || math.Abs(parton.Mom.Eta()) > btag.EtaMax { 25 return -1 26 } 27 pdg := parton.Pid 28 if pdg < 0 { 29 pdg = -pdg 30 } 31 32 if pdg != 21 && pdg > 5 { 33 return -1 34 } 35 36 return 0 37 } 38 39 type BTagging struct { 40 fwk.TaskBase 41 42 partons string 43 jets string 44 output string 45 46 dR float64 47 bit uint 48 49 btag btagclassifier 50 eff map[int]func(pt, eta float64) float64 51 52 seed uint64 53 src *rand.Rand 54 55 flat distuv.Uniform 56 flatmu sync.Mutex 57 } 58 59 func (tsk *BTagging) Configure(ctx fwk.Context) error { 60 var err error 61 62 err = tsk.DeclInPort(tsk.partons, reflect.TypeOf([]Candidate{})) 63 if err != nil { 64 return err 65 } 66 67 err = tsk.DeclInPort(tsk.jets, reflect.TypeOf([]Candidate{})) 68 if err != nil { 69 return err 70 } 71 72 err = tsk.DeclOutPort(tsk.output, reflect.TypeOf([]Candidate{})) 73 if err != nil { 74 return err 75 } 76 77 tsk.src = rand.New(rand.NewPCG(tsk.seed, tsk.seed)) 78 tsk.flat = distuv.Uniform{Min: 0, Max: 1, Src: tsk.src} 79 return err 80 } 81 82 func (tsk *BTagging) StartTask(ctx fwk.Context) error { 83 var err error 84 85 return err 86 } 87 88 func (tsk *BTagging) StopTask(ctx fwk.Context) error { 89 var err error 90 91 return err 92 } 93 94 func (tsk *BTagging) Process(ctx fwk.Context) error { 95 var err error 96 97 store := ctx.Store() 98 msg := ctx.Msg() 99 100 v, err := store.Get(tsk.partons) 101 if err != nil { 102 return err 103 } 104 105 allpartons := v.([]Candidate) 106 107 v, err = store.Get(tsk.jets) 108 if err != nil { 109 return err 110 } 111 jets := v.([]Candidate) 112 113 output := make([]Candidate, 0, len(jets)) 114 defer func() { 115 err = store.Put(tsk.output, jets) 116 }() 117 118 msg.Debugf("partons: %d\n", len(allpartons)) 119 msg.Debugf("jets: %d\n", len(jets)) 120 121 partons := make([]Candidate, 0, len(allpartons)) 122 for i := range allpartons { 123 cand := &allpartons[i] 124 if tsk.btag.Category(cand) < 0 { 125 continue 126 } 127 partons = append(partons, *cand) 128 } 129 130 for i := range jets { 131 jet := jets[i].Clone() 132 pdgmax := -1 133 eta := jet.Mom.Eta() 134 pt := jet.Mom.Pt() 135 136 for j := range partons { 137 p := &partons[j] 138 pdg := int(p.Pid) 139 if pdg < 0 { 140 pdg = -pdg 141 } 142 if pdg == 21 { 143 pdg = 0 144 } 145 if fmom.DeltaR(&jet.Mom, &p.Mom) < tsk.dR { 146 if pdgmax < pdg { 147 pdgmax = pdg 148 } 149 } 150 } 151 152 switch pdgmax { 153 case 0: 154 pdgmax = 21 155 case -1: 156 pdgmax = 0 157 } 158 159 eff, ok := tsk.eff[pdgmax] 160 if !ok { 161 eff = tsk.eff[0] 162 } 163 164 // apply efficiency 165 tag := uint32(0) 166 tsk.flatmu.Lock() 167 if tsk.flat.Rand() <= eff(pt, eta) { 168 tag = 1 169 } 170 tsk.flatmu.Unlock() 171 jet.BTag |= tag << tsk.bit 172 173 output = append(output, *jet) 174 } 175 176 msg.Debugf("output: %d\n", len(output)) 177 return err 178 } 179 180 func newBTagging(typ, name string, mgr fwk.App) (fwk.Component, error) { 181 var err error 182 183 tsk := &BTagging{ 184 TaskBase: fwk.NewTask(typ, name, mgr), 185 partons: "InputPartons", 186 jets: "InputJets", 187 output: "OutputJets", 188 189 bit: 0, 190 dR: 0.5, 191 btag: btagclassifier{ 192 PtMin: 1.0, 193 EtaMax: 2.5, 194 }, 195 eff: map[int]func(pt, eta float64) float64{ 196 0: func(pt, eta float64) float64 { return 0 }, 197 }, 198 199 seed: 1234, 200 } 201 202 err = tsk.DeclProp("Partons", &tsk.partons) 203 if err != nil { 204 return nil, err 205 } 206 207 err = tsk.DeclProp("Jets", &tsk.jets) 208 if err != nil { 209 return nil, err 210 } 211 212 err = tsk.DeclProp("Output", &tsk.output) 213 if err != nil { 214 return nil, err 215 } 216 217 err = tsk.DeclProp("BitNumber", &tsk.bit) 218 if err != nil { 219 return nil, err 220 } 221 222 err = tsk.DeclProp("DeltaR", &tsk.dR) 223 if err != nil { 224 return nil, err 225 } 226 227 err = tsk.DeclProp("PartonPtMin", &tsk.btag.PtMin) 228 if err != nil { 229 return nil, err 230 } 231 232 err = tsk.DeclProp("PartonEtaMax", &tsk.btag.EtaMax) 233 if err != nil { 234 return nil, err 235 } 236 237 err = tsk.DeclProp("Eff", &tsk.eff) 238 if err != nil { 239 return nil, err 240 } 241 242 err = tsk.DeclProp("Seed", &tsk.seed) 243 if err != nil { 244 return nil, err 245 } 246 247 return tsk, err 248 } 249 250 func init() { 251 fwk.Register(reflect.TypeOf(BTagging{}), newBTagging) 252 }