github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/inet.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"errors"
    19  	"net"
    20  
    21  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    22  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    23  )
    24  
    25  // Inet is a codec for the CQL inet type. Its preferred Go type is net.IP but it can encode from and decode to
    26  // []byte as well.
    27  var Inet Codec = &inetCodec{}
    28  
    29  type inetCodec struct{}
    30  
    31  func (c *inetCodec) DataType() datatype.DataType {
    32  	return datatype.Inet
    33  }
    34  
    35  func (c *inetCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
    36  	var val net.IP
    37  	if val, err = convertToIP(source); err == nil && val != nil {
    38  		dest, err = writeInet(val)
    39  	}
    40  	if err != nil {
    41  		err = errCannotEncode(source, c.DataType(), version, err)
    42  	}
    43  	return
    44  }
    45  
    46  func (c *inetCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
    47  	var val net.IP
    48  	if val, wasNull, err = readInet(source); err == nil {
    49  		err = convertFromIP(val, wasNull, dest)
    50  	}
    51  	if err != nil {
    52  		err = errCannotDecode(dest, c.DataType(), version, err)
    53  	}
    54  	return
    55  }
    56  
    57  func convertToIP(source interface{}) (val net.IP, err error) {
    58  	switch s := source.(type) {
    59  	case net.IP:
    60  		val = s
    61  		val4 := val.To4()
    62  		if val4 != nil {
    63  			val = val4
    64  		}
    65  	case *net.IP:
    66  		if s != nil {
    67  			val = *s
    68  			val4 := val.To4()
    69  			if val4 != nil {
    70  				val = val4
    71  			}
    72  		}
    73  	case []byte:
    74  		val = s
    75  		val4 := val.To4()
    76  		if val4 != nil {
    77  			val = val4
    78  		}
    79  	case *[]byte:
    80  		if s != nil {
    81  			val = *s
    82  			val4 := val.To4()
    83  			if val4 != nil {
    84  				val = val4
    85  			}
    86  		}
    87  	case string:
    88  		val = compactV4(net.ParseIP(s))
    89  		if val == nil {
    90  			err = errCannotParseString(s, errors.New("net.ParseIP(text) failed"))
    91  		}
    92  	case *string:
    93  		if s != nil {
    94  			val = compactV4(net.ParseIP(*s))
    95  			if val == nil {
    96  				err = errCannotParseString(*s, errors.New("net.ParseIP(text) failed"))
    97  			}
    98  		}
    99  	case nil:
   100  	default:
   101  		err = ErrConversionNotSupported
   102  	}
   103  	if err != nil {
   104  		err = errSourceConversionFailed(source, val, err)
   105  	}
   106  	return
   107  }
   108  
   109  func convertFromIP(val net.IP, wasNull bool, dest interface{}) (err error) {
   110  	switch d := dest.(type) {
   111  	case *interface{}:
   112  		if d == nil {
   113  			err = ErrNilDestination
   114  		} else if wasNull {
   115  			*d = nil
   116  		} else {
   117  			*d = val
   118  		}
   119  	case *net.IP:
   120  		if d == nil {
   121  			err = ErrNilDestination
   122  		} else if wasNull {
   123  			*d = nil
   124  		} else {
   125  			*d = compactV4(val)
   126  		}
   127  	case *[]byte:
   128  		if d == nil {
   129  			err = ErrNilDestination
   130  		} else if wasNull {
   131  			*d = nil
   132  		} else {
   133  			*d = compactV4(val)
   134  		}
   135  	case *string:
   136  		if d == nil {
   137  			err = ErrNilDestination
   138  		} else if wasNull {
   139  			*d = ""
   140  		} else {
   141  			*d = val.String()
   142  		}
   143  	default:
   144  		err = errDestinationInvalid(dest)
   145  	}
   146  	if err != nil {
   147  		err = errDestinationConversionFailed(val, dest, err)
   148  	}
   149  	return
   150  }
   151  
   152  func writeInet(val net.IP) (dest []byte, err error) {
   153  	length := len(val)
   154  	if length == 0 {
   155  		dest = nil
   156  	} else if length == net.IPv4len || length == net.IPv6len {
   157  		dest = compactV4(val)
   158  	} else {
   159  		err = errWrongFixedLengths(net.IPv4len, net.IPv6len, length)
   160  	}
   161  	if err != nil {
   162  		err = errCannotWrite(val, err)
   163  	}
   164  	return
   165  }
   166  
   167  // The below functions are roughly equivalent to primitive.ReadInetAddr and primitive.WriteInetAddr.
   168  // They favor the compact form (4-byte slice) for IPv4 addresses.
   169  
   170  func readInet(source []byte) (val net.IP, wasNull bool, err error) {
   171  	length := len(source)
   172  	if length == 0 {
   173  		wasNull = true
   174  	} else if length == net.IPv4len {
   175  		val = net.IPv4(source[0], source[1], source[2], source[3]).To4()
   176  	} else if length == net.IPv6len {
   177  		val = source
   178  	} else {
   179  		err = errWrongFixedLengths(net.IPv4len, net.IPv6len, length)
   180  	}
   181  	if err != nil {
   182  		err = errCannotRead(val, err)
   183  	}
   184  	return
   185  }
   186  
   187  func compactV4(val net.IP) net.IP {
   188  	if val != nil {
   189  		val4 := val.To4()
   190  		if val4 != nil {
   191  			return val4
   192  		}
   193  	}
   194  	return val
   195  }