github.com/openimsdk/tools@v0.0.49/utils/mageutil/gen_protocol.go (about) 1 package mageutil 2 3 import ( 4 "archive/zip" 5 "bufio" 6 "fmt" 7 "github.com/magefile/mage/sh" 8 "io" 9 "net/http" 10 "os" 11 "os/exec" 12 "path/filepath" 13 "runtime" 14 "strings" 15 ) 16 17 func ensureToolsInstalled() error { 18 tools := map[string]string{ 19 "protoc-gen-go": "https://github.com/golang/protobuf/tree/master/protoc-gen-go@latest", 20 } 21 22 // Setting GOBIN based on OS, Windows needs a different default path 23 var targetDir string 24 if runtime.GOOS == "windows" { 25 targetDir = filepath.Join(os.Getenv("USERPROFILE"), "go", "bin") 26 } else { 27 targetDir = "/usr/local/bin" 28 } 29 30 os.Setenv("GOBIN", targetDir) 31 32 for tool, path := range tools { 33 if _, err := exec.LookPath(filepath.Join(targetDir, tool)); err != nil { 34 fmt.Printf("Installing %s to %s...\n", tool, targetDir) 35 if err := sh.Run("go", "install", path); err != nil { 36 return fmt.Errorf("failed to install %s: %s", tool, err) 37 } 38 } else { 39 fmt.Printf("%s is already installed in %s.\n", tool, targetDir) 40 } 41 } 42 43 if _, err := exec.LookPath(filepath.Join(targetDir, "protoc")); err == nil { 44 fmt.Println("protoc is already installed.") 45 return nil 46 } 47 48 fmt.Println("Installing protoc...") 49 return installProtoc(targetDir) 50 } 51 52 func installProtoc(installDir string) error { 53 version := "26.1" 54 baseURL := "https://github.com/protocolbuffers/protobuf/releases/download/v" + version 55 archMap := map[string]string{ 56 "amd64": "x86_64", 57 "386": "x86", 58 "arm64": "aarch64", 59 } 60 protocFile := "protoc-%s-%s.zip" 61 62 osArch := runtime.GOOS + "-" + getProtocArch(archMap, runtime.GOARCH) 63 if runtime.GOOS == "windows" { 64 osArch = "win64" // assuming 64-bit, for 32-bit use "win32" 65 } 66 fileName := fmt.Sprintf(protocFile, version, osArch) 67 url := baseURL + "/" + fileName 68 69 fmt.Println("URL:", url) 70 71 resp, err := http.Get(url) 72 if err != nil { 73 return err 74 } 75 defer resp.Body.Close() 76 77 // Create a temporary file 78 tmpFile, err := os.CreateTemp("", "protoc-*.zip") 79 if err != nil { 80 return err 81 } 82 defer tmpFile.Close() 83 84 _, err = io.Copy(tmpFile, resp.Body) 85 if err != nil { 86 return err 87 } 88 fmt.Println("tmp ", tmpFile.Name(), "install ", installDir) 89 return unzip(tmpFile.Name(), installDir) 90 } 91 92 func unzip(src, dest string) error { 93 r, err := zip.OpenReader(src) 94 if err != nil { 95 return err 96 } 97 defer r.Close() 98 99 for _, f := range r.File { 100 fpath := filepath.Join(dest, f.Name) 101 if f.FileInfo().IsDir() { 102 os.MkdirAll(fpath, os.ModePerm) 103 continue 104 } 105 106 outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) 107 if err != nil { 108 return err 109 } 110 111 rc, err := f.Open() 112 if err != nil { 113 outFile.Close() 114 return err 115 } 116 117 _, err = io.Copy(outFile, rc) 118 outFile.Close() 119 rc.Close() 120 if err != nil { 121 return err 122 } 123 } 124 return nil 125 } 126 127 func getProtocArch(archMap map[string]string, goArch string) string { 128 if arch, ok := archMap[goArch]; ok { 129 return arch 130 } 131 return goArch 132 } 133 134 func Protocol() error { 135 if err := ensureToolsInstalled(); err != nil { 136 fmt.Println("error ", err.Error()) 137 os.Exit(1) 138 } 139 140 moduleName, err := getModuleNameFromGoMod() 141 if err != nil { 142 fmt.Println("error fetching module name from go.mod: ", err.Error()) 143 os.Exit(1) 144 } 145 146 protoPath := "./pkg/protocol" 147 dirs, err := os.ReadDir(protoPath) 148 if err != nil { 149 fmt.Println("error ", err.Error()) 150 os.Exit(1) 151 } 152 153 for _, dir := range dirs { 154 if dir.IsDir() { 155 if err := compileProtoFiles(protoPath, dir.Name(), moduleName); err != nil { 156 fmt.Println("error ", err.Error()) 157 os.Exit(1) 158 } 159 } 160 } 161 return nil 162 } 163 func compileProtoFiles(basePath, dirName, moduleName string) error { 164 protoFile := filepath.Join(basePath, dirName, dirName+".proto") 165 outputDir := filepath.Join(basePath, dirName) 166 module := moduleName + "/pkg/protocol/" + dirName 167 args := []string{ 168 "--go_out=plugins=grpc:" + outputDir, 169 "--go_opt=module=" + module, 170 protoFile, 171 } 172 fmt.Printf("Compiling %s...\n", protoFile) 173 if err := sh.Run("protoc", args...); err != nil { 174 return fmt.Errorf("failed to compile %s: %s", protoFile, err) 175 } 176 return fixOmitemptyInDirectory(outputDir) 177 } 178 179 func fixOmitemptyInDirectory(dir string) error { 180 files, err := filepath.Glob(filepath.Join(dir, "*.pb.go")) 181 if err != nil { 182 return fmt.Errorf("failed to list .pb.go files in %s: %s", dir, err) 183 } 184 fmt.Printf("Fixing omitempty in dir %s...\n", dir) 185 for _, file := range files { 186 fmt.Printf("Fixing omitempty in %s...\n", file) 187 if err := RemoveOmitemptyFromFile(file); err != nil { 188 return fmt.Errorf("failed to replace omitempty in %s: %s", file, err) 189 } 190 } 191 return nil 192 } 193 194 func RemoveOmitemptyFromFile(filePath string) error { 195 file, err := os.Open(filePath) 196 if err != nil { 197 return fmt.Errorf("error opening file: %s", err) 198 } 199 defer file.Close() 200 201 var lines []string 202 scanner := bufio.NewScanner(file) 203 for scanner.Scan() { 204 line := scanner.Text() 205 line = strings.ReplaceAll(line, ",omitempty", "") 206 lines = append(lines, line) 207 } 208 if err := scanner.Err(); err != nil { 209 return fmt.Errorf("error reading file: %s", err) 210 } 211 212 return writeLines(lines, filePath) 213 } 214 215 // writeLines writes the lines to the given file. 216 func writeLines(lines []string, path string) error { 217 file, err := os.Create(path) 218 if err != nil { 219 return fmt.Errorf("error creating file: %s", err) 220 } 221 defer file.Close() 222 223 w := bufio.NewWriter(file) 224 for _, line := range lines { 225 if _, err := fmt.Fprintln(w, line); err != nil { 226 return fmt.Errorf("error writing to file: %s", err) 227 } 228 } 229 return w.Flush() 230 } 231 232 // getModuleNameFromGoMod extracts the module name from go.mod file. 233 func getModuleNameFromGoMod() (string, error) { 234 file, err := os.Open("go.mod") 235 if err != nil { 236 return "", fmt.Errorf("failed to open go.mod: %v", err) 237 } 238 defer file.Close() 239 240 scanner := bufio.NewScanner(file) 241 for scanner.Scan() { 242 line := scanner.Text() 243 if strings.HasPrefix(line, "module ") { 244 // Assuming line looks like "module github.com/user/repo" 245 return strings.TrimSpace(strings.TrimPrefix(line, "module")), nil 246 } 247 } 248 249 if err := scanner.Err(); err != nil { 250 return "", fmt.Errorf("error reading go.mod: %v", err) 251 } 252 253 return "", fmt.Errorf("module directive not found in go.mod") 254 }