AI for Youth Academy 青少年AI研究计划

第一年 · 第16周

第十六章:方法覆盖与 super()

上周你学会了子类继承父类的所有内容。本周你将再学两个技能:如何用新版本替换(覆盖)父类方法,以及如何使用 super().__init__() 添加额外属性。这两个模式几乎在所有真实的 Python 库中都会用到——包括最流行的 AI 工具 PyTorch。

第七课时:方法覆盖与 super()

时长:60 分钟

学习目标

  • 在子类中覆盖(override)父类的方法
  • 正确编写 super().__init__() 来添加额外属性
  • 认识 PyTorch 等 AI 库中使用的继承模式

覆盖——替换行为(15分钟)

上周我们学到子类可以添加新方法。但有时我们不想添加——我们想替换父类的方法,让它做不同的事。这叫做覆盖(override)

来看一个例子。所有动物都会发出声音,但每种动物的声音不一样:

class Animal:
    def __init__(self, name, hp):
        self.name = name
        self.hp   = hp

    def speak(self):
        print(f'{self.name} 发出了声音。')

    def eat(self):
        self.hp += 10
        print(f'{self.name} 吃了食物。HP: {self.hp}')


class Dog(Animal):
    def speak(self):              # 覆盖了 Animal.speak
        print(f'{self.name} 说:汪汪!')


class Cat(Animal):
    def speak(self):              # 覆盖了 Animal.speak
        print(f'{self.name} 说:喵!')

现在把它们放进列表,让每个动物说话:

animals = [Dog('Rex', 80), Cat('Luna', 70), Dog('Buddy', 90)]
for a in animals:
    a.speak()

# Rex 说:汪汪!
# Luna 说:喵!
# Buddy 说:汪汪!

试一试

如果我们创建一个普通的 Animal 并调用 speak(),会发生什么?

mystery = Animal('???', 50)
mystery.speak()    # ??? 发出了声音。  ← 使用 Animal 的版本

没有覆盖的子类方法?也验证一下:

rex = Dog('Rex', 80)
rex.eat()        # Rex 吃了食物。HP: 90  ← Dog 没有 eat(),所以用 Animal 的

super().__init__()——最重要的模式(30分钟)

问题

上周我们学过,子类可以添加父类没有的属性。但有一个问题:谁来设置父类的属性?

如果子类重新定义了 __init__,Python 就不会自动运行父类的 __init__。你必须手动调用它。这就是 super().__init__() 的用途。

一步一步来

第 1 步:看看父类做什么——设置 namehp

class Animal:
    def __init__(self, name, hp):
        self.name = name
        self.hp   = hp

    def eat(self):
        self.hp += 10
        print(f'{self.name} 吃了食物。HP: {self.hp}')

第 2 步:子类需要额外的属性 voltage(电压)。用 super().__init__() 让父类先处理 namehp,然后自己添加 voltage

class ElectricPokemon(Animal):
    def __init__(self, name, hp, voltage):   # 多了一个参数:voltage
        super().__init__(name, hp)            # 让 Animal 先设置 name 和 hp
        self.voltage = voltage                # 然后添加自己的属性

    def thunder_shock(self, target):
        damage = self.voltage // 10
        target.hp -= damage
        print(f'{self.name} 使用十万伏特,造成 {damage} 点伤害!')
        if target.hp <= 0:
            print(f'{target.name} 倒下了!')

    def speak(self):
        print(f'{self.name}:皮卡 皮卡!')

第 3 步:测试!一个 ElectricPokemon 对象同时拥有父类的和自己的属性:

pikachu = ElectricPokemon('Pikachu', 100, 500)
print(pikachu.name)      # Pikachu   ← 来自 Animal(通过 super())
print(pikachu.hp)        # 100       ← 来自 Animal(通过 super())
print(pikachu.voltage)   # 500       ← ElectricPokemon 自己的
pikachu.speak()          # Pikachu:皮卡 皮卡!  ← 覆盖了 Animal.speak
pikachu.eat()            # Pikachu 吃了食物。HP: 110  ← 继承自 Animal

完整的继承模式——总结

每次你写一个有额外属性的子类时,都用这个模式:

class 子类(父类):
    def __init__(self, 父类的参数, 自己的参数):
        super().__init__(父类的参数)   # 第 1 步:让父类做它的事
        self.自己的属性 = 自己的参数      # 第 2 步:添加自己的东西

搭档练习

创建一个 FirePokemon(Animal) 类,要求:

  1. 有一个额外属性 fire_power(火焰威力)
  2. 使用 super().__init__() 设置 namehp
  3. 覆盖 speak() 方法
  4. 添加一个 flamethrower(target) 方法,造成等于 fire_power 的伤害
  5. 创建一个 ElectricPokemon 和一个 FirePokemon,让它们对战
显示参考答案
class FirePokemon(Animal):
    def __init__(self, name, hp, fire_power):
        super().__init__(name, hp)
        self.fire_power = fire_power

    def speak(self):
        print(f'{self.name}:嘎嘎嘎!')

    def flamethrower(self, target):
        target.hp -= self.fire_power
        print(f'{self.name} 使用喷射火焰,造成 {self.fire_power} 点伤害!')
        if target.hp <= 0:
            print(f'{target.name} 倒下了!')

# 对战!
pikachu = ElectricPokemon('Pikachu', 100, 500)
charmander = FirePokemon('Charmander', 90, 18)

pikachu.thunder_shock(charmander)   # 500//10 = 50 伤害
charmander.flamethrower(pikachu)     # 18 伤害
print(pikachu)       # Pikachu (HP: 82)
print(charmander)    # Charmander (HP: 40)

PyTorch 连接——你现在能读懂 AI 代码了!

你可能觉得 super().__init__() 看起来很高级。但其实,这就是全世界 AI 工程师每天写的代码。来看看 PyTorch —— 最流行的深度学习框架——中是怎么用的:

# 这是一段真实的 PyTorch 代码——用来定义一个神经网络
import torch.nn as nn

class MyNetwork(nn.Module):          # 继承自 nn.Module(就像继承自 Animal)
    def __init__(self):
        super().__init__()                # 调用父类的初始化(就像 super().__init__(name, hp))
        self.layer1 = nn.Linear(10, 5)   # 添加自己的属性(就像 self.voltage = voltage)
        self.layer2 = nn.Linear(5, 1)    # 再添加一个

    def forward(self, x):               # 覆盖父类的 forward 方法(就像覆盖 speak)
        x = self.layer1(x)
        x = self.layer2(x)
        return x

每一个 PyTorch 神经网络都使用这个模式:

  1. 继承一个基类
  2. 调用 super().__init__()
  3. __init__ 中添加网络层
  4. 覆盖 forward() 定义数据如何流过网络

出口票

  1. 如果在 ElectricPokemon 中忘记写 super().__init__(),会发生什么?
  2. 在上面的 PyTorch 例子中,什么扮演了 Animal 的角色?什么扮演了 ElectricPokemon 的角色?
  3. 如果一个子类没有写自己的 __init__,需要调用 super().__init__() 吗?为什么?

关键词汇