# 深度学习基础~自制框架

## 概述

当我们学习深度学习并接触到 Pytorch 之类的框架时，我们会思考，框架语法为什么应该那样写？

- 实现了前向传播的过程，给了训练数据，为什么模型就可以自动训练了？反向传播是如何实现的？
- 当我们把一部分官方实现的方法替换成我们自己的方法时，模型还能不能正确计算？什么情况下会报错？
- 其他人将 `numpy`、`scikit-learn` 和 `pytorch` 混合使用，什么样的混合使用是允许的，我们去使用会不会出错？
- 其他人在训练时手动控制 `cuda` 的计算，自己定义一些算子，我们想要修改应该如何？

本文的内容来自于《深度学习入门 2：自制框架》，这本书搭建了一个简易的深度学习框架 `DeZero`，语法类似 `pytorch`，通过一步步深入浅出的教导，我们知晓了 `pytorch` 之类的框架是如何从代码层面实现的，理解各种语法为什么要那样写，后面修改自己的模型时也会更加底气十足。

本人学习深度学习相关知识已经很久，理论知识掌握了许多，但是涉及到代码层面，总是会陷入**去记忆别人写的代码**这种怪圈，说到底其实本人并不懂这个框架，只是会用而已，只是一种别人代码这么用了所以我也可以这么用。

这本书一共 500 多页，如果有基础的话，全部翻看一遍其实并不会太久，几天时间就可以了，看完会有一种醍醐灌顶的感觉，赞叹 `pytorch` 之类的深度学习框架居然可以如此精妙，赞叹作者居然可以如此恰到好处的描绘出来。这本书从做工来看是非常精致的，内容的衔接也是做足了功夫，可见作者其实是倾注了很多心血的，作者花费的时间是远比本人看书的这几天时间要多的多的，不过因为作者的付出，世界上成千上万个像本人这样的学习者才能快速掌握相关的知识。

## 基础搭建

首先，需要构建名为 `Variable` 的类，对标的是 `pytorch` 中的 `Tensor`

```python
class Variable:  
    def __init__(self, data):  
        self.data = data
```

然后，需要构建名为 `Function` 的类：

```python
class Function:  
    def __call__(self, input):  
        x = input.data  
        y = self.forward(x)  
        output = Variable(y)   
        return output
  
    def forward(self, x):  
        raise NotImplementedError()  
```

## 反向传播

### 求导方法

计算机程序求导的方法主要有 3 种：

- **数值微分**：就是在下图中，让 $h$ 取一个极小的值计算出来的微分值的方法。使用计算机浮点数计算，存在**精度丢失**的缺点；神经网络参数众多，存在**计算成本高**的缺点。

![](https://img.papergate.top:5000/i/2025/12/694fe5f27d911.webp)


- **符号微分**：使用导数公式求导的方法，输入是式子，输出也是式子，被用在 Mathematica 和 MATLAB 等软件中。式子会变的臃肿，神经网络参数众多，计算成本高。
- **自动微分**：采用链式法则求导的方法。自动微分可以大体分为两种：前向模式的自动微分和反向模式的自动微分。反向传播相当于反向模式的自动微分。

### 链式法则

假设有一个函数 $y = F(x)$ ，这个函数 $F$ 由 3 个函数组成：$a=A(x)$ 、 $b=B(a)$ 和 $y=C(b)$

![](https://img.papergate.top:5000/i/2025/12/694fe82131800.webp)

$y$ 对 $x$ 的导数可以表示为：

$$
\frac{\mathrm{d} y}{\mathrm{~d} x}=\frac{\mathrm{d} y}{\mathrm{~d} y} \frac{\mathrm{~d} y}{\mathrm{~d} b} \frac{\mathrm{~d} b}{\mathrm{~d} a} \frac{\mathrm{~d} a}{\mathrm{~d} x}
$$

可以按照下图的顺序依次计算：

![](https://img.papergate.top:5000/i/2025/12/694fe89e21199.webp)

于是可以构建出下面这张计算图：

![](https://img.papergate.top:5000/i/2025/12/694fe8cc64614.webp)

从 $\frac{\mathrm{d} y}{\mathrm{~d} y}(=1)$ 开始，计算它与 $\frac{\mathrm{d} y}{\mathrm{~d} b}$ 的乘积。这里的 $\frac{\mathrm{d} y}{\mathrm{~d} b}$ 是函数 $y=C(b)$ 的导数。

因此，如果用 $C^{\prime}$ 表示函数 $C$ 的导函数，我们就可以把式子写成 $\frac{\mathrm{d} y}{\mathrm{~d} b}=C^{\prime}(b)$ 。

同样，有 $\frac{\mathrm{d} b}{\mathrm{~d} a}=B^{\prime}(a), \frac{d a}{d x}=A^{\prime}(x)$ ，于是计算图可以简化成下图：

![](https://img.papergate.top:5000/i/2025/12/694fe9d214f90.webp)

> [!info]  信息
> 有以下两点需要注意：
> 
> - 反向传播中， $y$ 对各变量的导数**从右向左传播**，$y$ 是“重要人物”
> - 进行反向传播时需要用到正向传播中使用的数据

#### 框架搭建

增加反向传播的内容，需要在 `Variable` 中增加梯度：

```python
class Variable:  
    def __init__(self, data):  
        self.data = data
        self.grad = None  # 增加梯度
```

`Function` 中需要保存输入的变量，增加反向传播方法：

```python
class Function:  
    def __call__(self, input):  
        x = input.data  
        y = self.forward(x)  
        output = Variable(y)  
        self.input = input  # 保存输入的变量  
        return output
  
    def forward(self, x):  
        raise NotImplementedError()  
  
    def backward(self, gy):  # 新增反向传播方法
        raise NotImplementedError()
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/694ff31519a59.webp)

### 动态计算图

从函数的角度来看，变量是以输入和输出的形式存在的，函数的变量包括“输入变量”（input）和“输出变量”（output）。从变量的角度来看，变量是由函数“创造”的。也就是说，函数是变量的“父母”，是creator（创造者）。

![](https://img.papergate.top:5000/i/2025/12/694ff416d3214.webp)

计算图由函数和变量之间的“连接”构成，这个“连接”是在计算实际发生的时候形成的，称为动态计算图（Define-by-Run）。

![](https://img.papergate.top:5000/i/2025/12/694ff5a4011db.webp)

> [!info]  信息
> 因为形成了上面的连接，我们就可以使用计算机中的递归思想，自动构建反向传播的过程，无需每次手动反向传播。

#### 框架搭建

变量类进行如下修改：：

```python
class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self, func):
        self.creator = func
        
    def backward(self):  
        funcs = [self.creator]  
        while funcs:  
            f = funcs.pop()  # 获取函数  
            x, y = f.input, f.output  # 获取函数的输入  
            x.grad = f.backward(y.grad)  # backward调用backward方法  
            if x.creator is not None:  
                funcs.append(x.creator)  # 将前一个函数添加到列表中
```

函数类进行如下修改：

```python
class Function:  
    def __call__(self, input):  
        x = input.data  
        y = self.forward(x)  
        output = Variable(y)  
        output.set_creator(self)  # 让输出变量保存创造者信息  
        self.input = input  
        self.output = output  # 也保存输出变量  
        return output  
  
    def forward(self, x):  
        raise NotImplementedError()

    def backward(self, gy):
        raise NotImplementedError()
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/694ff87893e43.webp)

## 框架优化

### 函数类包装

构建函数 `function` 对函数类 `Function` 进行包装：

#### 框架搭建

```python
def square(x):
    return Square()(x)

def exp(x):
    return Exp()(x)
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/694ffada88f95.webp)

### 简化 `backward` 方法

省略掉 `y.grad = np.array(1.0)`

#### 框架搭建

变量类进行如下修改：

```python
class Variable:  
    def __init__(self, data):  
        self.data = data  
        self.grad = None  
        self.creator = None  
  
    def set_creator(self, func):  
        self.creator = func  
  
    def backward(self):  
        if self.grad is None:  # 自动初始化 self.grid
            self.grad = np.ones_like(self.data)  
        funcs = [self.creator]  
        while funcs:  
            f = funcs.pop()
            x, y = f.input, f.output
            x.grad = f.backward(y.grad)
            if x.creator is not None:  
                funcs.append(x.creator)
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/694ffc1e592e8.webp)

### 只支持 `ndarray`

限制 `Variable` 的输入只能是 `ndarray` 类型

#### 框架搭建

变量类进行如下修改：

```python
class Variable:  
    def __init__(self, data):  
        if data is not None:  
            if not isinstance(data, np.ndarray):  # 只支持 ndarray
                raise TypeError('{} is not supported'.format(type(data)))  
        self.data = data  
        self.grad = None  
        self.creator = None  
  
    def set_creator(self, func):  
        self.creator = func  
  
    def backward(self):  
        if self.grad is None:  
            self.grad = np.ones_like(self.data)  
        funcs = [self.creator]  
        while funcs:  
            f = funcs.pop()  
            x, y = f.input, f.output  
            x.grad = f.backward(y.grad)  
            if x.creator is not None:  
                funcs.append(x.creator)
```

函数类进行如下修改：

```python
def as_array(x):  
    if np.isscalar(x):  
        return np.array(x)  
    return x

class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        output = Variable(as_array(y))
        output.set_creator(self)
        self.input = input
        self.output = output
        return output

    def forward(self, x):
        raise NotImplementedError()

    def backward(self, gy):
        raise NotImplementedError()
```

> [!info]  信息
> 如果输入 `x` 是 0 维的 `ndarray`，输出 `y` 会变成非 `ndarray`，比如 `np. float64`

## 多参数反向传播

输入、输出都可以是多个。

### 修改类

#### 框架搭建

变量类进行如下修改：

```python
class Variable:  
    def __init__(self, data):  
        if data is not None:  
            if not isinstance(data, np.ndarray):  
                raise TypeError('{} is not supported'.format(type(data)))  
        self.data = data  
        self.grad = None  
        self.creator = None  
  
    def set_creator(self, func):  
        self.creator = func  
  
    def backward(self):  
        if self.grad is None:  
            self.grad = np.ones_like(self.data)  
        funcs = [self.creator]  
        while funcs:  
            f = funcs.pop()  
            gys = [output.grad for output in f.outputs]  # 将输出变量 outputs 的导数汇总在列表中  
            gxs = f.backward(*gys)  # 调用了函数 f 的反向传播  
            if not isinstance(gxs, tuple):  # 当 gxs 不是元组时，将其转换为元组  
                gxs = (gxs,)  
            for x, gx in zip(f.inputs, gxs):  # 将反向传播中传播的导数设置为 Variable 的实例变量 grad                
            	x.grad = gx  
                if x.creator is not None:  
                    funcs.append(x.creator)
```

函数类进行如下修改：

```python
class Function:  
    def __call__(self, *inputs):  # 添加星号  
        xs = [x.data for x in inputs]  
        ys = self.forward(*xs)  # 使用星号解包  
        if not isinstance(ys, tuple):  # 对非元组情况的额外处理  
            ys = (ys,)  
        outputs = [Variable(as_array(y)) for y in ys]  
        for output in outputs:  
            output.set_creator(self)  
        self.inputs = inputs  
        self.outputs = outputs  
        # 如果列表中只有一个元素，则返回第1 个元素  
        return outputs if len(outputs) > 1 else outputs[0]  
  
    def forward(self, *xs):  
    raise NotImplementedError()  
  
	def backward(self, *gys):  
	    raise NotImplementedError()
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/695005d346201.webp)

### 重复使用同一个变量

![](https://img.papergate.top:5000/i/2025/12/695006a58942d.webp)

重复使用统一个变量时，梯度值应该累加，而不是覆盖。

#### 框架搭建

变量类进行如下修改：

```python
class Variable:  
    def __init__(self, data):  
        if data is not None:  
            if not isinstance(data, np.ndarray):  
                raise TypeError('{} is not supported'.format(type(data)))  
        self.data = data  
        self.grad = None  
        self.creator = None  
  
    def set_creator(self, func):  
        self.creator = func  
  
    def backward(self):  
        if self.grad is None:  
            self.grad = np.ones_like(self.data)  
        funcs = [self.creator]  
        while funcs:  
            f = funcs.pop()  
            gys = [output.grad for output in f.outputs]  
            gxs = f.backward(*gys)  
            if not isinstance(gxs, tuple):  
                gxs = (gxs,)  
            for x, gx in zip(f.inputs, gxs):  
                if x.grad is None:  # 梯度值累加  
                    x.grad = gx  
                else:  
                    x.grad = x.grad + gx  
                if x.creator is not None:  
                    funcs.append(x.creator)
```

### 重置梯度

进行多次反向传播时，之前修改梯度值为累加，所以梯度值不会自动清除，需要添加清除函数。

#### 框架搭建

变量类进行如下修改：

```python
class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self, func):  # 清除梯度函数
        self.creator = func

    def cleargrad(self):
        self.grad = None

    def backward(self):
        if self.grad is None:
            self.grad = np.ones_like(self.data)
        funcs = [self.creator]
        while funcs:
            f = funcs.pop()
            gys = [output.grad for output in f.outputs]
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                if x.creator is not None:
                    funcs.append(x.creator)
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/695009a3c1583.webp)

## 复杂计算图

对于复杂的计算图，比如下面这张计算图：

![](https://img.papergate.top:5000/i/2025/12/69500a4a9a2ec.webp)

正确的反向传播顺序如下：

![](https://img.papergate.top:5000/i/2025/12/69500a71d7a1d.webp)

但是按照我们之前创建的程序，反向传播的顺序如下：

![](https://img.papergate.top:5000/i/2025/12/69500aafd6003.webp)

按照程序默认的顺序进行反向传播无法得出正确的结果。

### 函数优先级

在反向传播时，如果按照从后代到先代的顺序处理，就可以保证“子辈”在“父辈”之前被取出，如下图：

![](https://img.papergate.top:5000/i/2025/12/69500b46060da.webp)

增加辈分变量就可以解决，如下图：

![](https://img.papergate.top:5000/i/2025/12/69500c993c3f9.webp)

#### 框架搭建

变量类进行如下修改：

```python
class Variable:  
    def __init__(self, data):  
        if data is not None:  
            if not isinstance(data, np.ndarray):  
                raise TypeError('{} is not supported'.format(type(data)))  
        self.data = data  
        self.grad = None  
        self.creator = None  
        self.generation = 0  # 增加辈分变量  
  
    def set_creator(self, func):  
        self.creator = func  
        self.generation = func.generation + 1  # 辈分增加  
  
    def cleargrad(self):  
        self.grad = None  
  
    def backward(self):  
        if self.grad is None:  
            self.grad = np.ones_like(self.data)  
        funcs = []  
        seen_set = set()  
  
        # 按照辈分对 funcs 进行排序  
        def add_func(f):  
            if f not in seen_set:  
                funcs.append(f)  
                seen_set.add(f)  
                funcs.sort(key=lambda x: x.generation)  
  
        add_func(self.creator)  
  
        while funcs:  
            f = funcs.pop()  
            gys = [output.grad for output in f.outputs]  
            gxs = f.backward(*gys)  
            if not isinstance(gxs, tuple):  
                gxs = (gxs,)  
            for x, gx in zip(f.inputs, gxs):  
                if x.grad is None:  
                    x.grad = gx  
                else:  
                    x.grad = x.grad + gx  
                if x.creator is not None:  
                    add_func(x.creator)
```

函数类进行如下修改：

```python
class Function:  
    def __call__(self, *inputs):  
        xs = [x.data for x in inputs]  
        ys = self.forward(*xs)  
        if not isinstance(ys, tuple):  
            ys = (ys,)  
        outputs = [Variable(as_array(y)) for y in ys]  
        self.generation = max([x.generation for x in inputs])  # 取最大的辈分  
        for output in outputs:  
            output.set_creator(self)  
        self.inputs = inputs  
        self.outputs = outputs  
        return outputs if len(outputs) > 1 else outputs[0]  
  
    def forward(self, *xs):  
        raise NotImplementedError()  
  
    def backward(self, *gys):  
        raise NotImplementedError()
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/69500ffea5f5c.webp)

![](https://img.papergate.top:5000/i/2025/12/69500fe5ebe6f.webp)

## 框架优化
### 循环引用

Python 会自动从内存中删除不再需要的对象，但是，如果代码写得不好，就可能出现内存泄漏或内存不足等情况。

Python 使用两种方式管理内存：一种是引用计数，另一种是分代垃圾回收。这里我们把后者称为 GC（Garbage Collection，垃圾回收）。

引用计数的机制很简单，每个对象在被创建时的引用计数为 0，当它被另一个对象引用时，引用计数加 1，当引用停止时，引用计数减 1。最终，当引用计数变为 0 时，Python 解释器会回收该对象。

![](https://img.papergate.top:5000/i/2025/12/69501197c3abf.webp)

右图中的 a、b、c 的引用计数均为 1。这时用户已无法访问这 3 个对象，如果只设置了 `a = b = c =None`，那么此时因为循环引用，引用计数不会为 0，对象也不会从内存中释放出来。这时就需要使用 GC 了。

GC 能够正确处理循环引用。因此在使用 Python 编程时，我们通常不需要关心循环引用。不过，使用 GC 推迟内存释放会导致程序整体的内存使用量增加。内存是机器学习，尤其是神经网络运算时的重要资源，因此建议避免循环引用。

![](https://img.papergate.top:5000/i/2025/12/6950135742bc8.webp)

Function 实例引用了输入和输出的 Variable 实例，同时， Variable 实例也引用了作为创建者的 Function 实例。这时，Function 实例和 variable 实例之间就存在循环引用关系。我们可以使用 Python 标准模块 weakref 来避免循环引用。

#### 框架搭建

变量类进行如下修改：

```python
class Variable:  
    def __init__(self, data):  
        if data is not None:  
            if not isinstance(data, np.ndarray):  
                raise TypeError('{} is not supported'.format(type(data)))  
        self.data = data  
        self.grad = None  
        self.creator = None  
        self.generation = 0  
  
    def set_creator(self, func):  
        self.creator = func  
        self.generation = func.generation + 1  
  
    def cleargrad(self):  
        self.grad = None  
  
    def backward(self):  
        if self.grad is None:  
            self.grad = np.ones_like(self.data)  
        funcs = []  
        seen_set = set()  
  
        def add_func(f):  
            if f not in seen_set:  
                funcs.append(f)  
                seen_set.add(f)  
                funcs.sort(key=lambda x: x.generation)  
  
        add_func(self.creator)  
  
        while funcs:  
            f = funcs.pop()  
            # gys = [output.grad for output in f.outputs]  
            gys = [output().grad for output in f.outputs]  
            gxs = f.backward(*gys)  
            if not isinstance(gxs, tuple):  
                gxs = (gxs,)  
            for x, gx in zip(f.inputs, gxs):  
                if x.grad is None:  
                    x.grad = gx  
                else:  
                    x.grad = x.grad + gx  
                if x.creator is not None:  
                    add_func(x.creator)
```

函数类进行如下修改：

```python
class Function:  
    def __call__(self, *inputs):  
        xs = [x.data for x in inputs]  
        ys = self.forward(*xs)  
        if not isinstance(ys, tuple):  
            ys = (ys,)  
        outputs = [Variable(as_array(y)) for y in ys]  
        self.generation = max([x.generation for x in inputs])  
        for output in outputs:  
            output.set_creator(self)  
        self.inputs = inputs  
        self.outputs = [weakref.ref(output) for output in outputs]  # 增加弱引用  
        return outputs if len(outputs) > 1 else outputs[0]  
  
    def forward(self, *xs):  
        raise NotImplementedError()  
  
    def backward(self, *gys):  
        raise NotImplementedError()
```

### 是否保留梯度

在反向传播时，只有输入变量的梯度才需要保留。

#### 框架搭建

变量类进行如下修改：

```python
class Variable:  
    def __init__(self, data):  
        if data is not None:  
            if not isinstance(data, np.ndarray):  
                raise TypeError('{} is not supported'.format(type(data)))  
        self.data = data  
        self.grad = None  
        self.creator = None  
        self.generation = 0  
  
    def set_creator(self, func):  
        self.creator = func  
        self.generation = func.generation + 1  
  
    def cleargrad(self):  
        self.grad = None  
  
    def backward(self, retain_grad=False):  # 增加保留梯度参数  
        if self.grad is None:  
            self.grad = np.ones_like(self.data)  
        funcs = []  
        seen_set = set()  
  
        def add_func(f):  
            if f not in seen_set:  
                funcs.append(f)  
                seen_set.add(f)  
                funcs.sort(key=lambda x: x.generation)  
  
        add_func(self.creator)  
  
        while funcs:  
            f = funcs.pop()  
            gys = [output().grad for output in f.outputs]  
            gxs = f.backward(*gys)  
            if not isinstance(gxs, tuple):  
                gxs = (gxs,)  
            for x, gx in zip(f.inputs, gxs):  
                if x.grad is None:  
                    x.grad = gx  
                else:  
                    x.grad = x.grad + gx  
                if x.creator is not None:  
                    add_func(x.creator)  
            if not retain_grad:  # 保留梯度
                for y in f.outputs:  
                    y().grad = None  # y 是弱引用
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/6950162ba0c38.webp)

### 是否反向传播

只有训练的时候需要反向传播，验证和测试的时候不需要。

#### 框架搭建

创建 `Config` 类

```python
import contextlib

class Config:  # 创建Config 类，定义是否启用反向传播
    enable_backprop = True

# 创建一个 With 上下文环境，在该上下文中，对 Config 进行修改
@contextlib.contextmanager  
def using_config(name, value):  
    old_value = getattr(Config, name)  
    setattr(Config, name, value)  
    try:  
        yield  
    finally:  
        setattr(Config, name, old_value)
        
# 创建 no_grad 函数，返回 enable_backprop 为 false 的上下文环境
def no_grad():  
    return using_config('enable_backprop', False)
```

函数类进行如下修改：

```python
class Function:  
    def __call__(self, *inputs):  
        xs = [x.data for x in inputs]  
        ys = self.forward(*xs)  
        if not isinstance(ys, tuple):  
            ys = (ys,)  
        outputs = [Variable(as_array(y)) for y in ys]  
        if Config.enable_backprop:  # 是否启用反向传播
            self.generation = max([x.generation for x in inputs])  
            for output in outputs:  
                output.set_creator(self)  
            self.inputs = inputs  
            self.outputs = [weakref.ref(output) for output in outputs]  
        return outputs if len(outputs) > 1 else outputs[0]  
  
    def forward(self, *xs):  
        raise NotImplementedError()  
  
    def backward(self, *gys):  
        raise NotImplementedError()
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/6950202eb44a6.webp)

### 变量类属性方法

给变量类增加 `name` 、`shape`、`ndim`、`size`、`dtype` 属性， `len`、`print` 方法

#### 框架搭建

```python
class Variable:  
    def __init__(self, data, name=None):  
        if data is not None:  
            if not isinstance(data, np.ndarray):  
                raise TypeError('{} is not supported'.format(type(data)))  
        self.data = data  
        self.name = name  # 增加 name 属性  
        self.grad = None  
        self.creator = None  
        self.generation = 0  
  
    @property  
    def shape(self):  # 增加 shape 属性  
        return self.data.shape  
  
    @property  
    def ndim(self):  # 增加 ndim 属性  
        return self.data.ndim  
  
    @property  
    def size(self):  # 增加 size 属性  
        return self.data.size  
  
    @property  
    def dtype(self):  # 增加 dtype 属性  
        return self.data.dtype  
  
    def __len__(self):  # 增加 len 方法  
        return len(self.data)  
  
    def __repr__(self):  # 增加 print 方法  
        if self.data is None:  
            return 'variable(None)'  
  
        p = str(self.data).replace('\n', '\n' + ' ' * 9)  
        return 'variable(' + p + ')'  
  
    def set_creator(self, func):  
        self.creator = func  
        self.generation = func.generation + 1  
  
    def cleargrad(self):  
        self.grad = None  
  
    def backward(self, retain_grad=False):  # 增加梯度保留参数  
        if self.grad is None:  
            self.grad = np.ones_like(self.data)  
        funcs = []  
        seen_set = set()  
  
        def add_func(f):  
            if f not in seen_set:  
                funcs.append(f)  
                seen_set.add(f)  
                funcs.sort(key=lambda x: x.generation)  
  
        add_func(self.creator)  
  
        while funcs:  
            f = funcs.pop()  
            gys = [output().grad for output in f.outputs]  
            gxs = f.backward(*gys)  
            if not isinstance(gxs, tuple):  
                gxs = (gxs,)  
            for x, gx in zip(f.inputs, gxs):  
                if x.grad is None:  
                    x.grad = gx  
                else:  
                    x.grad = x.grad + gx  
                if x.creator is not None:  
                    add_func(x.creator)  
            if not retain_grad:  
                for y in f.outputs:  
                    y().grad = None
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/695023a1ca009.webp)

### 运算符重载

重载运算符，实现类似 $a * b + c$ 的效果，并且左右两侧有一边是 `int`、`float` 或者 `ndarray` 时都支持运算。

#### 框架搭建

创建运算符 `Function`：

```python
class Mul(Function):  
    def forward(self, x0, x1):  
        y = x0 * x1  
        return y  
  
    def backward(self, gy):  
        x0, x1 = self.inputs[0].data, self.inputs[1].data  
        return gy * x1, gy * x0  
  
  
class Neg(Function):  
    def forward(self, x):  
        return -x  
  
    def backward(self, gy):  
        return -gy  
  
  
class Sub(Function):  
    def forward(self, x0, x1):  
        y = x0 - x1  
        return y  
  
    def backward(self, gy):  
        return gy, -gy  
  
  
class Div(Function):  
    def forward(self, x0, x1):  
        y = x0 / x1  
        return y  
  
    def backward(self, gy):  
        x0, x1 = self.inputs[0].data, self.inputs[1].data  
        gx0 = gy / x1  
        gx1 = gy * (-x0 / x1 ** 2)  
        return gx0, gx1  
  
  
class Pow(Function):  
    def __init__(self, c):  
        self.c = c  
  
    def forward(self, x):  
        y = x ** self.c  
        return y  
  
    def backward(self, gy):  
        x = self.inputs[0].data  
        c = self.c  
        gx = c * x ** (c - 1) * gy  
        return gx  
  

def add(x0, x1):  
    x0 = as_array(x0)  
    x1 = as_array(x1)  
    return Add()(x0, x1)  
  
  
def mul(x0, x1):  
    x0 = as_array(x0)  
    x1 = as_array(x1)  
    return Mul()(x0, x1)  
  
  
def neg(x):  
    return Neg()(x)  
  
  
def sub(x0, x1):  
    x1 = as_array(x1)  
    return Sub()(x0, x1)  
  
  
def rsub(x0, x1):  
    x1 = as_array(x1)  
    return Sub()(x1, x0)  
  
  
def div(x0, x1):  
    x1 = as_array(x1)  
    return Div()(x0, x1)  
  
  
def rdiv(x0, x1):  
    x1 = as_array(x1)  
    return Div()(x1, x0)  
  
  
def pow(x, c):  
    return Pow(c)(x)
    
def setup_variable():  
    Variable.__add__ = add  
    Variable.__radd__ = add  
    Variable.__mul__ = mul  
    Variable.__rmul__ = mul  
    Variable.__neg__ = neg  
    Variable.__sub__ = sub  
    Variable.__rsub__ = rsub  
    Variable.__truediv__ = div  
    Variable.__rtruediv__ = rdiv  
    Variable.__pow__ = pow
    
setup_variable()
```

变量类进行如下修改：

```python
class Variable:  
    __array_priority__ = 200  # 优先级 200,Variable 右侧运算的优先级要高于 ndarray 左侧运算的优先级  
  
    def __init__(self, data, name=None):  
        if data is not None:  
            if not isinstance(data, np.ndarray):  
                raise TypeError('{} is not supported'.format(type(data)))  
        self.data = data  
        self.name = name  
        self.grad = None  
        self.creator = None  
        self.generation = 0  
  
    @property  
    def shape(self):  
        return self.data.shape  
  
    @property  
    def ndim(self):  
        return self.data.ndim  
  
    @property  
    def size(self):  
        return self.data.size  
  
    @property  
    def dtype(self):  
        return self.data.dtype  
  
    def __len__(self):  
        return len(self.data)  
  
    def __repr__(self):  
        if self.data is None:  
            return 'variable(None)'  
  
        p = str(self.data).replace('\n', '\n' + ' ' * 9)  
        return 'variable(' + p + ')'  
  
    def set_creator(self, func):  
        self.creator = func  
        self.generation = func.generation + 1  
  
    def cleargrad(self):  
        self.grad = None  
  
    def backward(self, retain_grad=False): 
        if self.grad is None:  
            self.grad = np.ones_like(self.data)  
        funcs = []  
        seen_set = set()  
  
        def add_func(f):  
            if f not in seen_set:  
                funcs.append(f)  
                seen_set.add(f)  
                funcs.sort(key=lambda x: x.generation)  
  
        add_func(self.creator)  
  
        while funcs:  
            f = funcs.pop()  
            gys = [output().grad for output in f.outputs]  
            gxs = f.backward(*gys)  
            if not isinstance(gxs, tuple):  
                gxs = (gxs,)  
            for x, gx in zip(f.inputs, gxs):  
                if x.grad is None:  
                    x.grad = gx  
                else:  
                    x.grad = x.grad + gx  
                if x.creator is not None:  
                    add_func(x.creator)  
            if not retain_grad:  
                for y in f.outputs:  
                    y().grad = None
```

定义函数 `as_variable`，当运算符一侧不是 `Variable` 时转换为 `Variable`：

```python
def as_variable(obj):  
    if isinstance(obj, Variable):  
        return obj  
    return Variable(obj)
```

函数类进行如下修改：

```python
class Function:  
    def __call__(self, *inputs):  
        inputs = [as_variable(x) for x in inputs]  # 转成 Variable
        xs = [x.data for x in inputs]  
        ys = self.forward(*xs)  
        if not isinstance(ys, tuple):  
            ys = (ys,)  
        outputs = [Variable(as_array(y)) for y in ys]  
        if Config.enable_backprop:  
            self.generation = max([x.generation for x in inputs])  
            for output in outputs:  
                output.set_creator(self)  
            self.inputs = inputs  
            self.outputs = [weakref.ref(output) for output in outputs]  
        return outputs if len(outputs) > 1 else outputs[0]  
  
    def forward(self, *xs):  
        raise NotImplementedError()  
  
    def backward(self, *gys):  
        raise NotImplementedError()
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/695032f0dab67.webp)

### 框架打包

#### 框架搭建

将之前的代码打包到 `dezero/core_simple.py` 文件，并创建 `dezero/__init__.py` 文件：

```python
is_simple_core = True  
if is_simple_core:  
    from dezero.core_simple import Variable  
    from dezero.core_simple import Function  
    from dezero.core_simple import using_config  
    from dezero.core_simple import no_grad  
    from dezero.core_simple import as_array  
    from dezero.core_simple import as_variable  
    from dezero.core_simple import setup_variable  
  
setup_variable()
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/6950393d8d96a.webp)

> [!info]  信息
> 到这儿就有一些 `pytorch` 的雏形了。

### 计算图可视化

使用 `Graphviz` 将计算图可视化

```shell
brew install graphviz
```

`Graphviz` 可以将 DOT 语言翻译成图像，将下面的文字存储成 `xxx.dot` 文件：

```dot
digraph g {
1 [label="x", color=orange, style=filled]
2 [label="y", color=orange, style=filled]
3 [label="Exp", color=lightblue, style=filled, shape=box]
1 -> 3
3 -> 2
}
```

![](https://img.papergate.top:5000/i/2025/12/69503cad5ce6c.webp)

使用方法是：

```shell
dot xxx.dot -T png -o xxx.png
```

其中 `-T` 选项后指定要输出的文件扩展名，扩展名可以指定为pdf、svg 等。

#### 架构搭建

在 `dezero/utils.py` 中实现 `get_dot_graph` 函数：

```python
import os  
import subprocess  
  
  
def _dot_var(v, verbose=False):  
    name = '' if v.name is None else v.name  
    if verbose and v.data is not None:  
        if v.name is not None:  
            name += ': '  
        name += f'{v.shape} {v.dtype}'  
    dot_var = f'{id(v)} [label="{name}", color=orange, style=filled]\n'  
  
    return dot_var.format(id(v), name)  
  
  
def _dot_func(f):  
    txt = f'{id(f)} [label="{f.__class__.__name__}", color=lightblue, style=filled,shape=box]\n'  
    for x in f.inputs:  
        txt += f'{id(x)} -> {id(f)}\n'  
    for y in f.outputs:  
        txt += f'{id(f)} -> {id(y())}\n'  # y是weakref  
    return txt  
  
  
def get_dot_graph(output, verbose=True):  
    txt = ''  
    funcs = []  
    seen_set = set()  
  
    def add_func(f):  
        if f not in seen_set:  
            funcs.append(f)  
            seen_set.add(f)  
  
    add_func(output.creator)  
    txt += _dot_var(output, verbose)  
    while funcs:  
        func = funcs.pop()  
        txt += _dot_func(func)  
        for x in func.inputs:  
            txt += _dot_var(x, verbose)  
            if x.creator is not None:  
                add_func(x.creator)  
    return 'digraph g {\n' + txt + '}'  
  
  
def plot_dot_graph(output, verbose=True, to_file='graph.png'):  
    dot_graph = get_dot_graph(output, verbose)  
    graph_path = './sample.dot'  
    with open(graph_path, 'w') as f:  
        f.write(dot_graph)  
    extension = os.path.splitext(to_file)[1][1:]  
    cmd = f'dot {graph_path} -T {extension} -o {to_file}'  
    subprocess.run(cmd, shell=True)
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/695045a7652dc.webp)

![](https://img.papergate.top:5000/i/2025/12/695045bd5fce8.webp)

## 高阶导数

### 梯度下降法

求解 `Rosenbrock` 函数的最小值所在位置，其表达式为：

$$
y=100\left(x_1-x_0^2\right)^2+\left(x_0-1\right)^2
$$

#### 框架搭建

```python
import numpy as np  
from dezero import Variable

def rosenbrock(x0, x1):
    y = 100 * (x1 - x0 ** 2) ** 2 + (x0 - 1) ** 2
    return y

x0 = Variable(np.array(0.0))
x1 = Variable(np.array(2.0))
lr = 0.001 # 学习率
iters = 1000 # 迭代次数
for i in range(iters):
    if not (i+1) % 100:
        print(x0, x1)
    y = rosenbrock(x0, x1)
    x0.cleargrad()
    x1.cleargrad()
    y.backward()
    x0.data -= lr * x0.grad
    x1.data -= lr * x1.grad
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/69504969592f6.webp)

### 高阶导数

对于 $y=sin(x)$ 有如下计算图  `x -> y`：

![](https://img.papergate.top:5000/i/2025/12/69504c7028038.webp)

在反向传播的过程中，会调用 `y.backward()`，对于 $y=sin(x)$，就是计算 `gx = gy * np.cos(x)` 的过程。所以说， `gx` 是 `x` 的函数，上述过程也可以构建出一张计算图，如下：

![](https://img.papergate.top:5000/i/2025/12/69504fe1d1fb7.webp)

在这张  `x -> gx` 的计算图中，调用 `gx.backward()`，就可以计算 `x` 的二阶导数。

不过，通常计算图只在正向传播的过程中建立，在反向传播时不会被创建。如果我们在反向传播的过程中也建立计算图，就可以求高阶导数。

#### 框架搭建

变量类进行如下修改：

```python
class Variable:  
    def backward(self, retain_grad=False, create_graph=False):  
        if self.grad is None:  
            # self.grad = np.ones_like(self.data)  
            self.grad = Variable(np.ones_like(self.data))  
        funcs = []  
        seen_set = set()  
  
        def add_func(f):  
            if f not in seen_set:  
                funcs.append(f)  
                seen_set.add(f)  
                funcs.sort(key=lambda x: x.generation)  
  
        add_func(self.creator)  
  
        while funcs:  
            f = funcs.pop()  
            gys = [output().grad for output in f.outputs]  
            with using_config('enable_backprop', create_graph):  
                gxs = f.backward(*gys)  # 主要的backward处理  
                if not isinstance(gxs, tuple):  
                    gxs = (gxs,)  
                for x, gx in zip(f.inputs, gxs):  
                    if x.grad is None:  
                        x.grad = gx  
                    else:  
                        x.grad = x.grad + gx  # 这个计算也是对象  
                    if x.creator is not None:  
                        add_func(x.creator)  
            if not retain_grad:  
                for y in f.outputs:  
                    y().grad = None
```

> [!info]  信息
> 将梯度修改成 `Variable` 变量类，设置 `enable_backprop` 是否反向传播，可以控制梯度是否建立计算图。

运算符进行如下修改：

```python
class Mul(Function):  
    def forward(self, x0, x1):  
        y = x0 * x1  
        return y  
  
    def backward(self, gy):  
        x0, x1 = self.inputs  
        return gy * x1, gy * x0

    
class Div(Function):  
    def forward(self, x0, x1):  
        y = x0 / x1  
        return y  
  
    def backward(self, gy):  
        x0, x1 = self.inputs  
        gx0 = gy / x1  
        gx1 = gy * (-x0 / x1 ** 2)  
        return gx0, gx1


class Pow(Function):  
    def __init__(self, c):  
        self.c = c  
  
    def forward(self, x):  
        y = x ** self.c  
        return y  
  
    def backward(self, gy):  
        x, = self.inputs 
        c = self.c  
        gx = c * x ** (c - 1) * gy  
        return gx
```

> [!info]  信息
> `Variable` 变量类的运算符已经重载过，修改反向传播函数即可适配 `Variable` 变量类。

#### 示例

![](https://img.papergate.top:5000/i/2025/12/69505f13aa40d.webp)

## 框架优化

### 高级函数

在 `dezero/functions.py` 中创建高级函数：

```python
import numpy as np  
from dezero.core import Function  
  
  
class Sin(Function):  
    def forward(self, x):  
        y = np.sin(x)  
        return y  
  
    def backward(self, gy):  
        x, = self.inputs  
        gx = gy * cos(x)  
        return gx  
  
  
def sin(x):  
    return Sin()(x)  
  
  
class Cos(Function):  
    def forward(self, x):  
        y = np.cos(x)  
        return y  
  
    def backward(self, gy):  
        x, = self.inputs  
        gx = gy * -sin(x)  
        return gx  
  
  
def cos(x):  
    return Cos()(x)  
  
  
class Tanh(Function):  
    def forward(self, x):  
        y = np.tanh(x)  
        return y  
  
    def backward(self, gy):  
        y = self.outputs[0]()  
        gx = gy * (1 - y * y)  
        return gx  
  
  
def tanh(x):  
    return Tanh()(x)
```

### `reshape`

#### 框架搭建

在 `dezero/function.py` 中进行如下修改：

```python
class Reshape(Function):  
    def __init__(self, shape):  
        self.shape = shape  
  
    def forward(self, x):  
        self.x_shape = x.shape  
        y = x.reshape(self.shape)  
        return y  
  
    def backward(self, gy):  
        return reshape(gy, self.x_shape)  
  
  
def reshape(x, shape):  
    if x.shape == shape:  
        return as_variable(x)  
    return Reshape(shape)(x)
```

在 `dezero/core.py` 中进行如下修改：

```python
def reshape(self, *shape):  
    if len(shape) == 1 and isinstance(shape[0], (tuple, list)):  
        shape = shape[0]  
    return dezero.functions.reshape(self, shape)
```

### `transpose`

#### 框架搭建

在 `dezero/function.py` 中进行如下修改：

```python
class Transpose(Function):  
    def __init__(self, axes=None):  
        self.axes = axes  
  
    def forward(self, x):  
        y = x.transpose(self.axes)  
        return y  
  
    def backward(self, gy):  
        if self.axes is None:  
            return transpose(gy)  
  
        axes_len = len(self.axes)  
        inv_axes = tuple(np.argsort([ax % axes_len for ax in self.axes]))  
        return transpose(gy, inv_axes)  
  
  
def transpose(x, axes=None):  
    return Transpose(axes)(x)
```

> [!info]  信息
> 对 `axes` 进行 `np.argsort` 操作，正好可以得到 `inv_axes`，`transpose` 两次正好可以还原回去，`% axes_len` 是为了保证 `ax` 为正，因为 `ax` 可以传入 -1 之类的复数。

在 `dezero/core.py` 中进行如下修改：

```python
import dezero

class Variable:
	def transpose(self, *axes):  
    if len(axes) == 0:  
        axes = None  
    elif len(axes) == 1:  
        if isinstance(axes[0], (tuple, list)) or axes[0] is None:  
            axes = axes[0]  
    return dezero.functions.transpose(self, axes)  
  
	@property  
	def T(self):  
	    return dezero.functions.transpose(self)  
```

### `sum & boardcast`

在求和 `sum` 函数中， ` axis ` 用于指定求和时的轴，如下：

![](https://img.papergate.top:5000/i/2025/12/6950fa3391e9a.webp)

`keepdims` 用于指定输入和输出是否应具有相同维度（轴的数量）。

![](https://img.papergate.top:5000/i/2025/12/6950fb0567e27.webp)

广播 `boardcast` 函数，复制输入的元素，并将其形状变为 `shape` 的形状：

![](https://img.papergate.top:5000/i/2025/12/69510405260cb.webp)

> [!info]  信息
> 广播从**最后一维**开始对齐，两个维度必须相等或者广播前的维度为 1。

广播 `boardcast` 的反向传播，需要实现名为 `sum_to` 的函数，其求和输入的元素，将其形状变为 `shape` 的形状：

![](https://img.papergate.top:5000/i/2025/12/6950fcea7096c.webp)

求和 `sum_to` 的反向传播，就是 `boardcast_to`：

![](https://img.papergate.top:5000/i/2025/12/6950fd5fe3513.webp)

同样，求和 `sum` 的反向传播，也是 `boardcast_to`。

#### 框架搭建

在 `dezero/function.py` 中进行如下修改：

```python
from dezero import utils

class Sum(Function):  
    def __init__(self, axis, keepdims):  
        self.axis = axis  
        self.keepdims = keepdims  
  
    def forward(self, x):  
        self.x_shape = x.shape  
        y = x.sum(axis=self.axis, keepdims=self.keepdims)  
        return y  
  
    def backward(self, gy):  
        gy = utils.reshape_sum_backward(gy, self.x_shape, self.axis, self.keepdims)  
        gx = broadcast_to(gy, self.x_shape)  
        return gx  


def sum(x, axis=None, keepdims=False):  
    return Sum(axis, keepdims)(x)

class SumTo(Function):  
    def __init__(self, shape):  
        self.shape = shape  
  
    def forward(self, x):  
        self.x_shape = x.shape  
        y = utils.sum_to(x, self.shape)  
        return y  
  
    def backward(self, gy):  
        gx = broadcast_to(gy, self.x_shape)  
        return gx  
  
  
def sum_to(x, shape):  
    if x.shape == shape:  
        return as_variable(x)  
  
    return SumTo(shape)(x)

class BroadcastTo(Function):  
    def __init__(self, shape):  
        self.shape = shape  
  
    def forward(self, x):  
        self.x_shape = x.shape  
        y = np.broadcast_to(x, self.shape)  
        return y  
  
    def backward(self, gy):  
        gx = sum_to(gy, self.x_shape)  
        return gx  
  
  
def broadcast_to(x, shape):  
    if x.shape == shape:  
        return as_variable(x)  
    return BroadcastTo(shape)(x)
```

需要在 `dezero/utils.py` 中实现 `reshape_sum_backward` 方法和 `sum_to` 方法：

```python
def sum_to(x, shape):  
    ndim = len(shape)  
    lead = x.ndim - ndim  # 前导维度：broadcast过程中自动补的维度
    lead_axis = tuple(range(lead))  # 前导维度所在轴
  	
  	# 维度为 1 的轴
    axis = tuple([i + lead for i, sx in enumerate(shape) if sx == 1])  
    # 对指定的轴（lead_axis 和 axis）进行求和
    y = x.sum(lead_axis + axis, keepdims=True)  
    if lead > 0: 
        y = y.squeeze(lead_axis)  # 去掉前导维度  
    return y  
  
  
def reshape_sum_backward(gy, x_shape, axis, keepdims):  
    ndim = len(x_shape)
    tupled_axis = axis
    if axis is None:  
        tupled_axis = None  
    elif not isinstance(axis, tuple):  
        tupled_axis = (axis,)  
  
    if not (ndim == 0 or tupled_axis is None or keepdims):
    	# sum 过程中消失的轴
        actual_axis = [a % ndim for a in tupled_axis]  
        shape = list(gy.shape)  
        for a in sorted(actual_axis):  
            shape.insert(a, 1)  # 消失的轴填充 1,用来广播
    else:  
        shape = gy.shape
```

在变量类 `Variable` 中添加 `sum` 方法：

```python
class Variable:
	def sum(self, axis=None, keepdims=False):  
    	return dezero.functions.sum(self, axis, keepdims)
```

## 张量

### 链式法则

对于函数 $\boldsymbol{y}=F(\boldsymbol{x})$ ，其中 $\boldsymbol{x}$ 和 $\boldsymbol{y}$ 是向量，假设这两个向量的元素数都是 $n$ 。 $\boldsymbol{y}$ 对 $\boldsymbol{x}$ 的导数可通过以下式子定义：

$$
\begin{aligned}
\frac{\partial y}{\partial x}
&=
\left(
\begin{array}{cccc}
\frac{\partial y_1}{\partial x_1} & \frac{\partial y_1}{\partial x_2} & \cdots & \frac{\partial y_1}{\partial x_n} \\
\frac{\partial y_2}{\partial x_1} & \frac{\partial y_2}{\partial x_2} & \cdots & \frac{\partial y_2}{\partial x_n} \\
\vdots & \vdots & \ddots & \vdots \\
\frac{\partial y_n}{\partial x_1} & \frac{\partial y_n}{\partial x_2} & \cdots & \frac{\partial y_n}{\partial x_n}
\end{array}
\right)
\end{aligned}
$$

这个矩阵称为雅可比矩阵。

如果 $\boldsymbol{y}$ 不是向量而是标量，那么 $\boldsymbol{y}$ 对 $\boldsymbol{x}$ 的导数就是下面这样。

$$
\frac{\partial y}{\partial x}=\left(\begin{array}{llll}
\frac{\partial y}{\partial x_1} & \frac{\partial y}{\partial x_2} & \cdots & \frac{\partial y}{\partial x_n}
\end{array}\right)
$$


这是一个 $1 \times n$ 的雅可比矩阵，可以将它看作一个行向量。

接下来思考复合函数。假设有复合函数 $\boldsymbol{y}=F(\boldsymbol{x})$ ，它由 3 个函数复合而成，分别是 $\boldsymbol{a}=A(\boldsymbol{x}), \boldsymbol{b}=B(\boldsymbol{a}), y=C(\boldsymbol{b})$ 。假设变量 $\boldsymbol{x}$ 、$\boldsymbol{a}$ 、 $\boldsymbol{b}$ 都是向量，它们的元素数为 $n$ ，只有最终的输出 $y$ 是标量。那么，基于链式法则，$y$ 对 $x$ 的导数可以表示如下。

$$
\frac{\partial y}{\partial x}=\frac{\partial y}{\partial b} \frac{\partial b}{\partial a} \frac{\partial a}{\partial x}
$$

### 矩阵乘积

假设有向量 $\boldsymbol{a}=\left(a_1, \cdots, a_n\right)$ 和向量 $\boldsymbol{b}=\left(b_1, \cdots, b_n\right)$ 。向量的内积可以定义为：

$$
\boldsymbol{a} \boldsymbol{b}=a_1 b_1+a_2 b_2+\cdots+a_n b_n
$$

![](https://img.papergate.top:5000/i/2025/12/69510a52f1a64.webp)

矩阵乘积的计算方法是先分别求出左侧矩阵水平方向的向量和右侧矩阵垂直方向的向量的内积，然后将结果存储在新矩阵的相应元素中。

下面以 $\boldsymbol{y}=\boldsymbol{x} \boldsymbol{W}$ 为例介绍矩阵乘积的反向传播。在该计算中， $\boldsymbol{x}$ 、 $\boldsymbol{W}$ 和 $\boldsymbol{y}$ 的形状分别为 $1 \times D$ 、 $D \times H$ 和 $1 \times H$ ：

![](https://img.papergate.top:5000/i/2025/12/69510ad915031.webp)

假定计算最终输出的标量是 $L$（通过反向传播求 $L$ 对每个变量的导数），此时，$L$ 对 $\boldsymbol{x}$ 的第 $i$ 个元素的导数 $\frac{\partial L}{\partial x_i}$ 的式子如下所示：

$$
\frac{\partial L}{\partial x_i}=\sum_j \frac{\partial L}{\partial y_j} \frac{\partial y_j}{\partial x_i}
$$

 $\frac{\partial L}{\partial x_i}$ 表示当 $x_i$ 发生微小的变化时 $L$ 的变化程度。当 $x_i$ 发生变化时，向量 $\boldsymbol{y}$ 的所有元素也会发生改变， $\boldsymbol{y}$ 的每个元素的改变也会使 $L$ 最终发生变化。因此，从 $x_i$ 到 $L$ 有多条链式法则的路径，其总和为 $\frac{\partial L}{\partial x_i}$ 。

展开 $\boldsymbol{y}$ 的第 $j$ 个元素，有 

$$
y_j=x_1 W_{1 j}+x_2 W_{2 j}+\cdots+x_i W_{i j}+\cdots+x_H W_{H j}
$$

 由此可知，$\frac{\partial y_j}{\partial x_i}=W_{i j}$ ，于是：

$$
\frac{\partial L}{\partial x_i}=\sum_j \frac{\partial L}{\partial y_j} \frac{\partial y_j}{\partial x_i}=\sum_j \frac{\partial L}{\partial y_j} W_{i j}
$$


也就是说， $\frac{\partial L}{\partial x_i}$ 可通过向量 $\frac{\partial L}{\partial \boldsymbol{y}}$ 和 $\boldsymbol{W}$ 的第 $i$ 行向量的内积求出，由此我们可以推导出以下式子：

$$
\frac{\partial L}{\partial \boldsymbol{x}}=\frac{\partial L}{\partial \boldsymbol{y}} \boldsymbol{W}^{\mathrm{T}}
$$

![](https://img.papergate.top:5000/i/2025/12/69510db8529d7.webp)

再次思考 $\boldsymbol{y}=\boldsymbol{x} \boldsymbol{W}$ 这个矩阵乘积的计算，这次 $\boldsymbol{x}$ 、 $\boldsymbol{W}$ 和 $\boldsymbol{y}$ 的形状分别为 $N \times D$ 、 $D \times H$ 和 $N \times H$ ，此时反向传播的计算图如下：

![](https://img.papergate.top:5000/i/2025/12/69510e3960de0.webp)

通过矩阵的形状，我们可以推导出如下的式子：

![](https://img.papergate.top:5000/i/2025/12/69510ec6587d7.webp)

同样，上面的式子也可以通过计算每个矩阵的元素并比较两边的结果推导出来，这里不再进行推导。

#### 框架搭建

在 `dezero/functions.py` 中实现矩阵乘积如下：

```python
class MatMul(Function):  
    def forward(self, x, W):  
        y = x.dot(W)  
        return y  
  
    def backward(self, gy):  
        x, W = self.inputs  
        gx = matmul(gy, W.T)  
        gW = matmul(x.T, gy)  
        return gx, gW  
      
  
def matmul(x, W):  
    return MatMul()(x, W)
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/6951131d9c762.webp)

### 线性回归

构建简单的数据集，进行线性回归。

#### 框架搭建

在 `dezero/funtion.py` 中实现均方误差如下：

```python
class MeanSquaredError(Function):  
    def forward(self, x0, x1):  
        diff = x0 - x1  
        y = (diff ** 2).sum() / len(diff)  
        return y  
  
    def backward(self, gy):  
        x0, x1 = self.inputs  
        diff = x0 - x1  
        gx0 = gy * diff * (2. / len(diff))  
        gx1 = -gx0  
        return gx0, gx1  
  
  
def mean_squared_error(x0, x1):  
    return MeanSquaredError()(x0, x1)
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/69512a8905688.webp)

## 模型
### 线性变换 `Linear`

输入 $\boldsymbol{x}$ 和参数 $\boldsymbol{W}$ 之间的矩阵乘积，然后加上 $\boldsymbol{b}$ 的结果，称为叫作线性变换（linear transformation）或仿射变换（affine transformation ）。

![](https://img.papergate.top:5000/i/2025/12/69512bf2b9d89.webp)

左图的实现方式使用了 `DeZero` 的 `matmul` 函数和 `add` 函数，`matmul` 函数的输出作为 `Variable` 实例记录在计算图中。

右图的实现方式是继承 `Function` 类后实现 `Linear` 类。在使用这种方式下，中间结果没有作为 `Variable` 实例存储在内存中，所以正向传播中使用的数据在正向传播完成后会立即被删除。因此从内存效率的角度考虑，需要使用第二种实现方式。

#### 框架搭建

在 `dezero/function.py` 中实现线性变换：

```python
class Linear(Function):  
    def forward(self, x, W, b):  
        y = x.dot(W)  
        if b is not None:  
            y += b  
        return y  
  
    def backward(self, gy):  
        x, W, b = self.inputs  
        gb = None if b.data is None else sum_to(gy, b.shape)  
        gx = matmul(gy, W.T)  
        gW = matmul(x.T, gy)  
        return gx, gW, gb  
  
  
def linear(x, W, b=None):  
    return Linear()(x, W, b)
```

### 激活函数

线性变换指对输入数据进行线性的变换，而神经网络则对线性变换的输出进行非线性的变换。这种非线性变换叫作激活函数，典型的激活函数有 `ReLU` 和 `sigmoid` 函数。

#### 框架搭建

在 `dezero/function.py` 中实现激活函数：

```python
class Sigmoid(Function):  
    def forward(self, x):  
        y = np.tanh(x * 0.5) * 0.5 + 0.5  
        return y  
  
    def backward(self, gy):  
        y = self.outputs[0]()  
        gx = gy * y * (1 - y)  
        return gx  
  
  
def sigmoid(x):  
    return Sigmoid()(x)  
  
  
class ReLU(Function):  
    def forward(self, x):  
        y = np.maximum(x, 0.0)  
        return y  
  
    def backward(self, gy):  
        x, = self.inputs  
        mask = x.data > 0  
        gx = gy * mask  
        return gx  
  
  
def relu(x):  
    return ReLU()(x)
```

### 层 Layer

在实现结构更加复杂的网络时，参数处理将更加复杂。通过构建 `Layer`，可以实现参数的自动处理。

#### 框架搭建

在 `dezero/core.py` 中新建 `Parameter` 类如下：

```python
class Parameter(Variable):  
    pass
```

在 `dezero/layers.py` 中新建 `Layer` 类如下：

```python
import weakref  
from dezero.core import Parameter  
  
  
class Layer:  
    def __init__(self):  
        self._params = set()  
  
    def __setattr__(self, name, value):  
        if isinstance(value, Parameter):  
            self._params.add(name)  
        super().__setattr__(name, value)  
  
    def __call__(self, *inputs):  
        outputs = self.forward(*inputs)  
        if not isinstance(outputs, tuple):  
            outputs = (outputs,)  
        self.inputs = [weakref.ref(x) for x in inputs]  
        self.outputs = [weakref.ref(y) for y in outputs]  
        return outputs if len(outputs) > 1 else outputs[0]  
  
    def forward(self, inputs):  
        raise NotImplementedError()  
  
    def params(self):  
        for name in self._params:  # 所有的实例变量都以字典的形式存储在实例变量__dict__中  
            yield self.__dict__[name]  
  
    def cleargrads(self):  
        for param in self.params():  
            param.cleargrad()
```

在 `dezero/layers.py` 中新建 `Linear` 类如下：

```python
class Linear(Layer):  
    def __init__(self, out_size, nobias=False, dtype=np.float32, in_size=None):  
        super().__init__()  
        self.in_size = in_size  
        self.out_size = out_size  
        self.dtype = dtype  
        self.W = Parameter(None, name='W')  
        if self.in_size is not None:  # 如果没有指定 in_size，则延后处理  
            self._init_W()  
  
        if nobias:  
            self.b = None  
        else:  
            # 偏置初始化为 0            
            self.b = Parameter(np.zeros(out_size, dtype=dtype), name='b')  
  
    def _init_W(self):  
        I, O = self.in_size, self.out_size  
        # 权重 Xavier 初始化  
        W_data = np.random.randn(I, O).astype(self.dtype) * np.sqrt(1 / I)  
        self.W.data = W_data  
  
    def forward(self, x):  
        if self.W.data is None:  
            self.in_size = x.shape[1]  
            self._init_W()  
        y = F.linear(x, self.W, self.b)  
        return y
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/69515e13a07ca.webp)

### 模型 `Model`

将多层的 `Layer` 合并就可以得到一个 `Model`。

#### 框架搭建

修改 `dezero/layer.py` 中的层 `Layer` 类：

```python
class Layer:  
    def __init__(self):  
        self._params = set()  
  
    def __setattr__(self, name, value):  
        if isinstance(value, (Parameter, Layer)):  # 再增加 Layer
        	self._params.add(name)  
        super().__setattr__(name, value)  
  
    def __call__(self, *inputs):  
        outputs = self.forward(*inputs)  
        if not isinstance(outputs, tuple):  
            outputs = (outputs,)  
        self.inputs = [weakref.ref(x) for x in inputs]  
        self.outputs = [weakref.ref(y) for y in outputs]  
        return outputs if len(outputs) > 1 else outputs[0]  
  
    def forward(self, inputs):  
        raise NotImplementedError()  
  
    def params(self):  
        for name in self._params:  
            obj = self.__dict__[name]  
            if isinstance(obj, Layer):  # 从 Layer 取出参数  
                yield from obj.params()  
            else:  
                yield obj  
  
    def cleargrads(self):  
        for param in self.params():  
            param.cleargrad()
```

在 `dezero/models.py` 中新建 `Model` 类：

```python
from dezero import Layer  
from dezero import utils

class Model(Layer):
    def plot(self, *inputs, to_file='model.png'):
        y = self.forward(*inputs)
        return utils.plot_dot_graph(y, verbose=True, to_file=to_file)
```

在 `dezero/__init__.py` 中新增：

```python
from dezero.layers import Layer  
from dezero.models import Model
```

#### 示例

搭建的两层模型如图：

![](https://img.papergate.top:5000/i/2025/12/695166357b995.webp)

### 全连接网络 MLP

实现一个更通用的全连接层的网络。

#### 框架搭建

```python
class MLP(Model):  
    def __init__(self, fc_output_sizes, activation=F.sigmoid):  
        super().__init__()  
        self.activation = activation  
        self.layers = []  
  
        for i, out_size in enumerate(fc_output_sizes):  
            layer = L.Linear(out_size)  
            setattr(self, 'l' + str(i), layer)  
            self.layers.append(layer)  
  
    def forward(self, x):  
        for l in self.layers[:-1]:  
            x = self.activation(l(x))  
        return self.layers[-1](x)
```

## 优化器

上面框架中在反向传播时，还需要手动使用梯度下降法更新参数，这里使用优化器自动对参数进行更新。

#### 框架搭建

`dezero/optimizers.py` 中创建优化器类：

```python
class Optimizer:  
    def __init__(self):  
        self.target = None  
        self.hooks = []  
  
    def setup(self, target):  
        self.target = target  
        return self  
  
    def update(self):  
        params = [p for p in self.target.params() if p.grad is not None]  
        # 预处理（可选）  
        for f in self.hooks:  
            f(params)  
  
        # 更新参数  
        for param in params:  
            self.update_one(param)  
  
    def update_one(self, param):  
        raise NotImplementedError()  
  
    def add_hook(self, f):  
        self.hooks.append(f)
```

创建梯度下降法：

```python
class SGD(Optimizer):  
    def __init__(self, lr=0.01):  
        super().__init__()  
        self.lr = lr  
  
    def update_one(self, param):  
        param.data -= self.lr * param.grad.data
```

创建动量法：

$$
\boldsymbol{v} \leftarrow \alpha \boldsymbol{v}-\eta \frac{\partial L}{\partial \boldsymbol{W}}
$$

$$
\boldsymbol{W} \leftarrow \boldsymbol{W}+\boldsymbol{v}
$$

```python
class MomentumSGD(Optimizer):  
    def __init__(self, lr=0.01, momentum=0.9):  
        super().__init__()  
        self.lr = lr  
        self.momentum = momentum  
        self.vs = {}  
  
    def update_one(self, param):  
        v_key = id(param)  
        if v_key not in self.vs:  
            self.vs[v_key] = np.zeros_like(param.data)  
        v = self.vs[v_key]  
        v *= self.momentum  
        v -= self.lr * param.grad.data  
        param.data += v
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/69516d69d416d.webp)

> [!info]  信息
> 可以看到，到这里已经比较接近 `pytorch` 的代码体验了。

## 多分类
### 切片

增加切片函数，将多维数组中的一些数据原封不动地传递出去，反向传播只在被提取的部分设置梯度，其余部分梯度为 0。

#### 框架搭建

在 `dezero/functions.py` 增加 `GetItem` 类：

```python
class GetItem(Function):  
    def __init__(self, slices):  
        self.slices = slices  
  
    def forward(self, x):  
        y = x[self.slices]  
        return y  
  
    def backward(self, gy):  
        x, = self.inputs  
        f = GetItemGrad(self.slices, x.shape)  
        return f(gy)  
  
  
class GetItemGrad(Function):  
    def __init__(self, slices, in_shape):  
        self.slices = slices  
        self.in_shape = in_shape  
  
    def forward(self, gy):  
        gx = np.zeros(self.in_shape, dtype=gy.dtype)  
        np.add.at(gx, self.slices, gy)  # 只在被切出来的位置累加梯度，其余位置为 0
        return gx  
  
    def backward(self, ggx):  
        return get_item(ggx, self.slices)  
  
  
def get_item(x, slices):  
    f = GetItem(slices)  
    return f(x)
```

### `Softmax`

`Softmax` 将神经网络输出的数值转换为概率。

$$
p_k=\frac{\exp \left(y_k\right)}{\sum_{i=1}^n \exp \left(y_i\right)}
$$

![](https://img.papergate.top:5000/i/2025/12/695206408f76a.webp)

`Softmax` 有平移不变性：

$$
\mathrm{softmax}(x)=\mathrm{softmax}(x−c)
$$

取 $c$ 为 $\max (x)$ 可以防止 $\exp$ 溢出。

#### 框架搭建

在 `dezero/functions.py` 增加 `Softmax` 类：

```python
class Softmax(Function):  
    def __init__(self, axis=1):  
        self.axis = axis  
  
    def forward(self, x):  
        y = x - x.max(axis=self.axis, keepdims=True)  
        y = np.exp(y)  
        y /= y.sum(axis=self.axis, keepdims=True)  
        return y  
  
    def backward(self, gy):  
        y = self.outputs[0]()  
        gx = y * gy  
        sumdx = gx.sum(axis=self.axis, keepdims=True)  
        gx -= y * sumdx  
        return gx  
  
  
def softmax(x, axis=1):  
    return Softmax(axis)(x)
```

### 交叉熵

在线性回归中，我们使用均方误差作为损失函数，但在进行多分类时，需要使用专用的损失函数。最常用的是交叉熵误差（cross entropy error），对于单个样本，有：

$$
L=-\sum_i^C y_{i} \log p_i
$$
其中 $C$ 为类别数量，样本的真实类别为 $t$， $y_i= \begin{cases}1 & i=t \\ 0 & i \neq t\end{cases}$

在多分类问题中，交叉熵误差的 $p_k$ 使用 ` Softmax ` 函数的输出，可以将 `Softmax` 函数和交叉熵误差合二为一来实现，合并后的函数复杂度会更低，计算更加稳定。

$$
L=-\sum_{i=1}^C y_i \log \left(\frac{e^{x_i}}{\sum_k e^{x_k}}\right)
$$

 所以: 

$$
L=-\sum_i y_i x_i+\sum_i y_i \log \sum_k e^{x_k}
$$

因为 $y$ 是 `one-hot`，$\sum_i y_i=1$， $\sum_i y_ix_i=x_t$，所以：

$$
L=-x_t+\log \sum_{k=1}^C e^{x_k}
$$
扩展到多样本：

$$
L=\frac{1}{N} \sum_{i=1}^N\left(-x_{i, t_i}+\log \sum_k e^{x_{i k}}\right)
$$

对第 $i$ 个样本求梯度：

$$
L_i=-x_{i, t_i}+\log \sum_{k=1}^C e^{x_{i k}}
$$

对于第一项：

$$
\frac{\partial\left(-x_{i, t_i}\right)}{\partial x_{i j}}=\left\{\begin{array}{ll}
-1 & j=t_i \\
0 & j \neq t_i
\end{array}=-\left(t_{-} \text {onehot }\right)_{i j}\right.
$$

对于第二项：

$$
\frac{\partial}{\partial x_{i j}} \log \sum_k e^{x_{i k}}=\frac{e^{x_{i j}}}{\sum_k e^{x_{i k}}}=\left(\operatorname{softmax}\left(x_i\right)\right)_j
$$

于是：

$$
\frac{\partial L_i}{\partial x_{i j}}=\left(\operatorname{softmax}\left(x_i\right)\right)_j-\left(t_{-} o n e h o t\right)_{i j}
$$

所以对整体求梯度：

$$
\frac{\partial L}{\partial x_{i j}}=\frac{1}{N}\left[\left(\operatorname{softmax}\left(x_i\right)\right)_j-\left(t_{-} o n e h o t\right)_{i j}\right]
$$

#### 框架搭建

在 `dezero/utils.py` 增加 `logsumexp` 类：

```python
def logsumexp(x, axis=1):  
    m = x.max(axis=axis, keepdims=True)  
    y = x - m  
    np.exp(y, out=y)  
    s = y.sum(axis=axis, keepdims=True)  
    np.log(s, out=s)  
    m += s  
    return m
```

> [!info]  信息
> 通常维度 0 表示不同的样本，维度 1 表示不同的类别，因而默认的轴为 1。

在 `dezero/functions.py` 增加 `SoftmaxCrossEntropy` 类：

```python
class SoftmaxCrossEntropy(Function):  
    def forward(self, x, t):  
        N = x.shape[0]  
        log_z = utils.logsumexp(x, axis=1)  
        log_p = x - log_z
        # np.arange(N) 从 0 到 N-1
        # t.ravel() 把真实标签索引展平成一维向量
        # 从log_p中取出每一个log_p[x,y]对应的元素
        log_p = log_p[np.arange(N), t.ravel()]  
        y = -log_p.sum() / np.float32(N)  
        return y
  
    def backward(self, gy):  
        x, t = self.inputs  
        N, CLS_NUM = x.shape  
  
        gy *= 1 / N  
        y = softmax(x)
        # 通过 t.data 获取真实标签的索引，从单位矩阵中取出对应的行
        t_onehot = np.eye(CLS_NUM, dtype=t.dtype)[t.data]  
        y = (y - t_onehot) * gy  
        return y  
  
  
def softmax_cross_entropy(x, t):  
    return SoftmaxCrossEntropy()(x, t)
```

## 数据集

###  `Dataset` 类

Dataset 类是作为基类实现的。我们让用户实际使用的数据集类继承 Dataset 类。

#### 框架搭建

在 `dezero/dataset.py` 中创建 `Dataset` 类：

```python
class Dataset:  
    def __init__(self, train=True, transform=None, target_transform=None):  
        self.train = train  
        self.transform = transform  
        self.target_transform = target_transform  
        if self.transform is None:  
            self.transform = lambda x: x  
        if self.target_transform is None:  
            self.target_transform = lambda x: x  
  
        self.data = None  
        self.label = None  
        self.prepare()  
  
    def __getitem__(self, index):  
        assert np.isscalar(index)  
        if self.label is None:  
            return self.transform(self.data[index]), None  
        else:  
            return self.transform(self.data[index]), self.target_transform(self.label[index])  
  
    def __len__(self):  
        return len(self.data)  
  
    def prepare(self):  
        pass
```

###  `transform` 函数

#### `Normalize`

`Normalize` 对数据进行正则化处理。

```python
class Normalize:  
    def __init__(self, mean=0, std=1):  
        self.mean = mean  
        self.std = std  
  
    def __call__(self, array):  
        mean, std = self.mean, self.std  
  		# 正则化图像的处理步骤
        if not np.isscalar(mean):  
        	# 构造全 1 的 mshape，长度等于数组维度
            mshape = [1] * array.ndim
            # 这里考虑的数组是 [C H W]，维度0是通道
            mshape[0] = len(array) if len(self.mean) == 1 else len(self.mean)  
            mean = np.array(self.mean, dtype=array.dtype).reshape(*mshape)  
        if not np.isscalar(std): 
        	# 构造全 1 的 mshape，长度等于数组维度 
            rshape = [1] * array.ndim
            # 这里考虑的数组是 [C H W]，维度0是通道
            rshape[0] = len(array) if len(self.std) == 1 else len(self.std) 
            std = np.array(self.std, dtype=array.dtype).reshape(*rshape)  
        return (array - mean) / std
```

> [!info]  信息
> 由于框架的 `transform` 函数是在 `__getitem__` 的时候执行的，所以样本数都是 1，因而输入数据不是 `[N C H W]` 这种结构。

#### `Flatten`

`Flatten` 将数据展平成一维。

```python
class Flatten:  
    def __call__(self, array):  
        return array.flatten()
```

#### `ToFloat`

`ToFloat` 将数据的类型转成 `np.float32`

```python
class AsType:  
    def __init__(self, dtype=np.float32):  
        self.dtype = dtype  
  
    def __call__(self, array):  
        return array.astype(self.dtype)  
  
  
ToFloat = AsType
```

#### `Compose`

`Compose` 类按顺序从头开始连续进行多个转换。

```python
class Compose:  
  
    def __init__(self, transforms=[]):  
        self.transforms = transforms  
  
    def __call__(self, img):  
        if not self.transforms:  
            return img  
        for t in self.transforms:  
            img = t(img)  
        return img
```

### `DataLoader` 类

 `DataLoader` 从 `Dataset` 中创建小批量数据，实现数据集重排等工作。

#### 框架搭建

```python
import math  
import numpy as np  
  
  
class DataLoader:  
    def __init__(self, dataset, batch_size, shuffle=True):  
        self.dataset = dataset  
        self.batch_size = batch_size  
        self.shuffle = shuffle  
        self.data_size = len(dataset)  
        self.max_iter = math.ceil(self.data_size / batch_size)  
  
        self.reset()  
  
    def reset(self):  
        self.iteration = 0  
        if self.shuffle:  
            self.index = np.random.permutation(len(self.dataset))  
        else:  
            self.index = np.arange(len(self.dataset))  
  
    def __iter__(self):  
        return self  
  
    def __next__(self):  
        if self.iteration >= self.max_iter:  
            self.reset()  
            raise StopIteration  
  
        i, batch_size = self.iteration, self.batch_size  
        batch_index = self.index[i * batch_size:(i + 1) * batch_size]  
        batch = [self.dataset[i] for i in batch_index]  
        x = np.array([example[0] for example in batch])  
        t = np.array([example[1] for example in batch])  
  
        self.iteration += 1  
        return x, t  
  
    def next(self):  
        return self.__next__()
```

### 准确率

添加一个用于评估识别精度的函数 `accuracy`。

#### 框架搭建

在 `dezero/functions.py` 中添加 `accuracy` 函数

```python
def accuracy(y, t):  
    y, t = as_variable(y), as_variable(t)  
    pred = y.data.argmax(axis=1).reshape(t.shape)  
    result = (pred == t.data)  
    acc = result.mean()  
    return Variable(as_array(acc))
```

### 文件下载

添加一个下载文件的工具函数。

#### 框架搭建

在 `dezero/utils.py` 中添加 `show_progress` 函数和 `get_file` 函数。

```python
def show_progress(block_num, block_size, total_size):  
    bar_template = "\r[{}] {:.2f}%"  
  
    downloaded = block_num * block_size  
    p = downloaded / total_size * 100  
    i = int(downloaded / total_size * 30)  
    if p >= 100.0: p = 100.0  
    if i >= 30: i = 30  
    bar = "#" * i + "." * (30 - i)  
    print(bar_template.format(bar, p), end='')  
  
  
def get_file(url, file_name=None):  
    if file_name is None:  
        file_name = url[url.rfind('/') + 1:]  
    file_path = f'./{file_name}'  
  
    if os.path.exists(file_path):  
        return file_path  
  
    print("Downloading: " + file_name)  
    try:  
        urllib.request.urlretrieve(url, file_path, show_progress)  
    except (Exception, KeyboardInterrupt) as e:  
        if os.path.exists(file_path):  
            os.remove(file_path)  
        raise  
    print(" Done")  
  
    return file_path
```

### `MNIST` 数据集

使用 `Dataset` 类构建一个 `MNIST` 数据集。

#### 框架搭建

```python
class MNIST(Dataset):  
  
    def __init__(self, train=True,  
                 transform=Compose([Flatten(), ToFloat(),  
                                    Normalize(0., 255.)]),  
                 target_transform=None):  
        super().__init__(train, transform, target_transform)  
  
    def prepare(self):  
        url = 'https://ossci-datasets.s3.amazonaws.com/mnist/'  # mirror site  
        train_files = {'target': 'train-images-idx3-ubyte.gz',  
                       'label': 'train-labels-idx1-ubyte.gz'}  
        test_files = {'target': 't10k-images-idx3-ubyte.gz',  
                      'label': 't10k-labels-idx1-ubyte.gz'}  
  
        files = train_files if self.train else test_files  
        data_path = get_file(url + files['target'])  
        label_path = get_file(url + files['label'])  
  
        self.data = self._load_data(data_path)  
        self.label = self._load_label(label_path)  
  
    def _load_label(self, filepath):  
        with gzip.open(filepath, 'rb') as f:  
            labels = np.frombuffer(f.read(), np.uint8, offset=8)  
        return labels  
  
    def _load_data(self, filepath):  
        with gzip.open(filepath, 'rb') as f:  
            data = np.frombuffer(f.read(), np.uint8, offset=16)  
        data = data.reshape(-1, 1, 28, 28)  
        return data  
  
    def show(self, row=10, col=10):  
        H, W = 28, 28  
        img = np.zeros((H * row, W * col))  
        for r in range(row):  
            for c in range(col):  
                img[r * H:(r + 1) * H, c * W:(c + 1) * W] = self.data[  
                    np.random.randint(0, len(self.data) - 1)].reshape(H, W)  
        plt.imshow(img, cmap='gray', interpolation='nearest')  
        plt.axis('off')  
        plt.show()  
  
    @staticmethod  
    def labels():  
        return {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}
```

#### 示例

![](https://img.papergate.top:5000/i/2025/12/6952db1ca1134.webp)

> [!info]  信息
> 到这里已经和 `pytorch` 的使用体验基本一样了。

## 支持 GPU



---

> 作者: Aphros  
> URL: https://blog.papergate.top/posts/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%9F%BA%E7%A1%80~%E8%87%AA%E5%88%B6%E6%A1%86%E6%9E%B6/  

