Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-123884 Tee of tee was not producing n independent iterators #124490

Merged
merged 9 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions Doc/library/itertools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -691,25 +691,36 @@ loops that truncate the stream.

def tee(iterable, n=2):
if n < 0:
raise ValueError('n must be >= 0')
iterator = iter(iterable)
shared_link = [None, None]
return tuple(_tee(iterator, shared_link) for _ in range(n))

def _tee(iterator, link):
try:
while True:
if link[1] is None:
link[0] = next(iterator)
link[1] = [None, None]
value, link = link
yield value
except StopIteration:
return

Once a :func:`tee` has been created, the original *iterable* should not be
used anywhere else; otherwise, the *iterable* could get advanced without
the tee objects being informed.
raise ValueError
rhettinger marked this conversation as resolved.
Show resolved Hide resolved
if n == 0:
return ()
iterator = _tee(iterable)
result = [iterator]
for _ in range(n - 1):
result.append(_tee(iterator))
return tuple(result)

class _tee:

def __init__(self, iterable):
it = iter(iterable)
if isinstance(it, _tee):
self.iterator = it.iterator
self.link = it.link
else:
self.iterator = it
self.link = [None, None]

def __iter__(self):
return self

def __next__(self):
link = self.link
if link[1] is None:
link[0] = next(self.iterator)
link[1] = [None, None]
value, self.link = link
return value

When the input *iterable* is already a tee iterator object, all
members of the return tuple are constructed as if they had been
Expand Down
1 change: 0 additions & 1 deletion Include/internal/pycore_global_objects_fini_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Include/internal/pycore_global_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ struct _Py_global_strings {
STRUCT_FOR_ID(__classdictcell__)
STRUCT_FOR_ID(__complex__)
STRUCT_FOR_ID(__contains__)
STRUCT_FOR_ID(__copy__)
STRUCT_FOR_ID(__ctypes_from_outparam__)
STRUCT_FOR_ID(__del__)
STRUCT_FOR_ID(__delattr__)
Expand Down
1 change: 0 additions & 1 deletion Include/internal/pycore_runtime_init_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 0 additions & 4 deletions Include/internal/pycore_unicodeobject_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

84 changes: 48 additions & 36 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,10 +1249,11 @@ def test_tee(self):
self.assertEqual(len(result), n)
self.assertEqual([list(x) for x in result], [list('abc')]*n)

# tee pass-through to copyable iterator
# tee objects are independent (see bug gh-123884)
a, b = tee('abc')
c, d = tee(a)
self.assertTrue(a is c)
e, f = tee(c)
self.assertTrue(len({a, b, c, d, e, f}) == 6)

# test tee_new
t1, t2 = tee('abc')
Expand Down Expand Up @@ -1759,21 +1760,36 @@ def test_tee_recipe(self):

def tee(iterable, n=2):
if n < 0:
raise ValueError('n must be >= 0')
iterator = iter(iterable)
shared_link = [None, None]
return tuple(_tee(iterator, shared_link) for _ in range(n))
raise ValueError
rhettinger marked this conversation as resolved.
Show resolved Hide resolved
if n == 0:
return ()
iterator = _tee(iterable)
result = [iterator]
for _ in range(n - 1):
result.append(_tee(iterator))
return tuple(result)

class _tee:

def __init__(self, iterable):
it = iter(iterable)
if isinstance(it, _tee):
self.iterator = it.iterator
self.link = it.link
else:
self.iterator = it
self.link = [None, None]

def _tee(iterator, link):
try:
while True:
if link[1] is None:
link[0] = next(iterator)
link[1] = [None, None]
value, link = link
yield value
except StopIteration:
return
def __iter__(self):
return self

def __next__(self):
link = self.link
if link[1] is None:
link[0] = next(self.iterator)
link[1] = [None, None]
value, self.link = link
return value

# End tee() recipe #############################################

Expand Down Expand Up @@ -1819,12 +1835,10 @@ def _tee(iterator, link):
self.assertRaises(TypeError, tee, [1,2], 'x')
self.assertRaises(TypeError, tee, [1,2], 3, 'x')

# Tests not applicable to the tee() recipe
if False:
# tee object should be instantiable
a, b = tee('abc')
c = type(a)('def')
self.assertEqual(list(c), list('def'))
# tee object should be instantiable
a, b = tee('abc')
c = type(a)('def')
self.assertEqual(list(c), list('def'))

# test long-lagged and multi-way split
a, b, c = tee(range(2000), 3)
Expand All @@ -1845,21 +1859,19 @@ def _tee(iterator, link):
self.assertEqual(len(result), n)
self.assertEqual([list(x) for x in result], [list('abc')]*n)

# tee objects are independent (see bug gh-123884)
a, b = tee('abc')
c, d = tee(a)
e, f = tee(c)
self.assertTrue(len({a, b, c, d, e, f}) == 6)

# Tests not applicable to the tee() recipe
if False:
# tee pass-through to copyable iterator
a, b = tee('abc')
c, d = tee(a)
self.assertTrue(a is c)

# test tee_new
t1, t2 = tee('abc')
tnew = type(t1)
self.assertRaises(TypeError, tnew)
self.assertRaises(TypeError, tnew, 10)
t3 = tnew(t1)
self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
# test tee_new
t1, t2 = tee('abc')
tnew = type(t1)
self.assertRaises(TypeError, tnew)
self.assertRaises(TypeError, tnew, 10)
t3 = tnew(t1)
self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))

# test that tee objects are weak referencable
a, b = tee(range(10))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fixed bug in itertools.tee() handling of other tee inputs (a tee in a tee).
The output now has the promised *n* independent new iterators. Formerly,
the first iterator was identical (not independent) to the input iterator.
This would sometimes give surprising results.
36 changes: 9 additions & 27 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n)
/*[clinic end generated code: output=1c64519cd859c2f0 input=c99a1472c425d66d]*/
{
Py_ssize_t i;
PyObject *it, *copyable, *copyfunc, *result;
PyObject *it, *to, *result;

if (n < 0) {
PyErr_SetString(PyExc_ValueError, "n must be >= 0");
Expand All @@ -1053,41 +1053,23 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n)
return NULL;
}

if (PyObject_GetOptionalAttr(it, &_Py_ID(__copy__), &copyfunc) < 0) {
Py_DECREF(it);
itertools_state *state = get_module_state(module);
to = tee_fromiterable(state, it);
Py_DECREF(it);
if (to == NULL) {
Py_DECREF(result);
return NULL;
}
if (copyfunc != NULL) {
copyable = it;
}
else {
itertools_state *state = get_module_state(module);
copyable = tee_fromiterable(state, it);
Py_DECREF(it);
if (copyable == NULL) {
Py_DECREF(result);
return NULL;
}
copyfunc = PyObject_GetAttr(copyable, &_Py_ID(__copy__));
if (copyfunc == NULL) {
Py_DECREF(copyable);
Py_DECREF(result);
return NULL;
}
}

PyTuple_SET_ITEM(result, 0, copyable);
PyTuple_SET_ITEM(result, 0, to);
for (i = 1; i < n; i++) {
copyable = _PyObject_CallNoArgs(copyfunc);
if (copyable == NULL) {
Py_DECREF(copyfunc);
to = tee_copy((teeobject *)to, NULL);
if (to == NULL) {
Py_DECREF(result);
return NULL;
}
PyTuple_SET_ITEM(result, i, copyable);
PyTuple_SET_ITEM(result, i, to);
}
Py_DECREF(copyfunc);
return result;
}

Expand Down
Loading