github.com/openimsdk/tools@v0.0.49/utils/mageutil/gen_protocol.go (about)

     1  package mageutil
     2  
     3  import (
     4  	"archive/zip"
     5  	"bufio"
     6  	"fmt"
     7  	"github.com/magefile/mage/sh"
     8  	"io"
     9  	"net/http"
    10  	"os"
    11  	"os/exec"
    12  	"path/filepath"
    13  	"runtime"
    14  	"strings"
    15  )
    16  
    17  func ensureToolsInstalled() error {
    18  	tools := map[string]string{
    19  		"protoc-gen-go": "https://github.com/golang/protobuf/tree/master/protoc-gen-go@latest",
    20  	}
    21  
    22  	// Setting GOBIN based on OS, Windows needs a different default path
    23  	var targetDir string
    24  	if runtime.GOOS == "windows" {
    25  		targetDir = filepath.Join(os.Getenv("USERPROFILE"), "go", "bin")
    26  	} else {
    27  		targetDir = "/usr/local/bin"
    28  	}
    29  
    30  	os.Setenv("GOBIN", targetDir)
    31  
    32  	for tool, path := range tools {
    33  		if _, err := exec.LookPath(filepath.Join(targetDir, tool)); err != nil {
    34  			fmt.Printf("Installing %s to %s...\n", tool, targetDir)
    35  			if err := sh.Run("go", "install", path); err != nil {
    36  				return fmt.Errorf("failed to install %s: %s", tool, err)
    37  			}
    38  		} else {
    39  			fmt.Printf("%s is already installed in %s.\n", tool, targetDir)
    40  		}
    41  	}
    42  
    43  	if _, err := exec.LookPath(filepath.Join(targetDir, "protoc")); err == nil {
    44  		fmt.Println("protoc is already installed.")
    45  		return nil
    46  	}
    47  
    48  	fmt.Println("Installing protoc...")
    49  	return installProtoc(targetDir)
    50  }
    51  
    52  func installProtoc(installDir string) error {
    53  	version := "26.1"
    54  	baseURL := "https://github.com/protocolbuffers/protobuf/releases/download/v" + version
    55  	archMap := map[string]string{
    56  		"amd64": "x86_64",
    57  		"386":   "x86",
    58  		"arm64": "aarch64",
    59  	}
    60  	protocFile := "protoc-%s-%s.zip"
    61  
    62  	osArch := runtime.GOOS + "-" + getProtocArch(archMap, runtime.GOARCH)
    63  	if runtime.GOOS == "windows" {
    64  		osArch = "win64" // assuming 64-bit, for 32-bit use "win32"
    65  	}
    66  	fileName := fmt.Sprintf(protocFile, version, osArch)
    67  	url := baseURL + "/" + fileName
    68  
    69  	fmt.Println("URL:", url)
    70  
    71  	resp, err := http.Get(url)
    72  	if err != nil {
    73  		return err
    74  	}
    75  	defer resp.Body.Close()
    76  
    77  	// Create a temporary file
    78  	tmpFile, err := os.CreateTemp("", "protoc-*.zip")
    79  	if err != nil {
    80  		return err
    81  	}
    82  	defer tmpFile.Close()
    83  
    84  	_, err = io.Copy(tmpFile, resp.Body)
    85  	if err != nil {
    86  		return err
    87  	}
    88  	fmt.Println("tmp ", tmpFile.Name(), "install  ", installDir)
    89  	return unzip(tmpFile.Name(), installDir)
    90  }
    91  
    92  func unzip(src, dest string) error {
    93  	r, err := zip.OpenReader(src)
    94  	if err != nil {
    95  		return err
    96  	}
    97  	defer r.Close()
    98  
    99  	for _, f := range r.File {
   100  		fpath := filepath.Join(dest, f.Name)
   101  		if f.FileInfo().IsDir() {
   102  			os.MkdirAll(fpath, os.ModePerm)
   103  			continue
   104  		}
   105  
   106  		outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
   107  		if err != nil {
   108  			return err
   109  		}
   110  
   111  		rc, err := f.Open()
   112  		if err != nil {
   113  			outFile.Close()
   114  			return err
   115  		}
   116  
   117  		_, err = io.Copy(outFile, rc)
   118  		outFile.Close()
   119  		rc.Close()
   120  		if err != nil {
   121  			return err
   122  		}
   123  	}
   124  	return nil
   125  }
   126  
   127  func getProtocArch(archMap map[string]string, goArch string) string {
   128  	if arch, ok := archMap[goArch]; ok {
   129  		return arch
   130  	}
   131  	return goArch
   132  }
   133  
   134  func Protocol() error {
   135  	if err := ensureToolsInstalled(); err != nil {
   136  		fmt.Println("error ", err.Error())
   137  		os.Exit(1)
   138  	}
   139  
   140  	moduleName, err := getModuleNameFromGoMod()
   141  	if err != nil {
   142  		fmt.Println("error fetching module name from go.mod: ", err.Error())
   143  		os.Exit(1)
   144  	}
   145  
   146  	protoPath := "./pkg/protocol"
   147  	dirs, err := os.ReadDir(protoPath)
   148  	if err != nil {
   149  		fmt.Println("error ", err.Error())
   150  		os.Exit(1)
   151  	}
   152  
   153  	for _, dir := range dirs {
   154  		if dir.IsDir() {
   155  			if err := compileProtoFiles(protoPath, dir.Name(), moduleName); err != nil {
   156  				fmt.Println("error ", err.Error())
   157  				os.Exit(1)
   158  			}
   159  		}
   160  	}
   161  	return nil
   162  }
   163  func compileProtoFiles(basePath, dirName, moduleName string) error {
   164  	protoFile := filepath.Join(basePath, dirName, dirName+".proto")
   165  	outputDir := filepath.Join(basePath, dirName)
   166  	module := moduleName + "/pkg/protocol/" + dirName
   167  	args := []string{
   168  		"--go_out=plugins=grpc:" + outputDir,
   169  		"--go_opt=module=" + module,
   170  		protoFile,
   171  	}
   172  	fmt.Printf("Compiling %s...\n", protoFile)
   173  	if err := sh.Run("protoc", args...); err != nil {
   174  		return fmt.Errorf("failed to compile %s: %s", protoFile, err)
   175  	}
   176  	return fixOmitemptyInDirectory(outputDir)
   177  }
   178  
   179  func fixOmitemptyInDirectory(dir string) error {
   180  	files, err := filepath.Glob(filepath.Join(dir, "*.pb.go"))
   181  	if err != nil {
   182  		return fmt.Errorf("failed to list .pb.go files in %s: %s", dir, err)
   183  	}
   184  	fmt.Printf("Fixing omitempty in dir  %s...\n", dir)
   185  	for _, file := range files {
   186  		fmt.Printf("Fixing omitempty in %s...\n", file)
   187  		if err := RemoveOmitemptyFromFile(file); err != nil {
   188  			return fmt.Errorf("failed to replace omitempty in %s: %s", file, err)
   189  		}
   190  	}
   191  	return nil
   192  }
   193  
   194  func RemoveOmitemptyFromFile(filePath string) error {
   195  	file, err := os.Open(filePath)
   196  	if err != nil {
   197  		return fmt.Errorf("error opening file: %s", err)
   198  	}
   199  	defer file.Close()
   200  
   201  	var lines []string
   202  	scanner := bufio.NewScanner(file)
   203  	for scanner.Scan() {
   204  		line := scanner.Text()
   205  		line = strings.ReplaceAll(line, ",omitempty", "")
   206  		lines = append(lines, line)
   207  	}
   208  	if err := scanner.Err(); err != nil {
   209  		return fmt.Errorf("error reading file: %s", err)
   210  	}
   211  
   212  	return writeLines(lines, filePath)
   213  }
   214  
   215  // writeLines writes the lines to the given file.
   216  func writeLines(lines []string, path string) error {
   217  	file, err := os.Create(path)
   218  	if err != nil {
   219  		return fmt.Errorf("error creating file: %s", err)
   220  	}
   221  	defer file.Close()
   222  
   223  	w := bufio.NewWriter(file)
   224  	for _, line := range lines {
   225  		if _, err := fmt.Fprintln(w, line); err != nil {
   226  			return fmt.Errorf("error writing to file: %s", err)
   227  		}
   228  	}
   229  	return w.Flush()
   230  }
   231  
   232  // getModuleNameFromGoMod extracts the module name from go.mod file.
   233  func getModuleNameFromGoMod() (string, error) {
   234  	file, err := os.Open("go.mod")
   235  	if err != nil {
   236  		return "", fmt.Errorf("failed to open go.mod: %v", err)
   237  	}
   238  	defer file.Close()
   239  
   240  	scanner := bufio.NewScanner(file)
   241  	for scanner.Scan() {
   242  		line := scanner.Text()
   243  		if strings.HasPrefix(line, "module ") {
   244  			// Assuming line looks like "module github.com/user/repo"
   245  			return strings.TrimSpace(strings.TrimPrefix(line, "module")), nil
   246  		}
   247  	}
   248  
   249  	if err := scanner.Err(); err != nil {
   250  		return "", fmt.Errorf("error reading go.mod: %v", err)
   251  	}
   252  
   253  	return "", fmt.Errorf("module directive not found in go.mod")
   254  }