github.com/cybriq/giocore@v0.0.7-0.20210703034601-cfb9cb5f3900/gpu/internal/convertshaders/main.go (about)

     1  // SPDX-License-Identifier: Unlicense OR MIT
     2  
     3  package main
     4  
     5  import (
     6  	"bytes"
     7  	"errors"
     8  	"flag"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"os"
    13  	"os/exec"
    14  	"path/filepath"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"text/template"
    20  
    21  	"github.com/cybriq/giocore/gpu/internal/driver"
    22  )
    23  
    24  func main() {
    25  	packageName := flag.String("package", "", "specify Go package name")
    26  	workdir := flag.String("work", "", "temporary working directory (default TEMP)")
    27  	shadersDir := flag.String("dir", "shaders", "shaders directory")
    28  	directCompute := flag.Bool("directcompute", false, "enable compiling DirectCompute shaders")
    29  
    30  	flag.Parse()
    31  
    32  	var work WorkDir
    33  	cleanup := func() {}
    34  	if *workdir == "" {
    35  		tempdir, err := ioutil.TempDir("", "shader-convert")
    36  		if err != nil {
    37  			fmt.Fprintf(os.Stderr, "failed to create tempdir: %v\n", err)
    38  			os.Exit(1)
    39  		}
    40  		cleanup = func() { os.RemoveAll(tempdir) }
    41  		defer cleanup()
    42  
    43  		work = WorkDir(tempdir)
    44  	} else {
    45  		if abs, err := filepath.Abs(*workdir); err == nil {
    46  			*workdir = abs
    47  		}
    48  		work = WorkDir(*workdir)
    49  	}
    50  
    51  	var out bytes.Buffer
    52  	conv := NewConverter(work, *packageName, *shadersDir, *directCompute)
    53  	if err := conv.Run(&out); err != nil {
    54  		fmt.Fprintf(os.Stderr, "%v\n", err)
    55  		cleanup()
    56  		os.Exit(1)
    57  	}
    58  
    59  	if err := ioutil.WriteFile("shaders.go", out.Bytes(), 0644); err != nil {
    60  		fmt.Fprintf(os.Stderr, "failed to create shaders: %v\n", err)
    61  		cleanup()
    62  		os.Exit(1)
    63  	}
    64  
    65  	cmd := exec.Command("gofmt", "-s", "-w", "shaders.go")
    66  	cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
    67  	if err := cmd.Run(); err != nil {
    68  		fmt.Fprintf(os.Stderr, "formatting shaders.go failed: %v\n", err)
    69  		cleanup()
    70  		os.Exit(1)
    71  	}
    72  }
    73  
    74  type Converter struct {
    75  	workDir       WorkDir
    76  	shadersDir    string
    77  	directCompute bool
    78  
    79  	packageName string
    80  
    81  	glslvalidator *GLSLValidator
    82  	spirv         *SPIRVCross
    83  	fxc           *FXC
    84  }
    85  
    86  func NewConverter(workDir WorkDir, packageName, shadersDir string, directCompute bool) *Converter {
    87  	if abs, err := filepath.Abs(shadersDir); err == nil {
    88  		shadersDir = abs
    89  	}
    90  
    91  	conv := &Converter{}
    92  	conv.workDir = workDir
    93  	conv.shadersDir = shadersDir
    94  	conv.directCompute = directCompute
    95  
    96  	conv.packageName = packageName
    97  
    98  	conv.glslvalidator = NewGLSLValidator()
    99  	conv.spirv = NewSPIRVCross()
   100  	conv.fxc = NewFXC()
   101  
   102  	verifyBinaryPath(&conv.glslvalidator.Bin)
   103  	verifyBinaryPath(&conv.spirv.Bin)
   104  	// We cannot check fxc since it may depend on wine.
   105  
   106  	conv.glslvalidator.WorkDir = workDir.Dir("glslvalidator")
   107  	conv.fxc.WorkDir = workDir.Dir("fxc")
   108  	conv.spirv.WorkDir = workDir.Dir("spirv")
   109  
   110  	return conv
   111  }
   112  
   113  func verifyBinaryPath(bin *string) {
   114  	new, err := exec.LookPath(*bin)
   115  	if err != nil {
   116  		fmt.Fprintf(os.Stderr, "unable to find %q: %v\n", *bin, err)
   117  	} else {
   118  		*bin = new
   119  	}
   120  }
   121  
   122  func (conv *Converter) Run(out io.Writer) error {
   123  	shaders, err := filepath.Glob(filepath.Join(conv.shadersDir, "*"))
   124  	if len(shaders) == 0 || err != nil {
   125  		return fmt.Errorf("failed to list shaders in %q: %w", conv.shadersDir, err)
   126  	}
   127  
   128  	sort.Strings(shaders)
   129  
   130  	var workers Workers
   131  
   132  	type ShaderResult struct {
   133  		Path    string
   134  		Shaders []driver.ShaderSources
   135  		Error   error
   136  	}
   137  	shaderResults := make([]ShaderResult, len(shaders))
   138  
   139  	for i, shaderPath := range shaders {
   140  		i, shaderPath := i, shaderPath
   141  
   142  		switch filepath.Ext(shaderPath) {
   143  		case ".vert", ".frag":
   144  			workers.Go(func() {
   145  				shaders, err := conv.Shader(shaderPath)
   146  				shaderResults[i] = ShaderResult{
   147  					Path:    shaderPath,
   148  					Shaders: shaders,
   149  					Error:   err,
   150  				}
   151  			})
   152  		case ".comp":
   153  			workers.Go(func() {
   154  				shaders, err := conv.ComputeShader(shaderPath)
   155  				shaderResults[i] = ShaderResult{
   156  					Path:    shaderPath,
   157  					Shaders: shaders,
   158  					Error:   err,
   159  				}
   160  			})
   161  		default:
   162  			continue
   163  		}
   164  	}
   165  
   166  	workers.Wait()
   167  
   168  	var allErrors string
   169  	for _, r := range shaderResults {
   170  		if r.Error != nil {
   171  			if len(allErrors) > 0 {
   172  				allErrors += "\n\n"
   173  			}
   174  			allErrors += "--- " + r.Path + " --- \n\n" + r.Error.Error() + "\n"
   175  		}
   176  	}
   177  	if len(allErrors) > 0 {
   178  		return errors.New(allErrors)
   179  	}
   180  
   181  	fmt.Fprintf(out, "// Code generated by build.go. DO NOT EDIT.\n\n")
   182  	fmt.Fprintf(out, "package %s\n\n", conv.packageName)
   183  	fmt.Fprintf(out, "import %q\n\n", "github.com/cybriq/giocore/gpu/internal/driver")
   184  
   185  	fmt.Fprintf(out, "var (\n")
   186  
   187  	for _, r := range shaderResults {
   188  		if len(r.Shaders) == 0 {
   189  			continue
   190  		}
   191  
   192  		name := filepath.Base(r.Path)
   193  		name = strings.ReplaceAll(name, ".", "_")
   194  		fmt.Fprintf(out, "\tshader_%s = ", name)
   195  
   196  		multiVariant := len(r.Shaders) > 1
   197  		if multiVariant {
   198  			fmt.Fprintf(out, "[...]driver.ShaderSources{\n")
   199  		}
   200  
   201  		for _, src := range r.Shaders {
   202  			fmt.Fprintf(out, "driver.ShaderSources{\n")
   203  			fmt.Fprintf(out, "Name: %#v,\n", src.Name)
   204  			if len(src.Inputs) > 0 {
   205  				fmt.Fprintf(out, "Inputs: %#v,\n", src.Inputs)
   206  			}
   207  			if u := src.Uniforms; len(u.Blocks) > 0 {
   208  				fmt.Fprintf(out, "Uniforms: driver.UniformsReflection{\n")
   209  				fmt.Fprintf(out, "Blocks: %#v,\n", u.Blocks)
   210  				fmt.Fprintf(out, "Locations: %#v,\n", u.Locations)
   211  				fmt.Fprintf(out, "Size: %d,\n", u.Size)
   212  				fmt.Fprintf(out, "},\n")
   213  			}
   214  			if len(src.Textures) > 0 {
   215  				fmt.Fprintf(out, "Textures: %#v,\n", src.Textures)
   216  			}
   217  			if len(src.GLSL100ES) > 0 {
   218  				fmt.Fprintf(out, "GLSL100ES: `%s`,\n", src.GLSL100ES)
   219  			}
   220  			if len(src.GLSL300ES) > 0 {
   221  				fmt.Fprintf(out, "GLSL300ES: `%s`,\n", src.GLSL300ES)
   222  			}
   223  			if len(src.GLSL310ES) > 0 {
   224  				fmt.Fprintf(out, "GLSL310ES: `%s`,\n", src.GLSL310ES)
   225  			}
   226  			if len(src.GLSL130) > 0 {
   227  				fmt.Fprintf(out, "GLSL130: `%s`,\n", src.GLSL130)
   228  			}
   229  			if len(src.GLSL150) > 0 {
   230  				fmt.Fprintf(out, "GLSL150: `%s`,\n", src.GLSL150)
   231  			}
   232  			if len(src.HLSL) > 0 {
   233  				fmt.Fprintf(out, "HLSL: %q,\n", src.HLSL)
   234  			}
   235  			fmt.Fprintf(out, "}")
   236  			if multiVariant {
   237  				fmt.Fprintf(out, ",")
   238  			}
   239  			fmt.Fprintf(out, "\n")
   240  		}
   241  		if multiVariant {
   242  			fmt.Fprintf(out, "}\n")
   243  		}
   244  	}
   245  	fmt.Fprintf(out, ")\n")
   246  
   247  	return nil
   248  }
   249  
   250  func (conv *Converter) Shader(shaderPath string) ([]driver.ShaderSources, error) {
   251  	type Variant struct {
   252  		FetchColorExpr string
   253  		Header         string
   254  	}
   255  	variantArgs := [...]Variant{
   256  		{
   257  			FetchColorExpr: `_color.color`,
   258  			Header:         `layout(binding=0) uniform Color { vec4 color; } _color;`,
   259  		},
   260  		{
   261  			FetchColorExpr: `mix(_gradient.color1, _gradient.color2, clamp(vUV.x, 0.0, 1.0))`,
   262  			Header:         `layout(binding=0) uniform Gradient { vec4 color1; vec4 color2; } _gradient;`,
   263  		},
   264  		{
   265  			FetchColorExpr: `texture(tex, vUV)`,
   266  			Header:         `layout(binding=0) uniform sampler2D tex;`,
   267  		},
   268  	}
   269  
   270  	shaderTemplate, err := template.ParseFiles(shaderPath)
   271  	if err != nil {
   272  		return nil, fmt.Errorf("failed to parse template %q: %w", shaderPath, err)
   273  	}
   274  
   275  	var variants []driver.ShaderSources
   276  	for i, variantArg := range variantArgs {
   277  		variantName := strconv.Itoa(i)
   278  		var buf bytes.Buffer
   279  		err := shaderTemplate.Execute(&buf, variantArg)
   280  		if err != nil {
   281  			return nil, fmt.Errorf("failed to execute template %q with %#v: %w", shaderPath, variantArg, err)
   282  		}
   283  
   284  		var sources driver.ShaderSources
   285  		sources.Name = filepath.Base(shaderPath)
   286  
   287  		// Ignore error; some shaders are not meant to run in GLSL 1.00.
   288  		sources.GLSL100ES, _, _ = conv.ShaderVariant(shaderPath, variantName, buf.Bytes(), "es", "100")
   289  
   290  		var metadata Metadata
   291  		sources.GLSL300ES, metadata, err = conv.ShaderVariant(shaderPath, variantName, buf.Bytes(), "es", "300")
   292  		if err != nil {
   293  			return nil, fmt.Errorf("failed to convert GLSL300ES:\n%w", err)
   294  		}
   295  
   296  		sources.GLSL130, _, err = conv.ShaderVariant(shaderPath, variantName, buf.Bytes(), "glsl", "130")
   297  		if err != nil {
   298  			return nil, fmt.Errorf("failed to convert GLSL130:\n%w", err)
   299  		}
   300  
   301  		hlsl, _, err := conv.ShaderVariant(shaderPath, variantName, buf.Bytes(), "hlsl", "40")
   302  		if err != nil {
   303  			return nil, fmt.Errorf("failed to convert HLSL:\n%w", err)
   304  		}
   305  		sources.HLSL, err = conv.fxc.Compile(shaderPath, variantName, []byte(hlsl), "main", "4_0_level_9_1")
   306  		if err != nil {
   307  			// Attempt shader model 4.0. Only the gpu/headless
   308  			// test shaders use features not supported by level
   309  			// 9.1.
   310  			sources.HLSL, err = conv.fxc.Compile(shaderPath, variantName, []byte(hlsl), "main", "4_0")
   311  			if err != nil {
   312  				return nil, fmt.Errorf("failed to compile HLSL: %w", err)
   313  			}
   314  		}
   315  
   316  		sources.GLSL150, _, err = conv.ShaderVariant(shaderPath, variantName, buf.Bytes(), "glsl", "150")
   317  		if err != nil {
   318  			return nil, fmt.Errorf("failed to convert GLSL150:\n%w", err)
   319  		}
   320  
   321  		sources.Uniforms = metadata.Uniforms
   322  		sources.Inputs = metadata.Inputs
   323  		sources.Textures = metadata.Textures
   324  
   325  		variants = append(variants, sources)
   326  	}
   327  
   328  	// If the shader don't use the variant arguments, output only a single version.
   329  	if variants[0].GLSL100ES == variants[1].GLSL100ES {
   330  		variants = variants[:1]
   331  	}
   332  
   333  	return variants, nil
   334  }
   335  
   336  func (conv *Converter) ShaderVariant(shaderPath, variant string, src []byte, lang, profile string) (string, Metadata, error) {
   337  	spirv, err := conv.glslvalidator.Convert(shaderPath, variant, lang == "hlsl", src)
   338  	if err != nil {
   339  		return "", Metadata{}, fmt.Errorf("failed to generate SPIR-V for %q: %w", shaderPath, err)
   340  	}
   341  
   342  	dst, err := conv.spirv.Convert(shaderPath, variant, spirv, lang, profile)
   343  	if err != nil {
   344  		return "", Metadata{}, fmt.Errorf("failed to convert shader %q: %w", shaderPath, err)
   345  	}
   346  
   347  	meta, err := conv.spirv.Metadata(shaderPath, variant, spirv)
   348  	if err != nil {
   349  		return "", Metadata{}, fmt.Errorf("failed to extract metadata for shader %q: %w", shaderPath, err)
   350  	}
   351  
   352  	return dst, meta, nil
   353  }
   354  
   355  func (conv *Converter) ComputeShader(shaderPath string) ([]driver.ShaderSources, error) {
   356  	shader, err := ioutil.ReadFile(shaderPath)
   357  	if err != nil {
   358  		return nil, fmt.Errorf("failed to load shader %q: %w", shaderPath, err)
   359  	}
   360  
   361  	spirv, err := conv.glslvalidator.Convert(shaderPath, "", false, shader)
   362  	if err != nil {
   363  		return nil, fmt.Errorf("failed to convert compute shader %q: %w", shaderPath, err)
   364  	}
   365  
   366  	var sources driver.ShaderSources
   367  	sources.Name = filepath.Base(shaderPath)
   368  
   369  	sources.GLSL310ES, err = conv.spirv.Convert(shaderPath, "", spirv, "es", "310")
   370  	if err != nil {
   371  		return nil, fmt.Errorf("failed to convert es compute shader %q: %w", shaderPath, err)
   372  	}
   373  	sources.GLSL310ES = unixLineEnding(sources.GLSL310ES)
   374  
   375  	hlslSource, err := conv.spirv.Convert(shaderPath, "", spirv, "hlsl", "50")
   376  	if err != nil {
   377  		return nil, fmt.Errorf("failed to convert hlsl compute shader %q: %w", shaderPath, err)
   378  	}
   379  
   380  	dxil, err := conv.fxc.Compile(shaderPath, "0", []byte(hlslSource), "main", "5_0")
   381  	if err != nil {
   382  		return nil, fmt.Errorf("failed to compile hlsl compute shader %q: %w", shaderPath, err)
   383  	}
   384  	if conv.directCompute {
   385  		sources.HLSL = dxil
   386  	}
   387  
   388  	return []driver.ShaderSources{sources}, nil
   389  }
   390  
   391  // Workers implements wait group with synchronous logging.
   392  type Workers struct {
   393  	running sync.WaitGroup
   394  }
   395  
   396  func (lg *Workers) Go(fn func()) {
   397  	lg.running.Add(1)
   398  	go func() {
   399  		defer lg.running.Done()
   400  		fn()
   401  	}()
   402  }
   403  
   404  func (lg *Workers) Wait() {
   405  	lg.running.Wait()
   406  }
   407  
   408  func unixLineEnding(s string) string {
   409  	return strings.ReplaceAll(s, "\r\n", "\n")
   410  }