github.com/google/osv-scalibr@v0.4.1/enricher/reachability/java/reachable.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package java
    16  
    17  import (
    18  	"archive/zip"
    19  	"errors"
    20  	"fmt"
    21  	"io/fs"
    22  	"maps"
    23  	"os"
    24  	"path"
    25  	"regexp"
    26  	"slices"
    27  	"strings"
    28  
    29  	"github.com/google/osv-scalibr/log"
    30  )
    31  
    32  // ReachabilityResult contains the result of a reachability enumeration.
    33  type ReachabilityResult struct {
    34  	Classes                 []string
    35  	UsesDynamicCodeLoading  []string
    36  	UsesDependencyInjection []string
    37  }
    38  
    39  // DynamicCodeStrategy is a strategy for handling dynamic code loading.
    40  type DynamicCodeStrategy int
    41  
    42  const (
    43  	// DontHandleDynamicCode doesn't do any kind of special handling.
    44  	DontHandleDynamicCode DynamicCodeStrategy = 0
    45  	// AssumeAllDirectDepsReachable assumes that the entirety of all direct dependencies (i.e. all their
    46  	// classes) are fully reachable.
    47  	AssumeAllDirectDepsReachable = 1 << 0
    48  	// AssumeAllClassesReachable assumes that every single class belonging to the current dependency are
    49  	// fully reachable.
    50  	AssumeAllClassesReachable = 1 << 1
    51  )
    52  
    53  const (
    54  	// BootInfClasses contains spring boot specific classes.
    55  	BootInfClasses = "BOOT-INF/classes"
    56  	// MetaInfVersions is a directory that contains multi-release JARs.
    57  	MetaInfVersions = "META-INF/versions"
    58  )
    59  
    60  var errClassNotFound = errors.New("class not found in jar")
    61  
    62  // ReachabilityEnumerator enumerates the reachable classes from a set of root
    63  // classes.
    64  type ReachabilityEnumerator struct {
    65  	ClassPaths                  []string
    66  	PackageFinder               MavenPackageFinder
    67  	CodeLoadingStrategy         DynamicCodeStrategy
    68  	DependencyInjectionStrategy DynamicCodeStrategy
    69  
    70  	loadedJARs map[string]*zip.Reader
    71  }
    72  
    73  // NewReachabilityEnumerator creates a new ReachabilityEnumerator.
    74  func NewReachabilityEnumerator(
    75  	classPaths []string, packageFinder MavenPackageFinder,
    76  	codeLoadingStrategy DynamicCodeStrategy, dependencyInjectionStrategy DynamicCodeStrategy) *ReachabilityEnumerator {
    77  	return &ReachabilityEnumerator{
    78  		ClassPaths:                  classPaths,
    79  		PackageFinder:               packageFinder,
    80  		CodeLoadingStrategy:         codeLoadingStrategy,
    81  		DependencyInjectionStrategy: dependencyInjectionStrategy,
    82  		loadedJARs:                  map[string]*zip.Reader{},
    83  	}
    84  }
    85  
    86  // EnumerateReachabilityFromClasses enumerates the reachable classes from a set of root
    87  // classes.
    88  func (r *ReachabilityEnumerator) EnumerateReachabilityFromClasses(jarRoot *os.Root, mainClasses []string, optionalRootClasses []string) (*ReachabilityResult, error) {
    89  	roots := make([]*ClassFile, 0, len(mainClasses)+len(optionalRootClasses))
    90  	for _, mainClass := range mainClasses {
    91  		cf, err := r.findClass(jarRoot, r.ClassPaths, mainClass)
    92  		if err != nil {
    93  			return nil, fmt.Errorf("failed to find main class %s: %w", mainClass, err)
    94  		}
    95  		roots = append(roots, cf)
    96  	}
    97  
    98  	// optionalRootClasses include those from META-INF/services. They might not exist in the Jar.
    99  	for _, serviceClass := range optionalRootClasses {
   100  		cf, err := r.findClass(jarRoot, r.ClassPaths, serviceClass)
   101  		if err != nil {
   102  			continue
   103  		}
   104  		roots = append(roots, cf)
   105  	}
   106  
   107  	return r.EnumerateReachability(jarRoot, roots)
   108  }
   109  
   110  // EnumerateReachability enumerates the reachable classes from a set of root
   111  // classes.
   112  func (r *ReachabilityEnumerator) EnumerateReachability(jarRoot *os.Root, roots []*ClassFile) (*ReachabilityResult, error) {
   113  	seen := map[string]struct{}{}
   114  	codeLoading := map[string]struct{}{}
   115  	depInjection := map[string]struct{}{}
   116  	for _, root := range roots {
   117  		if err := r.enumerateReachability(jarRoot, root, seen, codeLoading, depInjection); err != nil {
   118  			return nil, err
   119  		}
   120  	}
   121  
   122  	return &ReachabilityResult{
   123  		Classes:                 slices.Collect(maps.Keys(seen)),
   124  		UsesDynamicCodeLoading:  slices.Collect(maps.Keys(codeLoading)),
   125  		UsesDependencyInjection: slices.Collect(maps.Keys(depInjection)),
   126  	}, nil
   127  }
   128  
   129  // findClassInJAR finds the relevant parsed .class file from a .jar.
   130  func (r *ReachabilityEnumerator) findClassInJAR(jarRoot *os.Root, jarPath string, className string) (*ClassFile, error) {
   131  	if _, ok := r.loadedJARs[jarPath]; !ok {
   132  		// Repeatedly opening zip files is very slow, so cache the opened JARs.
   133  		f, err := jarRoot.Open(jarPath)
   134  		if err != nil {
   135  			return nil, err
   136  		}
   137  		stat, err := f.Stat()
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  
   142  		zipr, err := zip.NewReader(f, stat.Size())
   143  		if err != nil {
   144  			return nil, err
   145  		}
   146  		r.loadedJARs[jarPath] = zipr
   147  	}
   148  
   149  	zipr := r.loadedJARs[jarPath]
   150  	class, err := zipr.Open(className + ".class")
   151  	if err != nil {
   152  		if errors.Is(err, fs.ErrNotExist) {
   153  			// class not found in this .jar. not an error.
   154  			return nil, errClassNotFound
   155  		}
   156  
   157  		return nil, err
   158  	}
   159  
   160  	return ParseClass(class)
   161  }
   162  
   163  // findClass finds the relevant parsed .class file from a list of classpaths.
   164  func (r *ReachabilityEnumerator) findClass(jarRoot *os.Root, classPaths []string, className string) (*ClassFile, error) {
   165  	// TODO(#787): Support META-INF/versions (multi release JARs) if necessary.
   166  
   167  	// Remove generics from the class name.
   168  	genericRE := regexp.MustCompile(`<.*>`)
   169  	className = genericRE.ReplaceAllString(className, "")
   170  
   171  	// Handle inner class names. The class filename will have a "$" in place of the ".".
   172  	className = strings.ReplaceAll(className, ".", "$")
   173  
   174  	for _, classPath := range classPaths {
   175  		if strings.HasSuffix(classPath, ".jar") {
   176  			cf, err := r.findClassInJAR(jarRoot, classPath, className)
   177  			if err != nil && !errors.Is(err, errClassNotFound) {
   178  				return nil, err
   179  			}
   180  
   181  			if cf != nil {
   182  				log.Debug("found class in nested .jar", "class", className, "path", classPath)
   183  
   184  				return cf, nil
   185  			}
   186  
   187  			continue
   188  		}
   189  
   190  		// Look inside the class directory.
   191  		classFilepath := path.Join(classPath, className)
   192  
   193  		if !strings.HasSuffix(classFilepath, ".class") {
   194  			classFilepath += ".class"
   195  		}
   196  
   197  		if _, err := jarRoot.Stat(classFilepath); errors.Is(err, fs.ErrNotExist) {
   198  			// Class not found in this directory. Move onto the next classpath.
   199  			continue
   200  		}
   201  
   202  		classFile, err := jarRoot.Open(classFilepath)
   203  		if err != nil {
   204  			return nil, err
   205  		}
   206  		cf, err := ParseClass(classFile)
   207  		if err != nil {
   208  			return nil, err
   209  		}
   210  		log.Debug("found class in directory", "class", className, "path", classPath)
   211  
   212  		return cf, nil
   213  	}
   214  
   215  	return nil, errors.New("class not found")
   216  }
   217  
   218  // isDynamicCodeLoading returns whether a method and its descriptor represents a
   219  // call to a dynamic code loading method.
   220  func isDynamicCodeLoading(method string, descriptor string) bool {
   221  	// https://docs.oracle.com/en/java/javase/23/docs/api/java.base/java/lang/ClassLoader.html#loadClass(java.lang.String)
   222  	if strings.Contains(method, "loadClass") && strings.HasSuffix(descriptor, "Ljava/lang/Class;") {
   223  		return true
   224  	}
   225  
   226  	// https://docs.oracle.com/en/java/javase/23/docs/api/java.base/java/lang/Class.html#forName(java.lang.String)
   227  	if strings.Contains(method, "forName") && strings.HasSuffix(descriptor, "Ljava/lang/Class;") {
   228  		return true
   229  	}
   230  
   231  	return false
   232  }
   233  
   234  // isDependencyInjection returns whether a class provides dependency injection functionality.
   235  func isDependencyInjection(class string) bool {
   236  	if strings.HasPrefix(class, "javax/inject") {
   237  		return true
   238  	}
   239  
   240  	if strings.HasPrefix(class, "org/springframework") {
   241  		return true
   242  	}
   243  
   244  	if strings.HasPrefix(class, "com/google/inject") {
   245  		return true
   246  	}
   247  
   248  	if strings.HasPrefix(class, "dagger/") {
   249  		return true
   250  	}
   251  
   252  	return false
   253  }
   254  
   255  // handleDynamicCode handles the enumeration of class reachability when there is
   256  // dynamic code loading, taking into account a user specified strategy.
   257  func (r *ReachabilityEnumerator) handleDynamicCode(jarRoot *os.Root, q *UniqueQueue[string, *ClassFile], class string, strategy DynamicCodeStrategy) error {
   258  	if strategy == DontHandleDynamicCode {
   259  		return nil
   260  	}
   261  
   262  	pkgs, err := r.PackageFinder.Find(class)
   263  	if err != nil {
   264  		return err
   265  	}
   266  
   267  	// Assume all classes that belong to the package are reachable.
   268  	// TODO(#787): Assume all classes that belong to the direct dependencies of the package
   269  	// are reachable.
   270  	if strategy&AssumeAllClassesReachable > 0 {
   271  		for _, pkg := range pkgs {
   272  			classes, err := r.PackageFinder.Classes(pkg)
   273  			if err != nil {
   274  				return err
   275  			}
   276  
   277  			for _, class := range classes {
   278  				if q.Seen(class) {
   279  					continue
   280  				}
   281  				cf, err := r.findClass(jarRoot, r.ClassPaths, class)
   282  				if err == nil {
   283  					log.Debug("assuming all package classes are reachable", "class", class, "pkg", pkg)
   284  					q.Push(class, cf)
   285  				} else {
   286  					log.Debug("failed to find class", "class", class, "from", pkg, "err", err)
   287  				}
   288  			}
   289  		}
   290  	}
   291  
   292  	return nil
   293  }
   294  
   295  func (r *ReachabilityEnumerator) enumerateReachability(
   296  	jarRoot *os.Root, cf *ClassFile, seen map[string]struct{}, codeLoading map[string]struct{}, depInjection map[string]struct{}) error {
   297  	thisClass, err := cf.ConstantPoolClass(int(cf.ThisClass))
   298  	if err != nil {
   299  		return err
   300  	}
   301  
   302  	q := NewQueue[string, *ClassFile](seen)
   303  	q.Push(thisClass, cf)
   304  
   305  	for !q.Empty() {
   306  		thisClass, cf := q.Pop()
   307  		log.Debug("Analyzing", "class", thisClass)
   308  
   309  		// Find uses of dynamic code loading.
   310  		for i, cp := range cf.ConstantPool {
   311  			if cp.Type() == ConstantKindMethodref {
   312  				_, method, descriptor, err := cf.ConstantPoolMethodref(i)
   313  				if err != nil {
   314  					return err
   315  				}
   316  
   317  				if isDynamicCodeLoading(method, descriptor) {
   318  					log.Debug("found dynamic class loading", "thisClass", thisClass, "method", method, "descriptor", descriptor)
   319  					if _, ok := codeLoading[thisClass]; !ok {
   320  						codeLoading[thisClass] = struct{}{}
   321  						err := r.handleDynamicCode(jarRoot, q, thisClass, r.CodeLoadingStrategy)
   322  						if err != nil {
   323  							log.Debug("failed to handle dynamic code", "thisClass", thisClass, "err", err)
   324  						}
   325  					}
   326  				}
   327  			} else if cp.Type() == ConstantKindClass {
   328  				class, err := cf.ConstantPoolClass(i)
   329  				if err != nil {
   330  					return err
   331  				}
   332  
   333  				if isDependencyInjection(class) {
   334  					log.Debug("found dependency injection", "thisClass", thisClass, "injector", class)
   335  					if _, ok := depInjection[thisClass]; !ok {
   336  						depInjection[thisClass] = struct{}{}
   337  						err := r.handleDynamicCode(jarRoot, q, thisClass, r.DependencyInjectionStrategy)
   338  						if err != nil {
   339  							log.Debug("failed to handle dynamic code", "thisClass", thisClass, "err", err)
   340  						}
   341  					}
   342  				}
   343  			}
   344  		}
   345  
   346  		// Enumerate class references.
   347  		for i, cp := range cf.ConstantPool {
   348  			if int(cf.ThisClass) == i {
   349  				// Don't consider this class itself.
   350  				continue
   351  			}
   352  
   353  			class := ""
   354  			if cp.Type() == ConstantKindClass {
   355  				class, err = cf.ConstantPoolClass(i)
   356  				if err != nil {
   357  					return err
   358  				}
   359  			} else if cp.Type() == ConstantKindUtf8 {
   360  				// Also check strings for references to classes.
   361  				val, err := cf.ConstantPoolUtf8(i)
   362  				if err != nil {
   363  					continue
   364  				}
   365  
   366  				// Found a string with the `Lpath/to/class;` format. This is
   367  				// likely a reference to a class. Annotations appear this way.
   368  				if val != "" && val[0] == 'L' && val[len(val)-1] == ';' {
   369  					class = val[1 : len(val)-1]
   370  				}
   371  			}
   372  
   373  			if class == "" {
   374  				continue
   375  			}
   376  
   377  			// Handle arrays.
   378  			if len(class) > 0 && class[0] == '[' {
   379  				// "[" can appear multiple times (nested arrays).
   380  				class = strings.TrimLeft(class, "[")
   381  
   382  				// Array of class type. Extract the class name.
   383  				if len(class) > 0 && class[0] == 'L' {
   384  					class = strings.TrimSuffix(class[1:], ";")
   385  				} else if slices.Contains(BinaryBaseTypes, class) {
   386  					// Base type (e.g. integer): just ignore this.
   387  					continue
   388  				} else {
   389  					// We don't know what the type is.
   390  					return fmt.Errorf("unknown class type %s", class)
   391  				}
   392  			}
   393  
   394  			if IsStdLib(class) {
   395  				continue
   396  			}
   397  
   398  			log.Debug("found", "dependency", class)
   399  			if q.Seen(class) {
   400  				continue
   401  			}
   402  
   403  			depcf, err := r.findClass(jarRoot, r.ClassPaths, class)
   404  			if err != nil {
   405  				// Dependencies can be optional, so this is not a fatal error.
   406  				log.Debug("failed to find class", "class", class, "from", thisClass, "cp idx", i, "error", err)
   407  				continue
   408  			}
   409  
   410  			q.Push(class, depcf)
   411  		}
   412  	}
   413  
   414  	return nil
   415  }