github.com/google/grumpy@v0.0.0-20171122020858-3ec87959189c/lib/itertools.py (about)

     1  # Copyright 2016 Google Inc. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     http://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  """Utilities for iterating over containers."""
    16  
    17  import _collections
    18  import sys
    19  
    20  class chain(object):
    21  
    22    def from_iterable(cls, iterables):
    23      for it in iterables:
    24        for element in it:
    25          yield element
    26  
    27    from_iterable = classmethod(from_iterable)
    28  
    29    def __init__(self, *iterables):
    30      if not iterables:
    31        self.iterables = iter([[]])
    32      else:
    33        self.iterables = iter(iterables)
    34      self.curriter = iter(next(self.iterables))
    35  
    36    def __iter__(self):
    37      return self
    38  
    39    def next(self):
    40      flag = True
    41      while flag:
    42        try:
    43          ret = next(self.curriter)
    44          flag = False
    45        except StopIteration:
    46          self.curriter = iter(next(self.iterables))
    47      return ret
    48  
    49  
    50  def compress(data, selectors):
    51    return (d for d,s in izip(data, selectors) if s)
    52  
    53  
    54  def count(start=0, step=1):
    55    n = start
    56    while True:
    57      yield n
    58      n += step
    59  
    60  
    61  def cycle(iterable):
    62    saved = []
    63    for element in iterable:
    64      yield element
    65      saved.append(element)
    66    while saved:
    67      for element in saved:
    68        yield element
    69  
    70  
    71  def dropwhile(predicate, iterable):
    72    iterable = iter(iterable)
    73    for x in iterable:
    74      if not predicate(x):
    75        yield x
    76        break
    77    for x in iterable:
    78      yield x
    79  
    80  
    81  class groupby(object):
    82    # [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B
    83    # [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D
    84    def __init__(self, iterable, key=None):
    85      if key is None:
    86        key = lambda x: x
    87      self.keyfunc = key
    88      self.it = iter(iterable)
    89      self.tgtkey = self.currkey = self.currvalue = object()
    90  
    91    def __iter__(self):
    92      return self
    93  
    94    def next(self):
    95      while self.currkey == self.tgtkey:
    96        self.currvalue = next(self.it)    # Exit on StopIteration
    97        self.currkey = self.keyfunc(self.currvalue)
    98      self.tgtkey = self.currkey
    99      return (self.currkey, self._grouper(self.tgtkey))
   100    
   101    def _grouper(self, tgtkey):
   102      while self.currkey == tgtkey:
   103        yield self.currvalue
   104        self.currvalue = next(self.it)    # Exit on StopIteration
   105        self.currkey = self.keyfunc(self.currvalue)
   106  
   107  
   108  def ifilter(predicate, iterable):
   109    if predicate is None:
   110      predicate = bool
   111    for x in iterable:
   112      if predicate(x):
   113         yield x
   114  
   115  
   116  def ifilterfalse(predicate, iterable):
   117    if predicate is None:
   118      predicate = bool
   119    for x in iterable:
   120      if not predicate(x):
   121         yield x
   122  
   123  
   124  def imap(function, *iterables):
   125    iterables = map(iter, iterables)
   126    while True:
   127      args = [next(it) for it in iterables]
   128      if function is None:
   129        yield tuple(args)
   130      else:
   131        yield function(*args)
   132  
   133  
   134  def islice(iterable, *args):
   135    s = slice(*args)
   136    it = iter(xrange(s.start or 0, s.stop or sys.maxint, s.step or 1))
   137    nexti = next(it)
   138    for i, element in enumerate(iterable):
   139      if i == nexti:
   140        yield element
   141        nexti = next(it)
   142  
   143  
   144  def izip(*iterables):
   145    iterators = map(iter, iterables)
   146    while iterators:
   147      yield tuple(map(next, iterators))
   148  
   149  
   150  class ZipExhausted(Exception):
   151    pass
   152  
   153  
   154  def izip_longest(*args, **kwds):
   155    # izip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D-
   156    fillvalue = kwds.get('fillvalue')
   157    counter = [len(args) - 1]
   158    def sentinel():
   159      if not counter[0]:
   160        raise ZipExhausted
   161      counter[0] -= 1
   162      yield fillvalue
   163    fillers = repeat(fillvalue)
   164    iterators = [chain(it, sentinel(), fillers) for it in args]
   165    try:
   166      while iterators:
   167        yield tuple(map(next, iterators))
   168    except ZipExhausted:
   169      pass
   170  
   171  
   172  def product(*args, **kwds):
   173    # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
   174    # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
   175    pools = map(tuple, args) * kwds.get('repeat', 1)
   176    result = [[]]
   177    for pool in pools:
   178      result = [x+[y] for x in result for y in pool]
   179    for prod in result:
   180      yield tuple(prod)
   181  
   182  
   183  def permutations(iterable, r=None):
   184    pool = tuple(iterable)
   185    n = len(pool)
   186    r = n if r is None else r
   187    for indices in product(range(n), repeat=r):
   188      if len(set(indices)) == r:
   189        yield tuple(pool[i] for i in indices)
   190  
   191  
   192  def combinations(iterable, r):
   193    pool = tuple(iterable)
   194    n = len(pool)
   195    for indices in permutations(range(n), r):
   196      if sorted(indices) == list(indices):
   197        yield tuple(pool[i] for i in indices)
   198  
   199  
   200  def combinations_with_replacement(iterable, r):
   201    pool = tuple(iterable)
   202    n = len(pool)
   203    for indices in product(range(n), repeat=r):
   204      if sorted(indices) == list(indices):
   205        yield tuple(pool[i] for i in indices)
   206  
   207  
   208  def repeat(object, times=None):
   209    if times is None:
   210      while True:
   211        yield object
   212    else:
   213      for i in xrange(times):
   214        yield object
   215  
   216  
   217  def starmap(function, iterable):
   218    for args in iterable:
   219      yield function(*args)
   220  
   221  
   222  def takewhile(predicate, iterable):
   223    for x in iterable:
   224      if predicate(x):
   225        yield x
   226      else:
   227        break
   228  
   229  
   230  def tee(iterable, n=2):
   231    it = iter(iterable)
   232    deques = [_collections.deque() for i in range(n)]
   233    def gen(mydeque):
   234      while True:
   235        if not mydeque:
   236          newval = next(it)
   237          for d in deques:
   238            d.append(newval)
   239        yield mydeque.popleft()
   240    return tuple(gen(d) for d in deques)
   241