
     1  // Copyright 2018 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package main
    17  import (
    18  	"bytes"
    19  	"flag"
    20  	"fmt"
    21  	"io"
    22  	"math"
    23  	"os"
    24  )
    26  var (
    27  	input  = flag.String("i", "", "input file")
    28  	output = flag.String("o", "", "output file")
    29  	symbol = flag.String("s", "", "symbol to inject into")
    30  	from   = flag.String("from", "", "optional existing value of the symbol for verification")
    31  	value  = flag.String("v", "", "value to inject into symbol")
    33  	dump = flag.Bool("dump", false, "dump the symbol table for copying into a test")
    34  )
    36  var maxUint64 uint64 = math.MaxUint64
    38  type cantParseError struct {
    39  	error
    40  }
    42  func main() {
    43  	flag.Parse()
    45  	usageError := func(s string) {
    46  		fmt.Fprintln(os.Stderr, s)
    47  		flag.Usage()
    48  		os.Exit(1)
    49  	}
    51  	if *input == "" {
    52  		usageError("-i is required")
    53  	}
    55  	if !*dump {
    56  		if *output == "" {
    57  			usageError("-o is required")
    58  		}
    60  		if *symbol == "" {
    61  			usageError("-s is required")
    62  		}
    64  		if *value == "" {
    65  			usageError("-v is required")
    66  		}
    67  	}
    69  	r, err := os.Open(*input)
    70  	if err != nil {
    71  		fmt.Fprintln(os.Stderr, err.Error())
    72  		os.Exit(2)
    73  	}
    74  	defer r.Close()
    76  	if *dump {
    77  		err := dumpSymbols(r)
    78  		if err != nil {
    79  			fmt.Fprintln(os.Stderr, err.Error())
    80  			os.Exit(6)
    81  		}
    82  		return
    83  	}
    85  	w, err := os.OpenFile(*output, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
    86  	if err != nil {
    87  		fmt.Fprintln(os.Stderr, err.Error())
    88  		os.Exit(3)
    89  	}
    90  	defer w.Close()
    92  	file, err := openFile(r)
    93  	if err != nil {
    94  		fmt.Fprintln(os.Stderr, err.Error())
    95  		os.Exit(4)
    96  	}
    98  	err = injectSymbol(file, w, *symbol, *value, *from)
    99  	if err != nil {
   100  		fmt.Fprintln(os.Stderr, err.Error())
   101  		os.Remove(*output)
   102  		os.Exit(5)
   103  	}
   104  }
   106  func openFile(r io.ReaderAt) (*File, error) {
   107  	file, err := elfSymbolsFromFile(r)
   108  	if elfError, ok := err.(cantParseError); ok {
   109  		// Try as a mach-o file
   110  		file, err = machoSymbolsFromFile(r)
   111  		if _, ok := err.(cantParseError); ok {
   112  			// Try as a windows PE file
   113  			file, err = peSymbolsFromFile(r)
   114  			if _, ok := err.(cantParseError); ok {
   115  				// Can't parse as elf, macho, or PE, return the elf error
   116  				return nil, elfError
   117  			}
   118  		}
   119  	}
   120  	if err != nil {
   121  		return nil, err
   122  	}
   124  	file.r = r
   126  	return file, err
   127  }
   129  func injectSymbol(file *File, w io.Writer, symbol, value, from string) error {
   130  	offset, size, err := findSymbol(file, symbol)
   131  	if err != nil {
   132  		return err
   133  	}
   135  	if uint64(len(value))+1 > size {
   136  		return fmt.Errorf("value length %d overflows symbol size %d", len(value), size)
   137  	}
   139  	if from != "" {
   140  		// Read the exsting symbol contents and verify they match the expected value
   141  		expected := make([]byte, size)
   142  		existing := make([]byte, size)
   143  		copy(expected, from)
   144  		_, err := file.r.ReadAt(existing, int64(offset))
   145  		if err != nil {
   146  			return err
   147  		}
   148  		if bytes.Compare(existing, expected) != 0 {
   149  			return fmt.Errorf("existing symbol contents %q did not match expected value %q",
   150  				string(existing), string(expected))
   151  		}
   152  	}
   154  	return copyAndInject(file.r, w, offset, size, value)
   155  }
   157  func copyAndInject(r io.ReaderAt, w io.Writer, offset, size uint64, value string) (err error) {
   158  	buf := make([]byte, size)
   159  	copy(buf, value)
   161  	// Copy the first bytes up to the symbol offset
   162  	_, err = io.Copy(w, io.NewSectionReader(r, 0, int64(offset)))
   164  	// Write the injected value in the output file
   165  	if err == nil {
   166  		_, err = w.Write(buf)
   167  	}
   169  	// Write the remainder of the file
   170  	pos := int64(offset + size)
   171  	if err == nil {
   172  		_, err = io.Copy(w, io.NewSectionReader(r, pos, 1<<63-1-pos))
   173  	}
   175  	if err == io.EOF {
   176  		err = io.ErrUnexpectedEOF
   177  	}
   179  	return err
   180  }
   182  func findSymbol(file *File, symbolName string) (uint64, uint64, error) {
   183  	for i, symbol := range file.Symbols {
   184  		if symbol.Name == symbolName {
   185  			// Find the next symbol (n the same section with a higher address
   186  			var n int
   187  			for n = i; n < len(file.Symbols); n++ {
   188  				if file.Symbols[n].Section != symbol.Section {
   189  					n = len(file.Symbols)
   190  					break
   191  				}
   192  				if file.Symbols[n].Addr > symbol.Addr {
   193  					break
   194  				}
   195  			}
   197  			size := symbol.Size
   198  			if size == 0 {
   199  				var end uint64
   200  				if n < len(file.Symbols) {
   201  					end = file.Symbols[n].Addr
   202  				} else {
   203  					end = symbol.Section.Size
   204  				}
   206  				if end <= symbol.Addr || end > symbol.Addr+4096 {
   207  					return maxUint64, maxUint64, fmt.Errorf("symbol end address does not seem valid, %x:%x", symbol.Addr, end)
   208  				}
   210  				size = end - symbol.Addr
   211  			}
   213  			offset := symbol.Section.Offset + symbol.Addr
   215  			return uint64(offset), uint64(size), nil
   216  		}
   217  	}
   219  	return maxUint64, maxUint64, fmt.Errorf("symbol not found")
   220  }
   222  type File struct {
   223  	r        io.ReaderAt
   224  	Symbols  []*Symbol
   225  	Sections []*Section
   226  }
   228  type Symbol struct {
   229  	Name    string
   230  	Addr    uint64 // Address of the symbol inside the section.
   231  	Size    uint64 // Size of the symbol, if known.
   232  	Section *Section
   233  }
   235  type Section struct {
   236  	Name   string
   237  	Addr   uint64 // Virtual address of the start of the section.
   238  	Offset uint64 // Offset into the file of the start of the section.
   239  	Size   uint64
   240  }
   242  func dumpSymbols(r io.ReaderAt) error {
   243  	err := dumpElfSymbols(r)
   244  	if elfError, ok := err.(cantParseError); ok {
   245  		// Try as a mach-o file
   246  		err = dumpMachoSymbols(r)
   247  		if _, ok := err.(cantParseError); ok {
   248  			// Try as a windows PE file
   249  			err = dumpPESymbols(r)
   250  			if _, ok := err.(cantParseError); ok {
   251  				// Can't parse as elf, macho, or PE, return the elf error
   252  				return elfError
   253  			}
   254  		}
   255  	}
   256  	return err
   257  }