隐私计算实训营第二期第十课:基于SPU机器学习建模实践

隐私计算实训营第二期-第十课

  • 第十课:基于SPU机器学习建模实践
    • 1 隐私保护机器学习背景
      • 1.1 机器学习中隐私保护的需求
      • 1.2 PPML提供的技术解决方案
    • 2 SPU架构
      • 2.1 SPU前端
      • 2.2 SPU编译器
      • 2.3 SPU运行时
      • 2.4 SPU目标
    • 3 密态训练与推理
      • 3.1 四个基本问题
      • 3.2 解决数据来源问题
      • 3.3 解决数据安全问题
      • 3.4 解决模型计算问题
      • 3.5 解决密态计算问题
      • 3.6 如何应对更复杂的模型
      • 3.7 已有模型的复用
    • 4 作业实践
      • 4.1 基础NN模型作业
      • 4.2 进阶Transformer模型作业

第十课:基于SPU机器学习建模实践

首先必须感谢蚂蚁集团及隐语社区带来的隐私计算实训第二期的学习机会!
本节课由蚂蚁隐私计算部算法工程师吴豪奇老师讲解。

在这里插入图片描述
本节课主要内容为:

  • 隐私保护机器学习背景
  • SPU架构简介
  • NN密态训练/推理示例

1 隐私保护机器学习背景

1.1 机器学习中隐私保护的需求

本节课前两个小节的内容,我们这之前的课程中已有一些了解,
本节课可以回顾一下。
数据和模型的隐私保护需求是产生隐私保护机器学习的根因。

在这里插入图片描述

1.2 PPML提供的技术解决方案

MPC提供了隐私保护的技术解决方案。

在这里插入图片描述

使用MPC结合机器学习,为模型训练和推理提供隐私保护。
问题
我们是否可以直接以 MPC 的方式高效地运行已有的机器学习程序?

在这里插入图片描述

2 SPU架构

SPU架构我们在之前已经学习过,宏观上主要分为三部分:

  1. 前端部分
  2. 编译器
  3. 运行时

2.1 SPU前端

SPU前端尽量支持原生的AI编程方式,支持JAX、TensorFlow,Pytorch
等典型的AI编程框架。

在这里插入图片描述

2.2 SPU编译器

SPU的编译器以优化方式生成SPU的密态中间语言。

SPu

2.3 SPU运行时

SPU的运行时支持多种并行模式(数据并行+指令并行),多种MPC协议
以及多种部署模式。

在这里插入图片描述

2.4 SPU目标

SPU的最终目标是实现易用、可扩展和高性能的密态计算虚拟设备。

在这里插入图片描述

3 密态训练与推理

3.1 四个基本问题

密态的训练和推理需要解决的四个问题:

  • 数据从哪来?
  • 如何加密保护数据?
  • 如何定义模型计算?
  • 如何执行密态模型计算?

在这里插入图片描述

3.2 解决数据来源问题

数据由数据个参与方以密态的形式提供。

在这里插入图片描述

3.3 解决数据安全问题

数据安全通过MPC协议或者同态加密等外部模式解决。

在这里插入图片描述

3.4 解决模型计算问题

NN模型的计算问题通过JAX实现前向和反向传播。

在这里插入图片描述

3.5 解决密态计算问题

NN模型的密态计算SPU的编译器转换为密态算子,然后按照MPC协议
进行计算。

在这里插入图片描述

密态的计算过程与明文类似,通过SPU密态计算配置实现密态训练。
在这里插入图片描述

3.6 如何应对更复杂的模型

对于复杂模型,使用stax和flax来进行实现。

在这里插入图片描述

3.7 已有模型的复用

已有模型的复用问题,根据明文实现来进行密态计算的迁移。
比如,明文实现的GPT2模型。
在这里插入图片描述

然后进行密态迁移:

在这里插入图片描述

在支持不同的模型方面,SPU还需要更新和优化自己的实现以满足不同
模型的需求。

在这里插入图片描述

4 作业实践

4.1 基础NN模型作业

本次课程有两个作业,一个是基础的NN模型。另一个是进阶的Transformer
模型。

在这里插入图片描述

完成步骤如下:

1、加载数据集

import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer


def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
    x, y = load_breast_cancer(return_X_y=True)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42
    )

    if train:
        if party_id:
            if party_id == 1:
                return x_train[:, :15], _
            else:
                return x_train[:, 15:], y_train
        else:
            return x_train, y_train
    else:
        return x_test, y_test

2、定义模型

from typing import Sequence
import flax.linen as nn


FEATURES = [30, 15, 8, 1]


class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

3、定义训练参数

import jax.numpy as jnp


def predict(params, x):
    # TODO(junfeng): investigate why need to have a duplicated definition in notebook,
    # which is not the case in a normal python program.
    from typing import Sequence
    import flax.linen as nn

    FEATURES = [30, 15, 8, 1]

    class MLP(nn.Module):
        features: Sequence[int]

        @nn.compact
        def __call__(self, x):
            for feat in self.features[:-1]:
                x = nn.relu(nn.Dense(feat)(x))
            x = nn.Dense(self.features[-1])(x)
            return x

    return MLP(FEATURES).apply(params, x)


def loss_func(params, x, y):
    pred = predict(params, x)

    def mse(y, pred):
        def squared_error(y, y_pred):
            return jnp.multiply(y - y_pred, y - y_pred) / 2.0

        return jnp.mean(squared_error(y, pred))

    return mse(y, pred)


def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
    x = jnp.concatenate((x1, x2), axis=1)
    xs = jnp.array_split(x, len(x) / n_batch, axis=0)
    ys = jnp.array_split(y, len(y) / n_batch, axis=0)

    def body_fun(_, loop_carry):
        params = loop_carry
        for x, y in zip(xs, ys):
            _, grads = jax.value_and_grad(loss_func)(params, x, y)
            params = jax.tree_util.tree_map(
                lambda p, g: p - step_size * g, params, grads
            )
        return params

    params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
    return params


def model_init(n_batch=10):
    model = MLP(FEATURES)
    return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))

4、验证参数

from sklearn.metrics import roc_auc_score
def validate_model(params, X_test, y_test):
    y_pred = predict(params, X_test)
    return roc_auc_score(y_test, y_pred)

5、开始明文训练

import jax

# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)

# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01

# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)

# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

这里输出的明文训练结果为:

在这里插入图片描述

6、开始密文训练

import secretflow as sf

# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))

x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)

device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)

params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
    x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)

7、检查参数

params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)

8、输出训练结果

X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')

密文训练输出结果为:

在这里插入图片描述
可以看出,密文训练和明文训练的效果相同,本作业结束,

4.2 进阶Transformer模型作业

完成步骤如下:
1、安装Transformer模型

import sys
!{sys.executable} -m pip install transformers[flax] -i https://pypi.tuna.tsinghua.edu.cn/simple

2、设置镜像huggingface

import os
import sys
!{sys.executable} -m pip install huggingface_hub
os.environ['HF_ENDPOINT']='https://hf-mirror.com'

3、加载模型

from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config
tokenizer = AutoTokenizer.from_pretrained("gpt2")
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")

4、定义文本生成函数

def text_generation(input_ids, params):
    config = GPT2Config()
    model = FlaxGPT2LMHeadModel(config=config)

    for _ in range(10):
        outputs = model(input_ids=input_ids, params=params)
        next_token_logits = outputs[0][0, -1, :]
        next_token = jnp.argmax(next_token_logits)
        input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)
    return input_ids

5、进行明文的文本生成

import jax.numpy as jnp

inputs_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
outputs_ids = text_generation(inputs_ids, pretrained_model.params)

print('-' * 65 + '\nRun on CPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

生成的输出结果为:

在这里插入图片描述
6、进行密文训练

import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob', 'carol'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])
conf['runtime_config']['fxp_exp_mode'] = 1
conf['runtime_config']['experimental_disable_mmul_split'] = True
spu = sf.SPU(conf)


def get_model_params():
    pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
    return pretrained_model.params


def get_token_ids():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')


model_params = alice(get_model_params)()
input_token_ids = bob(get_token_ids)()

device = spu
model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)

output_token_ids = spu(text_generation)(input_token_ids_, model_params_)

这里由于机器配置不够,内存不足,被系统kill进程,导致无法完成训练。小伙伴们机器好的应该可以跑完。

在这里插入图片描述

7、输出密文训练结果

outputs_ids = sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on SPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

至此,本次作业全部结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/773976.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

二叉搜索树(BST)

目录 一、概念 二、代码实现 1.框架 2.查找 3.插入 4.删除 5.递归的写法 三、应用 一、概念 二、代码实现 1.框架 #pragma oncenamespace utoKey {//结点template<class K>struct BinarySearchTreeNode{//结点的typedeftypedef BinarySearchTreeNode Node;//Nod…

利用pg_rman进行备份与恢复操作

文章目录 pg_rman简介一、安装配置pg_rman二、创建表与用户三、备份与恢复 pg_rman简介 pg_rman 是 PostgreSQL 的在线备份和恢复工具。类似oracle 的 rman pg_rman 项目的目标是提供一种与 pg_dump 一样简单的在线备份和 PITR 方法。此外&#xff0c;它还为每个数据库集群维护…

AIGC时代,“人”的核心价值在何处?

随着科技的浪潮汹涌向前&#xff0c;人工智能生成内容&#xff08;AIGC&#xff09;已悄然渗透至我们生活的每一个角落&#xff0c;从创意设计到信息传播&#xff0c;其影响力与变革力愈发显著。在这一由算法驱动的新纪元里&#xff0c;人类社会运作模式、学习途径及职业形态均…

眼动追踪技术 | 眼动的分类和模型

摘要 灵长类动物用于调整中央凹位置的正常眼动&#xff0c;几乎都可以归结为五种基本类型的组合&#xff1a;扫视、平稳追踪、聚散、前庭眼震和生理性眼震(与注视相关的微小运动)。聚散运动用于将双眼聚焦于远处的目标(深度知觉)。其他运动(如适应和聚焦)指的是眼动的非位置变…

LMT加仿真,十一届大唐杯全国总决赛

这次省赛带了太多个省一了&#xff0c;并且很多都进入了国赛总决赛&#xff0c;具体可看下面的图片&#xff0c;只放了一部分。目前只有B组是只有一个商用设备赛也就是LMT&#xff0c;A组和高职组都是仿真实践赛加上商用设备赛。 针对商用设备赛有对应的资料&#xff…

【深度学习】第3章——回归模型与求解分析

一、回归分析 1.定义 分析自变量与因变量之间定量的因果关系&#xff0c;根据已有的数据拟合出变量之间的关系。 2.回归和分类的区别和联系 3.线性模型 4.非线性模型 5.线性回归※ 面对回归问题&#xff0c;通常分三步解决 第一步&#xff1a;选定使用的model&#xff0c;…

CFS三层内网渗透——第二层内网打点并拿下第三层内网(三)

目录 八哥cms的后台历史漏洞 配置socks代理 ​以我的kali为例,手动添加 socks配置好了&#xff0c;直接sqlmap跑 ​登录进后台 蚁剑配置socks代理 ​ 测试连接 ​编辑 成功上线 上传正向后门 生成正向后门 上传后门 ​内网信息收集 ​进入目标二内网机器&#xf…

SAP-SD同一物料下单价格确不同

业务说明&#xff1a; 业务部门反馈&#xff0c;同一物料下销售订单时&#xff0c;价格确不同。 那么这个价格是怎么取到的呢&#xff1f; 逻辑说明&#xff1a; 1、首先查看销售订单 可以看到相同物料价格是不同的&#xff0c;条件类型都是ZPR5&#xff0c;但是客户是不同…

相关款式1111

一、花梨木迎客松 1. 风速打单 发现只有在兄弟店铺有售卖 六月份成交订单数有62笔 2. 生意参谋 兄弟店铺商品访客数&#xff1a;3548&#xff0c;支付件数&#xff1a;95件 二. 竹节茶刷&#xff08;引流&#xff09; 1. 风速打单 六月订单数有165笔 兄弟&#xff1a;…

揭秘数据之美:【Seaborn】在现代【数学建模】中的革命性应用

目录 已知数据集 tips 生成数据集并保存为CSV文件 数据预览&#xff1a; 导入和预览数据 步骤1&#xff1a;绘制散点图&#xff08;Scatter Plot&#xff09; 步骤2&#xff1a;添加回归线&#xff08;Regression Analysis&#xff09; 步骤3&#xff1a;分类变量分析&…

Mall,正在和年轻人重新对话

【潮汐商业评论/原创】 结束了一下午的苦闷培训&#xff0c;当Cindy赶到重庆十字大道时&#xff0c;才发现十字路口上的巨大“飞行棋”在前两天就已经撤展了。 “来了又错过&#xff0c;就会觉得遗憾&#xff0c;毕竟这样的路口不多&#xff0c;展陈又不可能会返场。” 飞行棋…

藏文作文写作业推荐什么学习工具?《藏文翻译词典》App值得你使用,一款好用准确的藏语词汇查询辞典!

探索藏语的奥秘&#xff0c;体验藏族文化的魅力&#xff0c;尽在《藏文翻译词典》App。这款App是藏汉翻译的神器&#xff0c;也是藏语学习者的必备工具。在学习过程中遇到不会的藏语单词&#xff0c;可以使用《藏文翻译词典》App进行查询&#xff01; 主要特性&#xff1a; 藏…

SCT612404通道,高效高集成,摄像头模组电源集成芯片

集成三路降压变换器&#xff0c;1CH高压BUCK,2CH低压Buck >HVBuck1:输入电压4.0V-20V,输出电流1.2A,Voo300mV/500mV >LVBuck2:输入电压2.7V-5V,输出电流0.6A , 固定1.8V输出 ;LVBuck3:输λ2.7V-5V,输出电流1.2A,可设定固定输出&#xff1a; 1 . 1 V / 1 . 2 V / 1 . 3 …

Intellj idea无法启动

个人电脑上安装的是2024.01版本的intellj idea作为开发工具&#xff0c;引入了javaagent作为工具包 但是在一次invaliad cache操作后&#xff0c;intellj idea就无法启动了&#xff0c;双击无响应。 重装了idea后也无效&#xff08;这个是有原因的&#xff0c;下面会讲&#…

开发人员使用的10大主流任务进度管理工具

本文将分享10大优质任务管理软件&#xff1a;Worktile、PingCode、Asana、Todoist、ClickUp、HubSpot Task Management、Hitask、Smartsheet、ProjectManager、Microsoft To Do。 任务管理软件不仅帮助个人和团队跟踪日常任务&#xff0c;还优化了工作流程&#xff0c;确保项目…

Linux/Ubuntu访问局域网共享文件夹

文件夹中找到“Other Location”&#xff0c;输入“smb:IP地址/共享文件夹名称”&#xff0c;然后点击connect后者直接回车即可&#xff01; End&#xff01;

Redis 主从,哨兵,cluster集群

概述 主从复制 主从复制是高可用Redis的基础&#xff0c;哨兵和集群都是在主从复制基础上实现高可用的。 主从复制主要实现了数据的多机备份&#xff0c;以及对于读操作的负载均衡和简单的故障恢复。 缺陷&#xff1a;故障恢复无法自动化&#xff1b;写操作无法负载均衡&am…

联合查询(多表查询)

多表查询是对多张表的数据取笛卡尔积&#xff08;关联查询可以对关联表使用别名&#xff09; 数据准备 insert into classes(name, desc) values (计算机系2019级1班, 学习了计算机原理、C和Java语言、数据结构和算法), (中文系2019级3班,学习了中国传统文学), (自动化2019级5…

Mysql 的第二次作业

一、数据库 1、登陆数据库 2、创建数据库zoo 3、修改数据库zoo字符集为gbk 4、选择当前数据库为zoo 5、查看创建数据库zoo信息 6、删除数据库zoo 1&#xff09;登陆数据库。 打开命令行&#xff0c;输入登陆用户名和密码。 mysql -uroot -p123456 ​ 2&#xff09;切换数据库…

前端修改audio背景色

1.查看浏览器设置Show user agent shadow DOM是否打开 2.打开可以查看audio Dom /** 去掉默认的背景颜色 */ audio::-webkit-media-controls-enclosure{background-color:unset; } 3.效果图
最新文章