跳至主要内容

轉呀轉呀七彩迭代器

為你自己學 Python

在 Python 裡迭代器(Iterator)使用的頻率很高,讓我們可以不用像其它程式語言一樣用 for 迴圈就能遍歷各種「容器」,而且這所謂的容器還不是只有串列,字典、字串、範圍(Range)這些都能用類似的操作,這個章節我們就來看看迭代器是怎麼實作的。

迭代器協議

在「為你自己學 Python」的物件導向程式設計 - 進階篇有提到三個看起來有點像的東西,分別是 Iteration、Iterable 以及 Iterator,很快的複習一下:

  • Iteration,「迭代」,名詞,指的是遍歷某個物件裡面所有元素的過程。
  • Iterable,「可迭代的」,形容詞,是指可以被進行迭代的物件,本章節提到的「可迭代物件」就是指它。
  • Iterator,「迭代器」,名詞,有點像是一個容器,我們可以用特定的方法遍歷這個容器中的每個元素。

根據 Python 對迭代器的定義,只要有實作「迭代器協議(Iterator Protocol)」的物件,就能被稱之迭代器。迭代器協議的內容也很簡單,只要有實作 __iter__() 以及 __next__() 這兩個魔術方法就可以了。不過這是在 Python 層級的定義,我們來看看在 CPython 是怎麼實作的。

在 Python 要建立一個迭代器,可以使用內建函數 iter()

iter([9, 5, 2, 7])

所以我們先從這個函數的實作原始碼看起:

檔案:Python/clinic/bltinmodule.c.h
static PyObject *
builtin_iter(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
{
PyObject *return_value = NULL;
PyObject *object;
PyObject *sentinel = NULL;

if (!_PyArg_CheckPositional("iter", nargs, 1, 2)) {
goto exit;
}
object = args[0];
if (nargs < 2) {
goto skip_optional;
}
sentinel = args[1];
skip_optional:
return_value = builtin_iter_impl(module, object, sentinel);

exit:
return return_value;
}

看起來真正實作的函數是 builtin_iter_impl()

檔案:Python/bltinmodule.c
static PyObject *
builtin_iter_impl(PyObject *module, PyObject *object, PyObject *sentinel)
{
if (sentinel == NULL)
return PyObject_GetIter(object);
if (!PyCallable_Check(object)) {
PyErr_SetString(PyExc_TypeError,
"iter(object, sentinel): object must be callable");
return NULL;
}
return PyCallIter_New(object, sentinel);
}

還滿容易懂的,如果沒有帶「哨兵(sentinel)」的話就直接呼叫 PyObject_GetIter(),否則就是 PyCallIter_New()。但這裡的哨兵是什麼意思?

站住,口令,誰!

其實哨兵的意思是如果迭代器回傳的值等於哨兵的話就停止迭代,舉個例子:

from random import randint

numbers = iter(lambda: randint(1, 10), 7)

for num in numbers:
print(num)

因為我在 iter() 的第二個參數帶了 7,所以上面這段程式碼會不斷的產生 1 到 10 之間的隨機數字,直到數字等於 7 為止。如果沒有帶哨兵的話就會一直迭代下去。

我們先從比較簡單的 PyCallIter_New() 開始看:

檔案:Objects/iterobject.c
PyObject *
PyCallIter_New(PyObject *callable, PyObject *sentinel)
{
calliterobject *it;
it = PyObject_GC_New(calliterobject, &PyCallIter_Type);
if (it == NULL)
return NULL;
it->it_callable = Py_NewRef(callable);
it->it_sentinel = Py_NewRef(sentinel);
_PyObject_GC_TRACK(it);
return (PyObject *)it;
}

這裡會建立一個 calliterobject 結構的物件,這個結構是用來儲存迭代器的資訊:

檔案:Objects/iterobject.c
typedef struct {
PyObject_HEAD
PyObject *it_callable;
PyObject *it_sentinel;
} calliterobject;

還滿單純的,我們順便看一下 PyCallIter_Type 的結構:

檔案:Objects/iterobject.c
PyTypeObject PyCallIter_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"callable_iterator", /* tp_name */
sizeof(calliterobject), /* tp_basicsize */
0, /* tp_itemsize */
// ... 略 ...
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)calliter_iternext, /* tp_iternext */
calliter_methods, /* tp_methods */
};

這裡的重點應該在 tp_itertp_iternexttp_iter 會回傳迭代器物件自己,而 tp_iternext 則應該要回傳下一個元素,其實這就是我們前面提到迭代器協議需要實作的 __iter__ 以及 __next__ 方法。

檔案:Objects/object.c
PyObject *
PyObject_SelfIter(PyObject *obj)
{
return Py_NewRef(obj);
}

就真的是回傳自己這個迭代器物件而已,很簡單,再看看 calliter_iternext() 函數:

檔案:Objects/iterobject.c
static PyObject *
calliter_iternext(calliterobject *it)
{
PyObject *result;
// ... 錯誤處理 ...

result = _PyObject_CallNoArgs(it->it_callable);
if (result != NULL && it->it_sentinel != NULL){
int ok;

ok = PyObject_RichCompareBool(it->it_sentinel, result, Py_EQ);
if (ok == 0) {
return result;
}

if (ok > 0) {
Py_CLEAR(it->it_callable);
Py_CLEAR(it->it_sentinel);
}
}
else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
PyErr_Clear();
Py_CLEAR(it->it_callable);
Py_CLEAR(it->it_sentinel);
}
Py_XDECREF(result);
return NULL;
}

中間使用 PyObject_RichCompareBool() 函數比對是不是等於哨兵,如果是的話就停止迭代,並且把 it_callable 以及 it_sentinel 的值清空,不然的話就回傳元素。

這裡可以另外看一個小亮點,在這裡也可以看到如果迭代器回傳的值是 StopIteration 的時候,會呼叫 PyErr_Clear() 清除當前執行緒的錯誤狀態並停止迭代,這也是為什麼一般我們用 next() 拿下一個拿到沒東西的時候會拋出 StopIteration 例外,但在 for 迴圈或是串列推導式裡不會出錯的原因。

整個看起來算滿簡單的。我們再看另一個 PyObject_GetIter() 函數,這個就稍微複雜一點:

檔案:Objects/abstract.c
PyObject *
PyObject_GetIter(PyObject *o)
{
PyTypeObject *t = Py_TYPE(o);
getiterfunc f;

f = t->tp_iter;
if (f == NULL) {
if (PySequence_Check(o))
return PySeqIter_New(o);
return type_error("'%.200s' object is not iterable", o);
}
else {
PyObject *res = (*f)(o);
if (res != NULL && !PyIter_Check(res)) {
PyErr_Format(PyExc_TypeError,
"iter() returned non-iterator "
"of type '%.100s'",
Py_TYPE(res)->tp_name);
Py_SETREF(res, NULL);
}
return res;
}
}

先看看有沒有實作 tp_iter 成員,如果有就呼叫它:

PyObject *res = (*f)(o);

這行程式碼就是在做這件事,不過這有一些有趣的細節待會看。

如果沒有沒有實作 tp_iter 成員的話沒關係,就再檢查是不是一種序列,如果是的話就用這個序列建立一個泛用型的 PySeqIter_New() 的迭代器,再看看這個 PySeqIter_New() 在做什麼:

檔案:Objects/iterobject.c
PyObject *
PySeqIter_New(PyObject *seq)
{
seqiterobject *it;

if (!PySequence_Check(seq)) {
PyErr_BadInternalCall();
return NULL;
}
it = PyObject_GC_New(seqiterobject, &PySeqIter_Type);
if (it == NULL)
return NULL;
it->it_index = 0;
it->it_seq = Py_NewRef(seq);
_PyObject_GC_TRACK(it);
return (PyObject *)it;
}

這個函數會建立一個 seqiterobject 結構的物件,這個結構的設計滿簡單的:

檔案:Objects/iterobject.c
typedef struct {
PyObject_HEAD
Py_ssize_t it_index;
PyObject *it_seq;
} seqiterobject;

it_index 記錄的是目前迭代的位置,而 it_seq 則是指向被迭代的序列物件。再看一下 PySeqIter_Type 的結構:

檔案:Objects/iterobject.c
PyTypeObject PySeqIter_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"iterator", /* tp_name */
sizeof(seqiterobject), /* tp_basicsize */
// ... 略 ...
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
iter_iternext, /* tp_iternext */
seqiter_methods, /* tp_methods */
0, /* tp_members */
};

這個 tp_iter 成員也是回傳自己而已,來看看 tp_iternext 的實作:

檔案:Objects/iterobject.c
static PyObject *
iter_iternext(PyObject *iterator)
{
seqiterobject *it;
PyObject *seq;
PyObject *result;

assert(PySeqIter_Check(iterator));
it = (seqiterobject *)iterator;
seq = it->it_seq;
// ... 錯誤處理 ...

result = PySequence_GetItem(seq, it->it_index);
if (result != NULL) {
it->it_index++;
return result;
}
if (PyErr_ExceptionMatches(PyExc_IndexError) ||
PyErr_ExceptionMatches(PyExc_StopIteration))
{
PyErr_Clear();
it->it_seq = NULL;
Py_DECREF(seq);
}
return NULL;
}

PySequence_GetItem() 函數根據 it_index 索引值去拿序列中的元素,如果拿到的話就回傳,不然的話就停止迭代。這裡也同樣也可以看到如果拿到 IndexError 或是 StopIteration 例外的時候就停止迭代,而且有 PyErr_Clear() 所以不會引發錯誤。

不同的迭代器?

如果在 iter() 函數裡傳入不同的可迭代物件,會得到不同的迭代器物件,我們來看看不同的迭代器物件長什麼樣子:

>>> iter([])
<list_iterator object>

>>> iter(range(0))
<range_iterator object>

>>> iter({})
<dict_keyiterator object>

>>> iter('hello')
<str_ascii_iterator object>

>>> iter('七龍珠')
<str_iterator object>

怎麼這麼多種?這是因為不同的可迭代物件有不同的迭代器實作,在剛才的 PyObject_GetIter() 函數裡的這一行:

PyObject *res = (*f)(o);

就是呼叫 tp_iter 成員的實作函數,並且把目前這個可迭代物件傳進去。不同的資料型態,可能有著不同的 tp_iter 實作,例如串列:

檔案:Objects/listobject.c
static PyObject *
list_iter(PyObject *seq)
{
_PyListIterObject *it;

// ... 錯誤處理 ...
it = PyObject_GC_New(_PyListIterObject, &PyListIter_Type);
if (it == NULL)
return NULL;
it->it_index = 0;
it->it_seq = (PyListObject *)Py_NewRef(seq);
_PyObject_GC_TRACK(it);
return (PyObject *)it;
}

這裡會做出一個 PyListIter_Type 類型的迭代器物件,這個迭代器物件的 tp_iternext 成員實作如下:

檔案:Objects/listobject.c
static PyObject *
listiter_next(_PyListIterObject *it)
{
PyListObject *seq;
PyObject *item;

// ... 錯誤處理 ...

if (it->it_index < PyList_GET_SIZE(seq)) {
item = PyList_GET_ITEM(seq, it->it_index);
++it->it_index;
return Py_NewRef(item);
}

it->it_seq = NULL;
Py_DECREF(seq);
return NULL;
}

串列的 tp_iternext 比較簡單。如果傳入的可迭代物件是字串,就會看看這個字全部都是 ASCII 還是有其它編碼而決定會建立哪一種迭代器,字典、範圍都是一樣的做法。也就是因為這樣,所以在上面會看到不同的迭代器物件。

有興趣的話,可再順著同樣的思路去追看看由範圍、字串跟字典這些物件所建立的迭代器物件的 tp_iternext 實作,這樣就能了解不同的迭代器物件是怎麼運作的。

同樣都是可以被 next() 函數操作的物件,跟上個章節介紹的產生器比起來,迭代器的實作簡單多了 :)

工商服務

想學 Python 嗎?我教你啊 :)

想要成為軟體工程師嗎?這不是條輕鬆的路,除了興趣之外,還需要足夠的決心、設定目標並持續學習,我們的ASTROCamp 軟體工程師培訓營提供專業的前後端課程培訓,幫助你在最短時間內建立正確且扎實的軟體開發技能,有興趣而且不怕吃苦的話不妨來試試看!