github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/prog/checksum.go (about)

     1  // Copyright 2017 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package prog
     5  
     6  import (
     7  	"fmt"
     8  )
     9  
    10  type CsumChunkKind int
    11  
    12  const (
    13  	CsumChunkArg CsumChunkKind = iota
    14  	CsumChunkConst
    15  )
    16  
    17  type CsumInfo struct {
    18  	Kind   CsumKind
    19  	Chunks []CsumChunk
    20  }
    21  
    22  type CsumChunk struct {
    23  	Kind  CsumChunkKind
    24  	Arg   Arg    // for CsumChunkArg
    25  	Value uint64 // for CsumChunkConst
    26  	Size  uint64 // for CsumChunkConst
    27  }
    28  
    29  func calcChecksumsCall(c *Call) (map[Arg]CsumInfo, map[Arg]struct{}) {
    30  	var inetCsumFields, pseudoCsumFields []Arg
    31  
    32  	// Find all csum fields.
    33  	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
    34  		if typ, ok := arg.Type().(*CsumType); ok {
    35  			switch typ.Kind {
    36  			case CsumInet:
    37  				inetCsumFields = append(inetCsumFields, arg)
    38  			case CsumPseudo:
    39  				pseudoCsumFields = append(pseudoCsumFields, arg)
    40  			default:
    41  				panic(fmt.Sprintf("unknown csum kind %v", typ.Kind))
    42  			}
    43  		}
    44  	})
    45  
    46  	if len(inetCsumFields) == 0 && len(pseudoCsumFields) == 0 {
    47  		return nil, nil
    48  	}
    49  
    50  	// Build map of each field to its parent struct.
    51  	parentsMap := make(map[Arg]Arg)
    52  	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
    53  		if _, ok := arg.Type().(*StructType); ok {
    54  			for _, field := range arg.(*GroupArg).Inner {
    55  				parentsMap[InnerArg(field)] = arg
    56  			}
    57  		}
    58  	})
    59  
    60  	csumMap := make(map[Arg]CsumInfo)
    61  	csumUses := make(map[Arg]struct{})
    62  
    63  	// Calculate generic inet checksums.
    64  	for _, arg := range inetCsumFields {
    65  		typ, _ := arg.Type().(*CsumType)
    66  		csummedArg := findCsummedArg(arg, typ, parentsMap)
    67  		csumUses[csummedArg] = struct{}{}
    68  		chunk := CsumChunk{CsumChunkArg, csummedArg, 0, 0}
    69  		csumMap[arg] = CsumInfo{Kind: CsumInet, Chunks: []CsumChunk{chunk}}
    70  	}
    71  
    72  	// No need to continue if there are no pseudo csum fields.
    73  	if len(pseudoCsumFields) == 0 {
    74  		return csumMap, csumUses
    75  	}
    76  
    77  	// Extract ipv4 or ipv6 source and destination addresses.
    78  	var ipSrcAddr, ipDstAddr Arg
    79  	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
    80  		groupArg, ok := arg.(*GroupArg)
    81  		if !ok {
    82  			return
    83  		}
    84  		// syz_csum_* structs are used in tests
    85  		switch groupArg.Type().TemplateName() {
    86  		case "ipv4_header", "syz_csum_ipv4_header":
    87  			ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 4)
    88  		case "ipv6_packet_t", "syz_csum_ipv6_header":
    89  			ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 16)
    90  		}
    91  	})
    92  	if ipSrcAddr == nil || ipDstAddr == nil {
    93  		panic("no ipv4 nor ipv6 header found")
    94  	}
    95  
    96  	// Calculate pseudo checksums.
    97  	for _, arg := range pseudoCsumFields {
    98  		typ, _ := arg.Type().(*CsumType)
    99  		csummedArg := findCsummedArg(arg, typ, parentsMap)
   100  		protocol := uint8(typ.Protocol)
   101  		var info CsumInfo
   102  		if ipSrcAddr.Size() == 4 {
   103  			info = composePseudoCsumIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol)
   104  		} else {
   105  			info = composePseudoCsumIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol)
   106  		}
   107  		csumMap[arg] = info
   108  		csumUses[csummedArg] = struct{}{}
   109  		csumUses[ipSrcAddr] = struct{}{}
   110  		csumUses[ipDstAddr] = struct{}{}
   111  	}
   112  
   113  	return csumMap, csumUses
   114  }
   115  
   116  func findCsummedArg(arg Arg, typ *CsumType, parentsMap map[Arg]Arg) Arg {
   117  	if typ.Buf == ParentRef {
   118  		csummedArg := parentsMap[arg]
   119  		if csummedArg == nil {
   120  			panic(fmt.Sprintf("%q for %q is not in parents map", ParentRef, typ.Name()))
   121  		}
   122  		return csummedArg
   123  	}
   124  	for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] {
   125  		// TODO(dvyukov): support template argument names as in size calculation.
   126  		if typ.Buf == parent.Type().Name() {
   127  			return parent
   128  		}
   129  	}
   130  	panic(fmt.Sprintf("csum field %q references non existent field %q", typ.Name(), typ.Buf))
   131  }
   132  
   133  func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo {
   134  	info := CsumInfo{Kind: CsumInet}
   135  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0})
   136  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0})
   137  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(protocol))), 2})
   138  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(tcpPacket.Size()))), 2})
   139  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0})
   140  	return info
   141  }
   142  
   143  func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo {
   144  	info := CsumInfo{Kind: CsumInet}
   145  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0})
   146  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0})
   147  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(tcpPacket.Size()))), 4})
   148  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(protocol))), 4})
   149  	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0})
   150  	return info
   151  }
   152  
   153  func extractHeaderParams(arg *GroupArg, size uint64) (Arg, Arg) {
   154  	srcAddr := getFieldByName(arg, "src_ip")
   155  	dstAddr := getFieldByName(arg, "dst_ip")
   156  	if srcAddr.Size() != size || dstAddr.Size() != size {
   157  		panic(fmt.Sprintf("src/dst_ip fields in %v must be %v bytes", arg.Type().Name(), size))
   158  	}
   159  	return srcAddr, dstAddr
   160  }
   161  
   162  func getFieldByName(arg *GroupArg, name string) Arg {
   163  	typ := arg.Type().(*StructType)
   164  	for i, field := range arg.Inner {
   165  		if typ.Fields[i].Name == name {
   166  			return field
   167  		}
   168  	}
   169  	panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type().Name()))
   170  }