github.com/google/osv-scalibr@v0.4.1/enricher/reachability/java/reachable.go (about) 1 // Copyright 2025 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package java 16 17 import ( 18 "archive/zip" 19 "errors" 20 "fmt" 21 "io/fs" 22 "maps" 23 "os" 24 "path" 25 "regexp" 26 "slices" 27 "strings" 28 29 "github.com/google/osv-scalibr/log" 30 ) 31 32 // ReachabilityResult contains the result of a reachability enumeration. 33 type ReachabilityResult struct { 34 Classes []string 35 UsesDynamicCodeLoading []string 36 UsesDependencyInjection []string 37 } 38 39 // DynamicCodeStrategy is a strategy for handling dynamic code loading. 40 type DynamicCodeStrategy int 41 42 const ( 43 // DontHandleDynamicCode doesn't do any kind of special handling. 44 DontHandleDynamicCode DynamicCodeStrategy = 0 45 // AssumeAllDirectDepsReachable assumes that the entirety of all direct dependencies (i.e. all their 46 // classes) are fully reachable. 47 AssumeAllDirectDepsReachable = 1 << 0 48 // AssumeAllClassesReachable assumes that every single class belonging to the current dependency are 49 // fully reachable. 50 AssumeAllClassesReachable = 1 << 1 51 ) 52 53 const ( 54 // BootInfClasses contains spring boot specific classes. 55 BootInfClasses = "BOOT-INF/classes" 56 // MetaInfVersions is a directory that contains multi-release JARs. 57 MetaInfVersions = "META-INF/versions" 58 ) 59 60 var errClassNotFound = errors.New("class not found in jar") 61 62 // ReachabilityEnumerator enumerates the reachable classes from a set of root 63 // classes. 64 type ReachabilityEnumerator struct { 65 ClassPaths []string 66 PackageFinder MavenPackageFinder 67 CodeLoadingStrategy DynamicCodeStrategy 68 DependencyInjectionStrategy DynamicCodeStrategy 69 70 loadedJARs map[string]*zip.Reader 71 } 72 73 // NewReachabilityEnumerator creates a new ReachabilityEnumerator. 74 func NewReachabilityEnumerator( 75 classPaths []string, packageFinder MavenPackageFinder, 76 codeLoadingStrategy DynamicCodeStrategy, dependencyInjectionStrategy DynamicCodeStrategy) *ReachabilityEnumerator { 77 return &ReachabilityEnumerator{ 78 ClassPaths: classPaths, 79 PackageFinder: packageFinder, 80 CodeLoadingStrategy: codeLoadingStrategy, 81 DependencyInjectionStrategy: dependencyInjectionStrategy, 82 loadedJARs: map[string]*zip.Reader{}, 83 } 84 } 85 86 // EnumerateReachabilityFromClasses enumerates the reachable classes from a set of root 87 // classes. 88 func (r *ReachabilityEnumerator) EnumerateReachabilityFromClasses(jarRoot *os.Root, mainClasses []string, optionalRootClasses []string) (*ReachabilityResult, error) { 89 roots := make([]*ClassFile, 0, len(mainClasses)+len(optionalRootClasses)) 90 for _, mainClass := range mainClasses { 91 cf, err := r.findClass(jarRoot, r.ClassPaths, mainClass) 92 if err != nil { 93 return nil, fmt.Errorf("failed to find main class %s: %w", mainClass, err) 94 } 95 roots = append(roots, cf) 96 } 97 98 // optionalRootClasses include those from META-INF/services. They might not exist in the Jar. 99 for _, serviceClass := range optionalRootClasses { 100 cf, err := r.findClass(jarRoot, r.ClassPaths, serviceClass) 101 if err != nil { 102 continue 103 } 104 roots = append(roots, cf) 105 } 106 107 return r.EnumerateReachability(jarRoot, roots) 108 } 109 110 // EnumerateReachability enumerates the reachable classes from a set of root 111 // classes. 112 func (r *ReachabilityEnumerator) EnumerateReachability(jarRoot *os.Root, roots []*ClassFile) (*ReachabilityResult, error) { 113 seen := map[string]struct{}{} 114 codeLoading := map[string]struct{}{} 115 depInjection := map[string]struct{}{} 116 for _, root := range roots { 117 if err := r.enumerateReachability(jarRoot, root, seen, codeLoading, depInjection); err != nil { 118 return nil, err 119 } 120 } 121 122 return &ReachabilityResult{ 123 Classes: slices.Collect(maps.Keys(seen)), 124 UsesDynamicCodeLoading: slices.Collect(maps.Keys(codeLoading)), 125 UsesDependencyInjection: slices.Collect(maps.Keys(depInjection)), 126 }, nil 127 } 128 129 // findClassInJAR finds the relevant parsed .class file from a .jar. 130 func (r *ReachabilityEnumerator) findClassInJAR(jarRoot *os.Root, jarPath string, className string) (*ClassFile, error) { 131 if _, ok := r.loadedJARs[jarPath]; !ok { 132 // Repeatedly opening zip files is very slow, so cache the opened JARs. 133 f, err := jarRoot.Open(jarPath) 134 if err != nil { 135 return nil, err 136 } 137 stat, err := f.Stat() 138 if err != nil { 139 return nil, err 140 } 141 142 zipr, err := zip.NewReader(f, stat.Size()) 143 if err != nil { 144 return nil, err 145 } 146 r.loadedJARs[jarPath] = zipr 147 } 148 149 zipr := r.loadedJARs[jarPath] 150 class, err := zipr.Open(className + ".class") 151 if err != nil { 152 if errors.Is(err, fs.ErrNotExist) { 153 // class not found in this .jar. not an error. 154 return nil, errClassNotFound 155 } 156 157 return nil, err 158 } 159 160 return ParseClass(class) 161 } 162 163 // findClass finds the relevant parsed .class file from a list of classpaths. 164 func (r *ReachabilityEnumerator) findClass(jarRoot *os.Root, classPaths []string, className string) (*ClassFile, error) { 165 // TODO(#787): Support META-INF/versions (multi release JARs) if necessary. 166 167 // Remove generics from the class name. 168 genericRE := regexp.MustCompile(`<.*>`) 169 className = genericRE.ReplaceAllString(className, "") 170 171 // Handle inner class names. The class filename will have a "$" in place of the ".". 172 className = strings.ReplaceAll(className, ".", "$") 173 174 for _, classPath := range classPaths { 175 if strings.HasSuffix(classPath, ".jar") { 176 cf, err := r.findClassInJAR(jarRoot, classPath, className) 177 if err != nil && !errors.Is(err, errClassNotFound) { 178 return nil, err 179 } 180 181 if cf != nil { 182 log.Debug("found class in nested .jar", "class", className, "path", classPath) 183 184 return cf, nil 185 } 186 187 continue 188 } 189 190 // Look inside the class directory. 191 classFilepath := path.Join(classPath, className) 192 193 if !strings.HasSuffix(classFilepath, ".class") { 194 classFilepath += ".class" 195 } 196 197 if _, err := jarRoot.Stat(classFilepath); errors.Is(err, fs.ErrNotExist) { 198 // Class not found in this directory. Move onto the next classpath. 199 continue 200 } 201 202 classFile, err := jarRoot.Open(classFilepath) 203 if err != nil { 204 return nil, err 205 } 206 cf, err := ParseClass(classFile) 207 if err != nil { 208 return nil, err 209 } 210 log.Debug("found class in directory", "class", className, "path", classPath) 211 212 return cf, nil 213 } 214 215 return nil, errors.New("class not found") 216 } 217 218 // isDynamicCodeLoading returns whether a method and its descriptor represents a 219 // call to a dynamic code loading method. 220 func isDynamicCodeLoading(method string, descriptor string) bool { 221 // https://docs.oracle.com/en/java/javase/23/docs/api/java.base/java/lang/ClassLoader.html#loadClass(java.lang.String) 222 if strings.Contains(method, "loadClass") && strings.HasSuffix(descriptor, "Ljava/lang/Class;") { 223 return true 224 } 225 226 // https://docs.oracle.com/en/java/javase/23/docs/api/java.base/java/lang/Class.html#forName(java.lang.String) 227 if strings.Contains(method, "forName") && strings.HasSuffix(descriptor, "Ljava/lang/Class;") { 228 return true 229 } 230 231 return false 232 } 233 234 // isDependencyInjection returns whether a class provides dependency injection functionality. 235 func isDependencyInjection(class string) bool { 236 if strings.HasPrefix(class, "javax/inject") { 237 return true 238 } 239 240 if strings.HasPrefix(class, "org/springframework") { 241 return true 242 } 243 244 if strings.HasPrefix(class, "com/google/inject") { 245 return true 246 } 247 248 if strings.HasPrefix(class, "dagger/") { 249 return true 250 } 251 252 return false 253 } 254 255 // handleDynamicCode handles the enumeration of class reachability when there is 256 // dynamic code loading, taking into account a user specified strategy. 257 func (r *ReachabilityEnumerator) handleDynamicCode(jarRoot *os.Root, q *UniqueQueue[string, *ClassFile], class string, strategy DynamicCodeStrategy) error { 258 if strategy == DontHandleDynamicCode { 259 return nil 260 } 261 262 pkgs, err := r.PackageFinder.Find(class) 263 if err != nil { 264 return err 265 } 266 267 // Assume all classes that belong to the package are reachable. 268 // TODO(#787): Assume all classes that belong to the direct dependencies of the package 269 // are reachable. 270 if strategy&AssumeAllClassesReachable > 0 { 271 for _, pkg := range pkgs { 272 classes, err := r.PackageFinder.Classes(pkg) 273 if err != nil { 274 return err 275 } 276 277 for _, class := range classes { 278 if q.Seen(class) { 279 continue 280 } 281 cf, err := r.findClass(jarRoot, r.ClassPaths, class) 282 if err == nil { 283 log.Debug("assuming all package classes are reachable", "class", class, "pkg", pkg) 284 q.Push(class, cf) 285 } else { 286 log.Debug("failed to find class", "class", class, "from", pkg, "err", err) 287 } 288 } 289 } 290 } 291 292 return nil 293 } 294 295 func (r *ReachabilityEnumerator) enumerateReachability( 296 jarRoot *os.Root, cf *ClassFile, seen map[string]struct{}, codeLoading map[string]struct{}, depInjection map[string]struct{}) error { 297 thisClass, err := cf.ConstantPoolClass(int(cf.ThisClass)) 298 if err != nil { 299 return err 300 } 301 302 q := NewQueue[string, *ClassFile](seen) 303 q.Push(thisClass, cf) 304 305 for !q.Empty() { 306 thisClass, cf := q.Pop() 307 log.Debug("Analyzing", "class", thisClass) 308 309 // Find uses of dynamic code loading. 310 for i, cp := range cf.ConstantPool { 311 if cp.Type() == ConstantKindMethodref { 312 _, method, descriptor, err := cf.ConstantPoolMethodref(i) 313 if err != nil { 314 return err 315 } 316 317 if isDynamicCodeLoading(method, descriptor) { 318 log.Debug("found dynamic class loading", "thisClass", thisClass, "method", method, "descriptor", descriptor) 319 if _, ok := codeLoading[thisClass]; !ok { 320 codeLoading[thisClass] = struct{}{} 321 err := r.handleDynamicCode(jarRoot, q, thisClass, r.CodeLoadingStrategy) 322 if err != nil { 323 log.Debug("failed to handle dynamic code", "thisClass", thisClass, "err", err) 324 } 325 } 326 } 327 } else if cp.Type() == ConstantKindClass { 328 class, err := cf.ConstantPoolClass(i) 329 if err != nil { 330 return err 331 } 332 333 if isDependencyInjection(class) { 334 log.Debug("found dependency injection", "thisClass", thisClass, "injector", class) 335 if _, ok := depInjection[thisClass]; !ok { 336 depInjection[thisClass] = struct{}{} 337 err := r.handleDynamicCode(jarRoot, q, thisClass, r.DependencyInjectionStrategy) 338 if err != nil { 339 log.Debug("failed to handle dynamic code", "thisClass", thisClass, "err", err) 340 } 341 } 342 } 343 } 344 } 345 346 // Enumerate class references. 347 for i, cp := range cf.ConstantPool { 348 if int(cf.ThisClass) == i { 349 // Don't consider this class itself. 350 continue 351 } 352 353 class := "" 354 if cp.Type() == ConstantKindClass { 355 class, err = cf.ConstantPoolClass(i) 356 if err != nil { 357 return err 358 } 359 } else if cp.Type() == ConstantKindUtf8 { 360 // Also check strings for references to classes. 361 val, err := cf.ConstantPoolUtf8(i) 362 if err != nil { 363 continue 364 } 365 366 // Found a string with the `Lpath/to/class;` format. This is 367 // likely a reference to a class. Annotations appear this way. 368 if val != "" && val[0] == 'L' && val[len(val)-1] == ';' { 369 class = val[1 : len(val)-1] 370 } 371 } 372 373 if class == "" { 374 continue 375 } 376 377 // Handle arrays. 378 if len(class) > 0 && class[0] == '[' { 379 // "[" can appear multiple times (nested arrays). 380 class = strings.TrimLeft(class, "[") 381 382 // Array of class type. Extract the class name. 383 if len(class) > 0 && class[0] == 'L' { 384 class = strings.TrimSuffix(class[1:], ";") 385 } else if slices.Contains(BinaryBaseTypes, class) { 386 // Base type (e.g. integer): just ignore this. 387 continue 388 } else { 389 // We don't know what the type is. 390 return fmt.Errorf("unknown class type %s", class) 391 } 392 } 393 394 if IsStdLib(class) { 395 continue 396 } 397 398 log.Debug("found", "dependency", class) 399 if q.Seen(class) { 400 continue 401 } 402 403 depcf, err := r.findClass(jarRoot, r.ClassPaths, class) 404 if err != nil { 405 // Dependencies can be optional, so this is not a fatal error. 406 log.Debug("failed to find class", "class", class, "from", thisClass, "cp idx", i, "error", err) 407 continue 408 } 409 410 q.Push(class, depcf) 411 } 412 } 413 414 return nil 415 }