简介
综合转载:
写在前面:本文不讨论任何涉及涉事人员、企业的话题,援引网传聊天记录内容,仅从技术角度剖析相关安全问题和应对方法。由于本身是做AI框架和NLP的,所以其实在AI安全领域并不专业,如有错误希望大家批评指正。
解析勘误
根据网传的聊天记录,可以得到的有效信息有:
- 利用huggingface的load ckpt函数漏洞注入代码
- 动态修改
- 具体的修改内容,包括optimizer、随机种子、sleep、修改梯度参数,以及随机kill进程。
这个事情其实对Pytorch和Python比较熟的人应该很自然地能想到是pickle漏洞的问题,可以通过序列化文件注入的方式,在反序列化时可以进行任何的代码执行(与执行程序同权限)。
这里首先需要勘误一个内容,这个锅呢,说是huggingface有漏洞,不如说Pytorch有漏洞,更不如说Python一直放任着pickle的问题,全部交由开发者自行维护(所以很多人吐槽企业管理问题也不无道理?)。
先来看一下,所谓的huggingface漏洞。
Huggingface->Pytorch->Pickle漏洞
当我们训练模型时,通常会将模型的权重保存到文件中,以便在检查点保存和稍后加载。最流行的格式是 PyTorch 的状态字典,它是一个 Python 字典对象,将每一层映射到其参数 tensor。我猜大多数人对以下代码片段都很熟悉:
1 | # 保存模型权重 |
然而,这种方法使用 pickle
来序列化和反序列化整个状态字典对象,引发了安全性问题。原因在于 pickle
并不安全,可能会加载具有与反序列化程序相同权限的任意代码。攻击者可以通过模型权重注入任意代码,造成严重的安全问题。一种攻击模型权重的方法是修改其 __reduce__
方法来执行任意代码。
1 | class Obj: |
如果你将此对象序列化并保存到文件中,那么加载对象时代码就会执行。也就是说,当你加载对象时,你会看到打印出的 “hello”。
涉及技术
Pickle注入
Pickle模块是Python自带的序列化和反序列化模块,其工作原理为序列化时将Python对象转换成字节流,反序列化时把字节流还原成Python对象。因此攻击者可以在序列化的数据里嵌入恶意代码,在反序列化这个数据时,这些恶意代码就会被执行。
1 | import pickle |
这个示例是从jb51的文章[2]直接引用的,其中创建了一个名为Malicious
的类。这个类的__reduce__
方法返回一个元组,第一个元素是os.system
,第二个元素是要执行的命令。当反序列化这个对象时,os.system('echo Hacked!')
会被执行,输出“Hacked!”。
可以看到我们能够很简单的借助Pickle注入,实现任意的系统调用(当前OS User权限下),这个时候,进行恶意代码执行非常方便,但是系统调用是有痕迹的,而且较容易排查,所以使用接下来的方法配合才是“王炸”——对源代码没有任何修改(知乎回答中提到的代码审核不会起到任何作用)。
援引一下“AI+Security”第二期《AI/机器学习供应链攻击》[1]分享的slides,其中详细地分析了pickle.load反序列化漏洞:
Monkey Patching
Monkey Patching, a dynamic technique in Python, allows developers and quality assurance engineers (QAEs) to make runtime modifications to classes, objects, or even built-in modules. With its ability to address common pain points in test automation, Monkey Patching in Python has become an invaluable tool in the arsenal of testers and developers.[3]
如引用所示,Monkey Patching是一个被广泛用于自动化测试的技术,可以通过运行时对Python模块、类、对象(函数也是一种对象)等替换,达到灵活测试各种case的目的。
But!反过来说,我可以在运行时随意替换Python代码,且不修改任何源码。举个简单例子:
1 | import torch |
这个例子看到我们轻易地就可以在运行时把torch.add方法进行替换,让其结果+1,得到错误的输出6。
有了这个概念,我们基本上可以操控程序的许多部分,包括导入的库和本地变量。我提供了两个典型场景,展示如何中断训练过程以及篡改模型权重的算术正确性。
场景 1:自动终止训练过程
如上例中的 “hello” 一样,恶意代码可以编写为一个代码字符串。同样,我们可以准备如下代码字符串,创建一个新线程,该线程在 5 秒后终止父进程。此线程在后台运行,因此用户不会注意到任何异常,而 os.kill
不返回错误日志,这使得检测恶意代码变得更加困难。
1 | AUTO_SHUTDOWN = """ |
接下来,我们需要将此代码注入到状态字典对象中。结果是,当我们从磁盘加载模型权重时,代码会执行,训练过程会被中断。
1 | def inject_malicious_code(obj, code_str): |
场景 2:在集合通信中引入错误
类似地,如果我们想修改集合通信操作的行为,可以在计算过程中引入错误,使得分布式训练中的梯度永远不正确。我们可以准备如下代码字符串,劫持 all_reduce
函数。这个代码字符串对 torch.distributed
模块中的 all_reduce
API 进行猴子补丁,并对 tensor 执行加 1 操作。结果是 all-reduce 的结果会比预期结果大。
1 | HIJACK_ALL_REDUCE = """ |
例如,如果你有两个进程,每个进程持有 tensor [0, 1, 2, 3]
,all-reduce 操作会将各个进程的 tensor 相加,结果是 [0, 2, 4, 6]
。然而,如果攻击者注入了恶意代码,结果将变为 [2, 4, 6, 8]
。
而对于Python而言,其实一直在官网文档上写着安全警告:
由于Pytorch直接使用了Python原生的pickle模块进行模型的序列化,因此一直以来Pytorch模型的序列化保存(.bin/.pth文件)都是有风险的,在hugginface网站托管的部分模型也有通过pickle反序列化漏洞的恶意模型。
那为什么我要先做这个勘误呢?首先,huggingface很早就注意到这个问题,并且针对性推出了safetensors,以保证权重文件序列化的安全性(即只保存读取Tensor和对应key)。
Pytorch的改进
其次,Pytorch实际上也做了改进,在torch.load
接口提供了weights_only
参数,用来规避此类问题。
所以追根溯源,我们更应该称之为Pickle注入漏洞。
调用栈回溯
其实读到这里大家肯定会有个感觉,明显有可以规避的方案,为什么没有用呢?原因大概率是因为分布式训练代码没人会去动导致的。从huggingface transformers源码入手我们来回溯一下恶意代码到底是怎么注入进去的。
在此之前先科普一个概念,模型训练过程中序列化保存的中间Checkpoint,不仅仅包含模型权重,还包括Optimizer States、Learning rate schecduler和相关的超参配置,问题正是出在了这里。
transformers的trainer在保存时,对这几部分内容是分别实现的,save_model方法中是可以配置保存为safetensors的:
但是optimizer和scheduler的保存则不然(倒是和网传聊天记录的内容相符,从optimizer的ckpt下手的概率最大):
可以看到分了几个不同训练框架的分支,默认使用torch.save,这里以Pytorch原生的fsdp为例,继续step in看看里面是怎么调用的:
继续查看DefaultSavePlanner
源码可以发现,其使用的序列化方法就是torch.save
:
这时我们再反过来看load时对应的Planner:
可以看到Pytorch明确使用了weights_only=False
,因此哪怕是模型权重保存为Safetensor,但是Optimizer和Scheduler默认都是使用原生Pickle,并且完全没有对Pickle注入的防范措施。
产生的影响
对模型收敛性/最终精度的影响,包括但不限于:
- 改变/扰动梯度下降方向造成的loss突刺和NAN
- Optimizer States修改造成的断点续训失效
- 确定性计算不统一
对集群训练的影响,包括但不限于:
- 随机sleep造成的进程通信超时
- 随机kill造成的节点宕机假象
暂时只能猜到这些,看起来能在内网群当“内鬼”能操作的还有很多,这里就交给专业的人去分析吧,我还是从框架和Python用户的视角,针对网传聊天记录的有效信息,对前两个方法进行深入解析。
解决方案
前面把技术点简单解释了下,还是得看看有什么好的办法解决或者阻拦该类事件的发生。(当然,从源头开始的话,把控好人的权限才是最重要的 )。
替换序列化模块
- Safetensors+json:huggingface已经提供了绝佳的解决方案,虽然序列化时optimizer要比model保存的内容要多一些,但其实也就是key(str): value(Tensor)构成的dict,以及超参数dict,完全可以将其拆分为safetensor和json保存,二者都没有什么代码注入隐患。但是需要自己动手改造checkpoint保存部分的逻辑。
- Protobuf:全称Google Protocol Buffers,是一种轻便、高效的结构化数据存储格式,可以用于结构化数据序列化。Protobuf是语言无关的序列化框架,独立于Python,也没有代码注入漏洞。(PS:部分框架,如Tensorflow、MindSpore选择此方案)。
Checkpoint校验
除了在序列化反序列化上直接封死注入漏洞外,如果已有的训练已经进行,短时间无法改造,则应该从Checkpoint文件下手,校验文件是否被修改(仅限于类似于字节事件场景下,破坏者需要直接对ckpt文件进行修改的情况,此时文件md5必然改变)。
find_class
重载
这是一个非常好的拦截手段,很难想象Pytorch和huggingface都没有在这里做任何拦截(西方人独有的chill感?)。这里我们打开torch.load源码看一下:
在pickle反序列化过程中,要执行python代码实际上需要Unpickler进行一个find_class
的操作,用来找到要执行的模块或对象,这里其实完全可以重载进行过滤。而Pytorch虽然重载了,但是完全没做任何安全性过滤(事实上torch序列化文件需要的反序列化执行函数完全可以统计为白名单)。
这里接前文的例子,我们重载find_class
,将可以反序列化的数据类型进行限制,看看效果:
1 | import pickle |
可以看到只要是没有在白名单内的数据类型,可以直接抛错。
Monkey Patching检测
最后再来说一下Monkey Patching检测,这里我们要做一个强假设,即:执行的代码任意位置都可能被修改,且完全不可预知。
因此需要一个全局的检测机制,这里我们可以利用 sys.settrace() 来实时跟踪代码[5]:
1 | import sys |
可以看到字节码frame对应的name和filename都可以记录,在有monkey patching风险的情况下可以进行全局执行log的排查,看到问题方法。不过这个方法会产生大量log,还是建议最多一个step进行check。