github.com/google/osv-scalibr@v0.4.1/guidedremediation/internal/manifest/python/pipfile.go (about) 1 // Copyright 2025 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package python 16 17 import ( 18 "fmt" 19 "io" 20 "path/filepath" 21 "strings" 22 23 "deps.dev/util/resolve" 24 "deps.dev/util/resolve/dep" 25 "github.com/BurntSushi/toml" 26 scalibrfs "github.com/google/osv-scalibr/fs" 27 "github.com/google/osv-scalibr/guidedremediation/internal/manifest" 28 "github.com/google/osv-scalibr/guidedremediation/result" 29 "github.com/google/osv-scalibr/guidedremediation/strategy" 30 "github.com/google/osv-scalibr/log" 31 ) 32 33 // Pipfile is a struct that represents the contents of a Pipfile. 34 type Pipfile struct { 35 Packages map[string]any `toml:"packages"` 36 DevPackages map[string]any `toml:"dev-packages"` 37 } 38 39 type pipfileReadWriter struct{} 40 41 // GetPipfileReadWriter returns a ReadWriter for Pipfile manifest files. 42 func GetPipfileReadWriter() (manifest.ReadWriter, error) { 43 return pipfileReadWriter{}, nil 44 } 45 46 // System returns the ecosystem of this ReadWriter. 47 func (r pipfileReadWriter) System() resolve.System { 48 return resolve.PyPI 49 } 50 51 // SupportedStrategies returns the remediation strategies supported for this manifest. 52 func (r pipfileReadWriter) SupportedStrategies() []strategy.Strategy { 53 return []strategy.Strategy{strategy.StrategyRelax} 54 } 55 56 // Read parses the manifest from the given file, preserving the order of dependencies. 57 func (r pipfileReadWriter) Read(path string, fsys scalibrfs.FS) (manifest.Manifest, error) { 58 path = filepath.ToSlash(path) 59 f, err := fsys.Open(path) 60 if err != nil { 61 return nil, err 62 } 63 defer f.Close() 64 65 var pipfile Pipfile 66 md, err := toml.NewDecoder(f).Decode(&pipfile) 67 if err != nil { 68 return nil, fmt.Errorf("failed to unmarshal Pipfile: %w", err) 69 } 70 71 var packageKeys []string 72 var devPackageKeys []string 73 for _, key := range md.Keys() { 74 // A key for a dependency will have 2 parts: `[section, name]` 75 if len(key) == 2 { 76 switch key[0] { 77 case "packages": 78 packageKeys = append(packageKeys, key[1]) 79 case "dev-packages": 80 devPackageKeys = append(devPackageKeys, key[1]) 81 } 82 } 83 } 84 85 allReqs := []resolve.RequirementVersion{} 86 groups := make(map[manifest.RequirementKey][]string) 87 88 // Packages 89 pkgReqs := parsePipfileDependencies(pipfile.Packages, packageKeys, false) 90 allReqs = append(allReqs, pkgReqs...) 91 92 // Dev packages 93 devPkgReqs := parsePipfileDependencies(pipfile.DevPackages, devPackageKeys, true) 94 allReqs = append(allReqs, devPkgReqs...) 95 for _, r := range devPkgReqs { 96 key := manifest.RequirementKey(r.PackageKey) 97 groups[key] = append(groups[key], "dev") 98 } 99 100 return &pythonManifest{ 101 filePath: path, 102 root: resolve.Version{ 103 VersionKey: resolve.VersionKey{ 104 PackageKey: resolve.PackageKey{ 105 System: resolve.PyPI, 106 Name: "rootproject", // Pipfile doesn't have a project name 107 }, 108 VersionType: resolve.Concrete, 109 Version: "1.0.0", 110 }, 111 }, 112 requirements: allReqs, 113 groups: groups, 114 }, nil 115 } 116 117 // parsePipfileDependencies converts a map of dependencies from a Pipfile's [packages] or 118 // [dev-packages] section into a slice of resolve.RequirementVersion, respecting the original key order. 119 func parsePipfileDependencies(deps map[string]any, keys []string, dev bool) []resolve.RequirementVersion { 120 var reqs []resolve.RequirementVersion 121 if deps == nil { 122 return reqs 123 } 124 125 var dt dep.Type 126 if dev { 127 dt.AddAttr(dep.Dev, "") 128 } 129 for _, name := range keys { 130 details, ok := deps[name] 131 if !ok { 132 continue // Should not happen if keys are from metadata 133 } 134 if constraint, ok := extractVersionConstraint(name, details); ok { 135 reqs = append(reqs, resolve.RequirementVersion{ 136 VersionKey: resolve.VersionKey{ 137 PackageKey: resolve.PackageKey{ 138 System: resolve.PyPI, 139 Name: name, 140 }, 141 Version: constraint, 142 VersionType: resolve.Requirement, 143 }, 144 Type: dt, 145 }) 146 } 147 } 148 return reqs 149 } 150 151 // extractVersionConstraint parses a single dependency entry from a Pipfile. 152 // It returns the version constraint string and a boolean indicating if parsing was successful. 153 // It skips over non-version dependencies like git or path references, returning false in those cases. 154 func extractVersionConstraint(name string, details any) (string, bool) { 155 switch v := details.(type) { 156 case string: 157 return v, true 158 case map[string]any: 159 if vs, ok := v["version"].(string); ok { 160 return vs, true 161 } else if _, ok := v["git"]; ok { 162 log.Infof("Skipping git dependency in Pipfile for package %q", name) 163 return "", false 164 } else if _, ok := v["path"]; ok { 165 log.Infof("Skipping path dependency in Pipfile for package %q", name) 166 return "", false 167 } 168 default: 169 log.Warnf("unsupported dependency format in Pipfile for package %q", name) 170 return "", false 171 } 172 173 return "", false 174 } 175 176 // Write writes the manifest after applying the patches to outputPath. 177 func (r pipfileReadWriter) Write(original manifest.Manifest, fsys scalibrfs.FS, patches []result.Patch, outputPath string) error { 178 return write(fsys, original.FilePath(), outputPath, patches, updatePipfile) 179 } 180 181 // updatePipfile takes an io.Reader representing the Pipfile 182 // and a map of package names to their new version constraints, returns the 183 // file with the updated requirements as a string. 184 func updatePipfile(reader io.Reader, requirements []TokenizedRequirements) (string, error) { 185 data, err := io.ReadAll(reader) 186 if err != nil { 187 return "", fmt.Errorf("error reading requirements: %w", err) 188 } 189 content := string(data) 190 191 var pipfile Pipfile 192 if _, err := toml.Decode(content, &pipfile); err != nil { 193 return "", fmt.Errorf("failed to unmarshal Pipfile: %w", err) 194 } 195 196 names := make(map[string]bool, len(requirements)) 197 for _, req := range requirements { 198 names[req.Name] = true 199 } 200 201 var sb strings.Builder 202 for _, line := range strings.SplitAfter(content, "\n") { 203 name, ok := dependencyToUpdate(line, names) 204 if !ok { 205 // This line is not a dependency requirement. 206 sb.WriteString(line) 207 continue 208 } 209 210 detail, ok := pipfile.Packages[name] 211 if !ok { 212 detail, ok = pipfile.DevPackages[name] 213 } 214 if !ok { 215 // Not a dependency found in packages or dev-packages. 216 sb.WriteString(line) 217 continue 218 } 219 220 oldVersion, ok := extractVersionConstraint(name, detail) 221 if !ok { 222 // We cannot parse this dependency requirement. 223 sb.WriteString(line) 224 continue 225 } 226 newReq, ok := findTokenizedRequirement(requirements, name, tokenizeRequirement(oldVersion)) 227 if !ok { 228 // We cannot find the new requirement. 229 sb.WriteString(line) 230 continue 231 } 232 233 newLine := strings.Replace(line, "\""+oldVersion+"\"", "\""+formatConstraints(newReq, false)+"\"", 1) 234 sb.WriteString(newLine) 235 } 236 237 return sb.String(), nil 238 } 239 240 // dependencyToUpdate checks if the given line contains a dependency that needs to be updated. 241 // It returns the name of the dependency and true if it needs to be updated, otherwise false. 242 func dependencyToUpdate(line string, names map[string]bool) (string, bool) { 243 trimmedLine := strings.TrimSpace(line) 244 if trimmedLine == "" { 245 return "", false 246 } 247 if strings.HasPrefix(trimmedLine, "[") || strings.HasPrefix(trimmedLine, "#") { 248 return "", false 249 } 250 parts := strings.SplitN(trimmedLine, "=", 2) 251 if len(parts) < 2 { 252 return "", false 253 } 254 name := strings.TrimSpace(parts[0]) 255 return name, names[name] 256 }