github.com/cybriq/giocore@v0.0.7-0.20210703034601-cfb9cb5f3900/gpu/internal/convertshaders/spirvcross.go (about)

     1  // SPDX-License-Identifier: Unlicense OR MIT
     2  
     3  package main
     4  
     5  import (
     6  	"encoding/json"
     7  	"fmt"
     8  	"os/exec"
     9  	"path/filepath"
    10  	"sort"
    11  	"strings"
    12  
    13  	"github.com/cybriq/giocore/gpu/internal/driver"
    14  )
    15  
    16  // Metadata contains reflection data about a shader.
    17  type Metadata struct {
    18  	Uniforms driver.UniformsReflection
    19  	Inputs   []driver.InputLocation
    20  	Textures []driver.TextureBinding
    21  }
    22  
    23  // SPIRVCross cross-compiles spirv shaders to es, hlsl and others.
    24  type SPIRVCross struct {
    25  	Bin     string
    26  	WorkDir WorkDir
    27  }
    28  
    29  func NewSPIRVCross() *SPIRVCross { return &SPIRVCross{Bin: "spirv-cross"} }
    30  
    31  // Convert converts compute shader from spirv format to a target format.
    32  func (spirv *SPIRVCross) Convert(path, variant string, shader []byte, target, version string) (string, error) {
    33  	base := spirv.WorkDir.Path(filepath.Base(path), variant)
    34  
    35  	if err := spirv.WorkDir.WriteFile(base, shader); err != nil {
    36  		return "", fmt.Errorf("unable to write shader to disk: %w", err)
    37  	}
    38  
    39  	var cmd *exec.Cmd
    40  	switch target {
    41  	case "glsl":
    42  		cmd = exec.Command(spirv.Bin,
    43  			"--no-es",
    44  			"--version", version,
    45  		)
    46  	case "es":
    47  		cmd = exec.Command(spirv.Bin,
    48  			"--es",
    49  			"--version", version,
    50  		)
    51  	case "hlsl":
    52  		cmd = exec.Command(spirv.Bin,
    53  			"--hlsl",
    54  			"--shader-model", version,
    55  		)
    56  	default:
    57  		return "", fmt.Errorf("unknown target %q", target)
    58  	}
    59  	cmd.Args = append(cmd.Args, "--no-420pack-extension", base)
    60  
    61  	out, err := cmd.CombinedOutput()
    62  	if err != nil {
    63  		return "", fmt.Errorf("%s\nfailed to run %v: %w", out, cmd.Args, err)
    64  	}
    65  	s := string(out)
    66  	if target != "hlsl" {
    67  		// Strip Windows \r in line endings.
    68  		s = unixLineEnding(s)
    69  	}
    70  
    71  	return s, nil
    72  }
    73  
    74  // Metadata extracts metadata for a SPIR-V shader.
    75  func (spirv *SPIRVCross) Metadata(path, variant string, shader []byte) (Metadata, error) {
    76  	base := spirv.WorkDir.Path(filepath.Base(path), variant)
    77  
    78  	if err := spirv.WorkDir.WriteFile(base, shader); err != nil {
    79  		return Metadata{}, fmt.Errorf("unable to write shader to disk: %w", err)
    80  	}
    81  
    82  	cmd := exec.Command(spirv.Bin,
    83  		base,
    84  		"--reflect",
    85  	)
    86  
    87  	out, err := cmd.Output()
    88  	if err != nil {
    89  		return Metadata{}, fmt.Errorf("failed to run %v: %w", cmd.Args, err)
    90  	}
    91  
    92  	meta, err := parseMetadata(out)
    93  	if err != nil {
    94  		return Metadata{}, fmt.Errorf("%s\nfailed to parse metadata: %w", out, err)
    95  	}
    96  
    97  	return meta, nil
    98  }
    99  
   100  func parseMetadata(data []byte) (Metadata, error) {
   101  	var reflect struct {
   102  		Types map[string]struct {
   103  			Name    string `json:"name"`
   104  			Members []struct {
   105  				Name   string `json:"name"`
   106  				Type   string `json:"type"`
   107  				Offset int    `json:"offset"`
   108  			} `json:"members"`
   109  		} `json:"types"`
   110  		Inputs []struct {
   111  			Name     string `json:"name"`
   112  			Type     string `json:"type"`
   113  			Location int    `json:"location"`
   114  		} `json:"inputs"`
   115  		Textures []struct {
   116  			Name    string `json:"name"`
   117  			Type    string `json:"type"`
   118  			Set     int    `json:"set"`
   119  			Binding int    `json:"binding"`
   120  		} `json:"textures"`
   121  		UBOs []struct {
   122  			Name      string `json:"name"`
   123  			Type      string `json:"type"`
   124  			BlockSize int    `json:"block_size"`
   125  			Set       int    `json:"set"`
   126  			Binding   int    `json:"binding"`
   127  		} `json:"ubos"`
   128  	}
   129  	if err := json.Unmarshal(data, &reflect); err != nil {
   130  		return Metadata{}, fmt.Errorf("failed to parse reflection data: %w", err)
   131  	}
   132  
   133  	var m Metadata
   134  
   135  	for _, input := range reflect.Inputs {
   136  		dataType, dataSize, err := parseDataType(input.Type)
   137  		if err != nil {
   138  			return Metadata{}, fmt.Errorf("parseReflection: %v", err)
   139  		}
   140  		m.Inputs = append(m.Inputs, driver.InputLocation{
   141  			Name:          input.Name,
   142  			Location:      input.Location,
   143  			Semantic:      "TEXCOORD",
   144  			SemanticIndex: input.Location,
   145  			Type:          dataType,
   146  			Size:          dataSize,
   147  		})
   148  	}
   149  
   150  	sort.Slice(m.Inputs, func(i, j int) bool {
   151  		return m.Inputs[i].Location < m.Inputs[j].Location
   152  	})
   153  
   154  	blockOffset := 0
   155  	for _, block := range reflect.UBOs {
   156  		m.Uniforms.Blocks = append(m.Uniforms.Blocks, driver.UniformBlock{
   157  			Name:    block.Name,
   158  			Binding: block.Binding,
   159  		})
   160  		t := reflect.Types[block.Type]
   161  		// By convention uniform block variables are named by prepending an underscore
   162  		// and converting to lowercase.
   163  		blockVar := "_" + strings.ToLower(block.Name)
   164  		for _, member := range t.Members {
   165  			dataType, size, err := parseDataType(member.Type)
   166  			if err != nil {
   167  				return Metadata{}, fmt.Errorf("failed to parse reflection data: %v", err)
   168  			}
   169  			m.Uniforms.Locations = append(m.Uniforms.Locations, driver.UniformLocation{
   170  				Name:   fmt.Sprintf("%s.%s", blockVar, member.Name),
   171  				Type:   dataType,
   172  				Size:   size,
   173  				Offset: blockOffset + member.Offset,
   174  			})
   175  		}
   176  		blockOffset += block.BlockSize
   177  	}
   178  	m.Uniforms.Size = blockOffset
   179  
   180  	for _, texture := range reflect.Textures {
   181  		m.Textures = append(m.Textures, driver.TextureBinding{
   182  			Name:    texture.Name,
   183  			Binding: texture.Binding,
   184  		})
   185  	}
   186  
   187  	//return m, fmt.Errorf("not yet!: %+v", reflect)
   188  	return m, nil
   189  }
   190  
   191  func parseDataType(t string) (driver.DataType, int, error) {
   192  	switch t {
   193  	case "float":
   194  		return driver.DataTypeFloat, 1, nil
   195  	case "vec2":
   196  		return driver.DataTypeFloat, 2, nil
   197  	case "vec3":
   198  		return driver.DataTypeFloat, 3, nil
   199  	case "vec4":
   200  		return driver.DataTypeFloat, 4, nil
   201  	case "int":
   202  		return driver.DataTypeInt, 1, nil
   203  	case "int2":
   204  		return driver.DataTypeInt, 2, nil
   205  	case "int3":
   206  		return driver.DataTypeInt, 3, nil
   207  	case "int4":
   208  		return driver.DataTypeInt, 4, nil
   209  	default:
   210  		return 0, 0, fmt.Errorf("unsupported input data type: %s", t)
   211  	}
   212  }