跳至主要内容

類別繼承與家族紛爭(下)

為你自己學 Python

上個章節大概介紹過 C3 線性演算法,可以自己手算,也能透過類別本身自帶的 .mro() 方法來查看 MRO 的順序,所以這個章節我們就順著這個 .mro() 方法,看看這個演算法在 CPython 是怎麼實作的。

演算法實作

合併前準備

因為每個類別都有 .mro() 方法,所以順著 PyType_Typetype_methods 成員應該可以找到這個方法的實作:

檔案:Objects/typeobject.c
static PyObject *
type_mro_impl(PyTypeObject *self)
{
PyObject *seq;
seq = mro_implementation(self);
if (seq != NULL && !PyList_Check(seq)) {
Py_SETREF(seq, PySequence_List(seq));
}
return seq;
}

看起來 mro_implementation() 會回傳一個序列回來,應該就是它了:

檔案:Objects/typeobject.c
static PyObject *
mro_implementation(PyTypeObject *type)
{
// ... 略 ...
PyObject *bases = lookup_tp_bases(type);
Py_ssize_t n = PyTuple_GET_SIZE(bases);
for (Py_ssize_t i = 0; i < n; i++) {
PyTypeObject *base = _PyType_CAST(PyTuple_GET_ITEM(bases, i));
if (lookup_tp_mro(base) == NULL) {
// ... 錯誤處理 ...
}
assert(PyTuple_Check(lookup_tp_mro(base)));
}

// ... 略 ...
}

首先先取得這個類別的上層類別,舉例來說:

class A: pass
class B: pass
class C(A, B): pass

對於類別 C 來說,它的上層類別就是 AB,如果去追蹤 lookup_tp_bases() 的實作,接下來的 for 迴圈,透過 lookup_tp_mro() 函數檢查每個上層類別是不是都有 MRO,如果沒有就會拋出錯誤。

來看看 lookup_tp_bases()lookup_tp_mro() 這兩個函數的實作:

檔案:Objects/typeobject.c
static inline PyObject *
lookup_tp_bases(PyTypeObject *self)
{
return self->tp_bases;
}

static inline PyObject *
lookup_tp_mro(PyTypeObject *self)
{
return self->tp_mro;
}

滿單純的,就只是讀取 tp_basestp_mro 這兩個成員而已。再接著往下看:

檔案:Objects/typeobject.c
// ... 略 ...
if (n == 1) {
PyTypeObject *base = _PyType_CAST(PyTuple_GET_ITEM(bases, 0));
PyObject *base_mro = lookup_tp_mro(base);
Py_ssize_t k = PyTuple_GET_SIZE(base_mro);
PyObject *result = PyTuple_New(k + 1);
if (result == NULL) {
return NULL;
}

PyTuple_SET_ITEM(result, 0, Py_NewRef(type));
for (Py_ssize_t i = 0; i < k; i++) {
PyObject *cls = PyTuple_GET_ITEM(base_mro, i);
PyTuple_SET_ITEM(result, i + 1, Py_NewRef(cls));
}
return result;
}
// ... 略 ...

n == 1 表示上層類別只有一個,這時候就可以走比較簡單的流程。這裡做了一個空的 Tuple 出來,然後把目前的類別放在第一個位置,最後再跑個 for 迴圈把上層類別的 MRO 一個一個放進去:

檔案:Objects/typeobject.c
// ... 略 ...
PyObject **to_merge = PyMem_New(PyObject *, n + 1);
// ... 錯誤處理 ...

for (Py_ssize_t i = 0; i < n; i++) {
PyTypeObject *base = _PyType_CAST(PyTuple_GET_ITEM(bases, i));
to_merge[i] = lookup_tp_mro(base);
}
to_merge[n] = bases;

PyObject *result = PyList_New(1);
// ... 錯誤處理 ...

PyList_SET_ITEM(result, 0, Py_NewRef(type));
if (pmerge(result, to_merge, n + 1) < 0) {
Py_CLEAR(result);
}
PyMem_Free(to_merge);

return result;
// ... 略 ...

中間的 for 迴圈把每個上層類別的 MRO 都放進 to_merge 這個陣列裡,最後再把自己也放進去,待會就可以準備進行合併。

進行合併

這裡真正合併的地方是 pmerge() 函數:

檔案:Objects/typeobject.c
static int
pmerge(PyObject *acc, PyObject **to_merge, Py_ssize_t to_merge_size)
{
int res = 0;
Py_ssize_t i, j, empty_cnt;
int *remain;

remain = PyMem_New(int, to_merge_size);

// ... 錯誤處理 ...
for (i = 0; i < to_merge_size; i++)
remain[i] = 0;

again:
empty_cnt = 0;
for (i = 0; i < to_merge_size; i++) {
PyObject *candidate;

PyObject *cur_tuple = to_merge[i];

if (remain[i] >= PyTuple_GET_SIZE(cur_tuple)) {
empty_cnt++;
continue;
}

candidate = PyTuple_GET_ITEM(cur_tuple, remain[i]);
for (j = 0; j < to_merge_size; j++) {
PyObject *j_lst = to_merge[j];
if (tail_contains(j_lst, remain[j], candidate))
goto skip;
}
res = PyList_Append(acc, candidate);
if (res < 0)
goto out;

for (j = 0; j < to_merge_size; j++) {
PyObject *j_lst = to_merge[j];
if (remain[j] < PyTuple_GET_SIZE(j_lst) &&
PyTuple_GET_ITEM(j_lst, remain[j]) == candidate) {
remain[j]++;
}
}
goto again;
skip: ;
}

if (empty_cnt != to_merge_size) {
set_mro_error(to_merge, to_merge_size, remain);
res = -1;
}

out:
PyMem_Free(remain);
return res;
}

這個函數的實作不算太難懂,一開始先去要一塊空的記憶體 remain 並且先把裡面都初始化成 0,這個待會是用來儲存每個準備要進行合併的列表的索引。接著開始找第一個候選人 candidate,然後用 tail_contains() 函數判斷是否出現在其它列表的尾部,如果是的話,就 goto skip 繼續挑下一個候選人。如果是個好線頭(Good Head)的話,就把這個候選人加進去,然後把這個候選人在其它列表裡的索引都往前移一格,然後再回到 again 重新找下一個候選人。

最後,如果還有未處理完的但找不到合適的候選人,就給錯誤訊息。基本上這個流程就是我們在上個章節看到的手算過程。介紹完了 MRO 的計算,接著來看看跟 MRO 也有關的 super() 函數。

家族紛爭

先來個簡單的程式碼:

class Animal:
def sleep(self):
print("Zzzzz")

class Cat(Animal):
def sleep(self):
super().sleep()

kitty = Cat()
kitty.sleep()

假設你有其它物件導向程式語言的經驗,看到這個 super().sleep() 大概能猜到是要呼叫上層類別的 .sleep() 方法,執行之後的確會印出 "Zzzzz" 沒錯,不過 Python 的 super() 函數可能跟各位在其它語言學到的不太一樣。不知道大家有注意到這裡是 super() 而不是 super 嗎?這個我們待會再來討論,先來看看 super() 函數是怎麼實作的。

Super!

從 Bytecode 指令來看:

8           4 LOAD_GLOBAL              0 (super)
14 LOAD_DEREF 1 (__class__)
16 LOAD_FAST 0 (self)
18 LOAD_SUPER_ATTR 5 (NULL|self + sleep)
22 CALL 0
30 POP_TOP
32 RETURN_CONST 0 (None)

這裡用了一個我們沒看過的 LOAD_SUPER_ATTR 5 指令,如果追進這個指令,會發現在這個指令裡會呼叫 PyObject_Vectorcall() 函數呼叫全域的 super() 函數來建立一個實體,事實上這個 super 本質上在 Python 是一個全域範圍的函數,更精準的說它是個類別:

>>> super
<class 'super'>

它就跟 intstr 一樣,都是一種內建的類別。用我們之前學過的技巧,應該可以猜到會有一個名為 PySuper_Type 的類別,我們就順著 PySuper_Type 類別的 tp_init 成員,來看看當呼叫 super() 類別建立實體的時候發生什麼事:

檔案:Objects/typeobject.c
static int
super_init(PyObject *self, PyObject *args, PyObject *kwds)
{
// ... 略 ...
if (super_init_impl(self, type, obj) < 0) {
return -1;
}
return 0;
}

繼續往下追進實際的實作函數 super_init_impl()

檔案:Objects/typeobject.c
static inline int
super_init_impl(PyObject *self, PyTypeObject *type, PyObject *obj) {
superobject *su = (superobject *)self;
PyTypeObject *obj_type = NULL;
if (type == NULL) {
PyThreadState *tstate = _PyThreadState_GET();
_PyInterpreterFrame *frame = _PyThreadState_GetFrame(tstate);
if (frame == NULL) {
PyErr_SetString(PyExc_RuntimeError,
"super(): no current frame");
return -1;
}
int res = super_init_without_args(frame, frame->f_code, &type, &obj);

if (res < 0) {
return -1;
}
}

// ... 略 ...
Py_XSETREF(su->type, (PyTypeObject*)Py_NewRef(type));
Py_XSETREF(su->obj, obj);
Py_XSETREF(su->obj_type, obj_type);
return 0;
}

如果單獨呼叫 super() 而不帶任何參數的話,會進入 super_init_without_args() 函數,並且把當前 Frame 跟 Code Object 傳進去:

檔案:Objects/typeobject.c
static int
super_init_without_args(_PyInterpreterFrame *cframe, PyCodeObject *co,
PyTypeObject **type_p, PyObject **obj_p)
{
// ... 略 ...
PyTypeObject *type = NULL;
int i = PyCode_GetFirstFree(co);
for (; i < co->co_nlocalsplus; i++) {
PyObject *name = PyTuple_GET_ITEM(co->co_localsplusnames, i);
if (_PyUnicode_Equal(name, &_Py_ID(__class__))) {
PyObject *cell = _PyFrame_GetLocalsArray(cframe)[i];
// ... 錯誤處理 ...
type = (PyTypeObject *) PyCell_GET(cell);
// ... 錯誤處理 ...
break;
}
}

// ... 略 ...
*type_p = type;
*obj_p = firstarg;
return 0;
}

這個函數會取得 Code Object 裡的 __class__ 變數,也就是指向當前的類別。所以我們在 Python 裡執行 super() 函數而且沒有傳任何參數給它的時候,Python 也能得知目前的類別跟物件資訊。

誰家的小孩?

所以,在我們上面的範例中,在 Cat 類別裡執行 super() 函數的時候,產生的是 Cat 類別的實體還是 Animal 類別的實體?雖然以結果來說好像是上層類別的實體,但其實都不是。super() 函數會建立的是一個 PySuper_Type 類別的實體,在這個實體裡有著目前的類別跟物件資訊,它的行為比較像是一個代理物件(Proxy Object)。如果你在過程中試著印出這顆 super() 函數建立的實體,會發現它會長這樣:

<super: <class 'Cat'>, <Cat object>>

不是 Cat 類別也不是 Animal 類別,就是 super 這個類別所建立的代理物件。當呼叫這顆代理物件身上的方法的時候,例如 super().sleep() 的時候,它會找 PySuper_Type 這個結構的 tp_getattro 成員:

檔案:Objects/typeobject.c
static PyObject *
super_getattro(PyObject *self, PyObject *name)
{
superobject *su = (superobject *)self;

if (PyUnicode_Check(name) &&
PyUnicode_GET_LENGTH(name) == 9 &&
_PyUnicode_Equal(name, &_Py_ID(__class__)))
return PyObject_GenericGetAttr(self, name);

return do_super_lookup(su, su->type, su->obj, su->obj_type, name, NULL);
}

這裡的 su 就是我們建立的代理物件,接下來是檢查是不是在查找 __class__ 這個屬性,例如 super().__class__ 這樣的語法,如果是,就會執行泛用型的 PyObject_GenericGetAttr() 函數而得到 <class 'super'>。這裡為什麼要檢查 name 的長度是不是 9?因為 __class__ 這個幾個字的長度就剛好是 9,加在這裡可以讓效能好那麼一點點,畢竟檢查字串長度是不是 9 會比檢查兩個字串是不是完全相等來的快一點。

如果不是查找 __class__ 屬性而是其它的,例如 sleep,就會進入 do_super_lookup() 函數。為了避免待會大家進到函數裡會暈船,我先整理一下目前關於 su 以及它身上的幾個屬性的狀態:

  • su 就是這顆代理物件本體。
  • su->type 這個是 super() 函數建立時的類別,因為這裡我們沒有帶參數給它,Python 根據當下 Frame 推敲出當前所在的類別,也就是 Cat
  • su->obj 是目前的物件,也就是 self
  • su->obj_type 是目前的物件的真實類別,也就是 Cat

這裡 typeobj_type 剛好是一樣的,但也有可能是不同的,例如我在原本的範例再加一層繼承關係:

class Animal:
def sleep(self):
print("Zzzzz")

class Cat(Animal):
def sleep(self):
super().sleep()

class Kitty(Cat):
pass

k = Kitty()
k.sleep()

當呼叫 k.sleep() 的時候,super() 函數建立的代理物件的狀態如下:

  • su->type 這個是 super() 函數建立時的類別,所以是 Cat
  • su->obj 是目前的物件,也就是 k 物件。
  • su->obj_type 是目前的物件的真實類別,也就是 Kitty

雖然這裡 su->type 好像跟 su->obj_type 是一樣的,但我們待會再來看看這兩個值不一樣的例子。接著來看 do_super_lookup() 函數:

檔案:Objects/typeobject.c
static PyObject *
do_super_lookup(superobject *su, PyTypeObject *su_type, PyObject *su_obj,
PyTypeObject *su_obj_type, PyObject *name, int *method)
{
// ... 略 ...
res = _super_lookup_descr(su_type, su_obj_type, name);
// ... 略 ...
}

可以看到這個函數利用我們傳進去的資訊來解析方法,追進 _super_lookup_descr() 函數就會看到:

檔案:Objects/typeobject.c
static PyObject *
_super_lookup_descr(PyTypeObject *su_type, PyTypeObject *su_obj_type, PyObject *name)
{
// ... 略 ...
mro = lookup_tp_mro(su_obj_type);
// ... 略 ...

i++;

// ... 略 ...
do {
PyObject *obj = PyTuple_GET_ITEM(mro, i);
PyObject *dict = lookup_tp_dict(_PyType_CAST(obj));
// ... 略 ...
res = PyDict_GetItemWithError(dict, name);
// ... 略 ...
i++;
} while (i < n);
}

會從 su->obj_type(在上面的範例裡就是 Kitty)開始在它的 MRO 中查找,而不是從它所在的 su->type(也就是 Cat),但這裡要特別注意的是那行 i++,這表示待會的迴圈會從 mro 的第二個元素開始找,這樣就可以避免找到自己的方法而造成無窮迴圈。

如果各位學過其它程式語言,查找順序竟然是從這個物件本身的類別開始找而不是直接從上層類別,這有點違反直覺。但是,為什麼 Python 這麼設計?

解決家族紛爭

我在「為你自己學 Python」的物件導向程式設計 - 入門篇曾經介紹到多重繼承可能造成的「鑽石問題(Diamond Problem)」,舉個例子:

class Animal:
def sleep(self):
print("Zzzzz")

class Bird(Animal):
def sleep(self):
print("我可以站著睡覺")
super().sleep()

class Fish(Animal):
def sleep(self):
print("我睡覺不用閉眼睛")
super().sleep()

class Cat(Bird, Fish):
def sleep(self):
print("呼嚕呼嚕~")
super().sleep()

kitty = Cat()
kitty.sleep() # 印出什麼?

你覺得這段程式碼最後會印出什麼?這裡比較難猜的應該是在 Bird 類別裡的 super().sleep() 指的是誰,如果你用其它程式語言的設計來猜的話,大概會猜 Animal 類別的 sleep() 方法。

想想看,在 Bird 類別裡的 super() 現在的狀態:

  • su 就是這個代理物件。
  • su->type 是根據上下文推敲出來的 Bird 類別。
  • su->obj 是目前的物件,也就是 kitty 物件。
  • su->obj_type 是目前的物件的真實類別,也就是 Cat

我們剛看原始碼,知道 Python 會從 su->obj_type 的也就是 Cat 類別的 MRO 開始找,而不是從 Bird 類別的 MRO。Cat 類別的 MRO 是 Cat -> Bird -> Fish -> Animal,Python 會在這一串 MRO 裡面找到 Bird,然後往它的下一個順位也就是 Fish 類別,所以結果會依序印出:

呼嚕呼嚕~
我可以站著睡覺
我睡覺不用閉眼睛
Zzzzz

如果要對 Python 的 super() 做個簡單的一句話解釋的話,就是「找到目前類別的 MRO,然後從這個 MRO 的下一個類別開始找」。

想想看,如果是直接往 Bird 的上層 Animal 找的話,那麼 Fish 類別的 sleep() 方法就會被跳過,這樣就破壞了原本的 MRO 繼承順序了。所以 Python 藉由 C3 線性演算法算出類別的 MRO,再搭配 super() 函數的設計來解決多重繼承可能遇到的問題。

指定 super 類別

雖然上面的範例中,super() 我們都沒帶參數給它,但其實 super() 函數是可以帶參數的,例如:

class Animal:
def sleep(self):
print("Zzzzz")

class Cat(Animal):
def sleep(self):
super(Cat, self).sleep()

這等於指定 su->typeCat,而 su->objself,這樣 Python 就不用自己從當前的 Frame 去推敲了。不過就以大部份的情況來說,不帶參數的 super() 應該就夠用了,而且可以少打幾個字。

小練習:我是誰?

最後給個小練習,大家猜猜看答案是什麼:

class Person:
name = "Walter White"
def say_my_name(self):
print(self.name)

class Heisenberg(Person):
name = "Heisenberg"
def say_my_name(self):
super().say_my_name()

heisenberg = Heisenberg()
heisenberg.say_my_name() # 這會印出什麼?

直覺你可能會猜會印出 Person 類別的 "Walter White",但執行之後會發現印出來的是 "Heisenberg"

為什麼是這個答案,你只要想一下 super() 函數建立的代理物件,裡面的 self 指的是誰就會知道答案囉 :)

工商服務

想學 Python 嗎?我教你啊 :)

想要成為軟體工程師嗎?這不是條輕鬆的路,除了興趣之外,還需要足夠的決心、設定目標並持續學習,我們的ASTROCamp 軟體工程師培訓營提供專業的前後端課程培訓,幫助你在最短時間內建立正確且扎實的軟體開發技能,有興趣而且不怕吃苦的話不妨來試試看!