github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/types/duration.go (about)

     1  /*
     2  Copyright 2020 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package types
    18  
    19  import (
    20  	"encoding/json"
    21  	"fmt"
    22  	"time"
    23  
    24  	"github.com/gravitational/trace"
    25  
    26  	"github.com/gravitational/teleport/api/constants"
    27  )
    28  
    29  // Duration is a wrapper around duration to set up custom marshal/unmarshal
    30  type Duration time.Duration
    31  
    32  // Duration returns time.Duration from Duration typex
    33  func (d Duration) Duration() time.Duration {
    34  	return time.Duration(d)
    35  }
    36  
    37  // Value returns time.Duration value of this wrapper
    38  func (d Duration) Value() time.Duration {
    39  	return time.Duration(d)
    40  }
    41  
    42  // MarshalJSON marshals Duration to string
    43  func (d Duration) MarshalJSON() ([]byte, error) {
    44  	return json.Marshal(d.Duration().String())
    45  }
    46  
    47  // UnmarshalJSON interprets the given bytes as a Duration value
    48  func (d *Duration) UnmarshalJSON(data []byte) error {
    49  	if len(data) == 0 {
    50  		return nil
    51  	}
    52  	var stringVar string
    53  	if err := json.Unmarshal(data, &stringVar); err != nil {
    54  		return trace.Wrap(err)
    55  	}
    56  	if stringVar == constants.DurationNever {
    57  		*d = Duration(0)
    58  		return nil
    59  	}
    60  	out, err := parseDuration(stringVar)
    61  	if err != nil {
    62  		return trace.BadParameter(err.Error())
    63  	}
    64  	*d = out
    65  	return nil
    66  }
    67  
    68  // MarshalYAML marshals duration into YAML value,
    69  // encodes it as a string in format "1m"
    70  func (d Duration) MarshalYAML() (interface{}, error) {
    71  	return fmt.Sprintf("%v", d.Duration()), nil
    72  }
    73  
    74  // UnmarshalYAML unmarshals duration from YAML value.
    75  func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
    76  	var stringVar string
    77  	if err := unmarshal(&stringVar); err != nil {
    78  		return trace.Wrap(err)
    79  	}
    80  	if stringVar == constants.DurationNever {
    81  		*d = Duration(0)
    82  		return nil
    83  	}
    84  	out, err := parseDuration(stringVar)
    85  	if err != nil {
    86  		return trace.BadParameter(err.Error())
    87  	}
    88  	*d = out
    89  	return nil
    90  }
    91  
    92  // MaxDuration returns the maximum duration value
    93  func MaxDuration() Duration {
    94  	return NewDuration(1<<63 - 1)
    95  }
    96  
    97  // NewDuration converts the given time.Duration value to a duration
    98  func NewDuration(d time.Duration) Duration {
    99  	return Duration(d)
   100  }
   101  
   102  // leadingInt consumes the leading [0-9]* from s.
   103  func leadingInt(s string) (x int64, rem string, err error) {
   104  	i := 0
   105  	for ; i < len(s); i++ {
   106  		c := s[i]
   107  		if c < '0' || c > '9' {
   108  			break
   109  		}
   110  		if x > (1<<63-1)/10 {
   111  			// overflow
   112  			return 0, "", trace.BadParameter("time: bad [0-9]*")
   113  		}
   114  		x = x*10 + int64(c) - '0'
   115  		if x < 0 {
   116  			// overflow
   117  			return 0, "", trace.BadParameter("time: bad [0-9]*")
   118  		}
   119  	}
   120  	return x, s[i:], nil
   121  }
   122  
   123  // leadingFraction consumes the leading [0-9]* from s.
   124  // It is used only for fractions, so does not return an error on overflow,
   125  // it just stops accumulating precision.
   126  func leadingFraction(s string) (x int64, scale float64, rem string) {
   127  	i := 0
   128  	scale = 1
   129  	overflow := false
   130  	for ; i < len(s); i++ {
   131  		c := s[i]
   132  		if c < '0' || c > '9' {
   133  			break
   134  		}
   135  		if overflow {
   136  			continue
   137  		}
   138  		if x > (1<<63-1)/10 {
   139  			// It's possible for overflow to give a positive number, so take care.
   140  			overflow = true
   141  			continue
   142  		}
   143  		y := x*10 + int64(c) - '0'
   144  		if y < 0 {
   145  			overflow = true
   146  			continue
   147  		}
   148  		x = y
   149  		scale *= 10
   150  	}
   151  	return x, scale, s[i:]
   152  }
   153  
   154  var unitMap = map[string]int64{
   155  	"ns": int64(time.Nanosecond),
   156  	"us": int64(time.Microsecond),
   157  	"µs": int64(time.Microsecond), // U+00B5 = micro symbol
   158  	"μs": int64(time.Microsecond), // U+03BC = Greek letter mu
   159  	"ms": int64(time.Millisecond),
   160  	"s":  int64(time.Second),
   161  	"m":  int64(time.Minute),
   162  	"h":  int64(time.Hour),
   163  	"d":  int64(time.Hour * 24),
   164  	"mo": int64(time.Hour * 24 * 30),
   165  	"y":  int64(time.Hour * 24 * 365),
   166  }
   167  
   168  // parseDuration parses a duration string.
   169  // A duration string is a possibly signed sequence of
   170  // decimal numbers, each with optional fraction and a unit suffix,
   171  // such as "300ms", "-1.5h" or "2h45m".
   172  // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
   173  func parseDuration(s string) (Duration, error) {
   174  	// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
   175  	orig := s
   176  	var d int64
   177  	neg := false
   178  
   179  	// Consume [-+]?
   180  	if s != "" {
   181  		c := s[0]
   182  		if c == '-' || c == '+' {
   183  			neg = c == '-'
   184  			s = s[1:]
   185  		}
   186  	}
   187  	// Special case: if all that is left is "0", this is zero.
   188  	if s == "0" {
   189  		return 0, nil
   190  	}
   191  	if s == "" {
   192  		return 0, trace.BadParameter("time: invalid duration " + orig)
   193  	}
   194  	for s != "" {
   195  		var (
   196  			v, f  int64       // integers before, after decimal point
   197  			scale float64 = 1 // value = v + f/scale
   198  		)
   199  
   200  		var err error
   201  
   202  		// The next character must be [0-9.]
   203  		if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') {
   204  			return 0, trace.BadParameter("time: invalid duration " + orig)
   205  		}
   206  		// Consume [0-9]*
   207  		pl := len(s)
   208  		v, s, err = leadingInt(s)
   209  		if err != nil {
   210  			return 0, trace.BadParameter("time: invalid duration " + orig)
   211  		}
   212  		pre := pl != len(s) // whether we consumed anything before a period
   213  
   214  		// Consume (\.[0-9]*)?
   215  		post := false
   216  		if s != "" && s[0] == '.' {
   217  			s = s[1:]
   218  			pl := len(s)
   219  			f, scale, s = leadingFraction(s)
   220  			post = pl != len(s)
   221  		}
   222  		if !pre && !post {
   223  			// no digits (e.g. ".s" or "-.s")
   224  			return 0, trace.BadParameter("time: invalid duration " + orig)
   225  		}
   226  
   227  		// Consume unit.
   228  		i := 0
   229  		for ; i < len(s); i++ {
   230  			c := s[i]
   231  			if c == '.' || '0' <= c && c <= '9' {
   232  				break
   233  			}
   234  		}
   235  		if i == 0 {
   236  			return 0, trace.BadParameter("time: missing unit in duration " + orig)
   237  		}
   238  		u := s[:i]
   239  		s = s[i:]
   240  		unit, ok := unitMap[u]
   241  		if !ok {
   242  			return 0, trace.BadParameter("time: unknown unit " + " in duration " + orig)
   243  		}
   244  		if v > (1<<63-1)/unit {
   245  			// overflow
   246  			return 0, trace.BadParameter("time: invalid duration " + orig)
   247  		}
   248  		v *= unit
   249  		if f > 0 {
   250  			// float64 is needed to be nanosecond accurate for fractions of hours.
   251  			// v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
   252  			v += int64(float64(f) * (float64(unit) / scale))
   253  			if v < 0 {
   254  				// overflow
   255  				return 0, trace.BadParameter("time: invalid duration " + orig)
   256  			}
   257  		}
   258  		d += v
   259  		if d < 0 {
   260  			// overflow
   261  			return 0, trace.BadParameter("time: invalid duration " + orig)
   262  		}
   263  	}
   264  
   265  	if neg {
   266  		d = -d
   267  	}
   268  	return Duration(d), nil
   269  }