0


联邦学习开源平台Flower:在移动设备上运行联邦学习

Flower开源平台的使用

Flower官网:https://flower.ai/
官网doc:https://flower.ai/docs/
Flower github:https://github.com/adap/flower

为什么用Flower

Flower可以将预先写好的集中式机器学习代码以联邦学习的方式运行(只需要少量的修改)。并且他可以在windows环境下模拟联邦学习场景,非常适合实验。

Flower开源了很多联邦学习的基线算法的示例,可以轻松入门。

本文主要将一下如何使用Flower,通过对官网给出的quickstart示例,结合其他csdn,给出运行Flower的步骤以及在android上运行联邦学习的示例。

Flower安装

Flower安装需要至少python 3.8版本以上,推荐使用python 3.10 及以上,本文将使用python 3.8 运行下列的示例。
官网安装教程

创建自己的虚拟环境:conda

通过anaconda创建虚拟环境:

conda create -n flwr python=3.8

激活环境

conda activate flwr

直接安装(稳定版):

pip install flwr

注意:直接pip安装的是flower的稳定版本,与官网的版本可能不同。

查看Flower的安装版本

python -c "import flwr;print(flwr.__version__)"

我安装的

1.10.0

版本。
后面的示例请移至对应版本的doc官网,具体版本的官网在doc官网的左下角

Versions

可以找到
Flower V-1.10.0的doc地址

Flower实例运行(quickstart pytorch和quickstart tensorflow)

quickstart pytorch

代码参考csdn

编写代码

创建两个文件:

client.py

server.py
client.py

:编写客户端运行文件,主要包括传统的机器学习流程代码和Flower客户端类实现。具体如下

# 传统的机器学习流程代码from collections import OrderedDict

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

DEVICE = torch.device("cuda:0"if torch.cuda.is_available()else"cpu")# 定义神经网络模型classNet(nn.Module):def__init__(self)->None:super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)defforward(self, x: torch.Tensor)-> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1,16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))return self.fc3(x)# 定义模型训练流程deftrain(net, trainloader, epochs):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)for _ inrange(epochs):for images, labels in tqdm(trainloader):
            optimizer.zero_grad()
            criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
            optimizer.step()# 定义模型推理流程deftest(net, testloader):
    criterion = torch.nn.CrossEntropyLoss()
    correct, loss =0,0.0with torch.no_grad():for images, labels in tqdm(testloader):
            outputs = net(images.to(DEVICE))
            labels = labels.to(DEVICE)
            loss += criterion(outputs, labels).item()
            correct +=(torch.max(outputs.data,1)[1]== labels).sum().item()
    accuracy = correct /len(testloader.dataset)return loss, accuracy

# 定义数据集的获取defload_data():"""Load CIFAR-10 (training and test set)."""
    trf = Compose([ToTensor(), Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    trainset = CIFAR10("./data", train=True, download=True, transform=trf)
    testset = CIFAR10("./data", train=False, download=True, transform=trf)return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)# 生成模型对象,实际获取训练与测试数据集

net = Net().to(DEVICE)
trainloader, testloader = load_data()
# 客户端类实现classFlowerClient(fl.client.NumPyClient):# 获取本地模型对应的参数defget_parameters(self, config):return[val.cpu().numpy()for _, val in net.state_dict().items()]# 接收模型参数,并更新本地模型defset_parameters(self, parameters):
        params_dict =zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v)for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)# 本地模型训练,会先调用 set_parameters() 基于收到的全局模型参数更新本地模型deffit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=1)return self.get_parameters(config={}),len(trainloader.dataset),{}# 基于测试数据集进行测试defevaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(net, testloader)return loss,len(testloader.dataset),{"accuracy": accuracy}# 启动 Flower 客户端

fl.client.start_numpy_client(
    server_address="127.0.0.1:8080",
    client=FlowerClient(),)

客户端类会继承flwr的

NumPyClient

类,当服务器选择一个特定的客户端进行训练时,他会通过网络发送训练指令。

这里服务器与客户端在同一个主机上运行,因此server_address就可以用本地回环地址

127.0.0.1

,FL服务器默认端口使用

8080

。如果服务器和客户端不是同一台主机,则可以使用真实的IP地址。

server.py

from typing import List, Tuple

import flwr as fl
from flwr.common import Metrics

# 定义指标聚合方法defweighted_average(metrics: List[Tuple[int, Metrics]])-> Metrics:
    accuracies =[num_examples * m["accuracy"]for num_examples, m in metrics]
    examples =[num_examples for num_examples, _ in metrics]return{"accuracy":sum(accuracies)/sum(examples)}# 定义模型聚合策略

strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)# 启动 Flower 服务端

fl.server.start_server(
    server_address="0.0.0.0:8080",# 服务器地址
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,)

在服务器代码中定义聚合策略。

运行python文件

先进入文件所在的文件夹,并激活虚拟环境。

直接运行

server.py
python server.py

重新打开一个终端运行

client.py
python client.py

服务器默认最少的客户端数量是2,所以要运行两个client才能进行联邦学习训练。
(重新打开一个终端,再运行一次

client.py

/ 一共三个终端)

运行结果

server

在这里插入图片描述

clinet1

在这里插入图片描述

client2

在这里插入图片描述

quickstart tensorflow

quickstart pytorch

类似,创建两个python文件:

client.py

server.py

,代码参考官网给出的代码,但需要稍微调整一下。

client.py

:需要注意的是server的地址,可以用上面那个地址

import flwr as fl
import tensorflow as tf
# 加载数据(x_train, y_train),(x_test, y_test)= tf.keras.datasets.cifar10.load_data()# 加载模型,10分类模型MobilNetV2
model = tf.keras.applications.MobileNetV2((32,32,3), classes=10, weights=None)
model.compile("adam","sparse_categorical_crossentropy", metrics=["accuracy"])# 定义client类,也就是flower客户端classCifarClient(fl.client.NumPyClient):defget_parameters(self, config):return model.get_weights()deffit(self, parameters, config):
        model.set_weights(parameters)
        model.fit(x_train, y_train, epochs=1, batch_size=32, steps_per_epoch=3)return model.get_weights(),len(x_train),{}defevaluate(self, parameters, config):
        model.set_weights(parameters)
        loss, accuracy = model.evaluate(x_test, y_test)return loss,len(x_test),{"accuracy":float(accuracy)}
    
fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client())# 启动flower客户端
server.py

:注意server的地址

import flwr as fl
if __name__ =='__main__':
    fl.server.start_server(server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=3))# 启动flower服务器

如果按照官网里的那个地址,我运行会报错,显示连接不到地址
在这里插入图片描述

运行结果

server

在这里插入图片描述

client1

在这里插入图片描述

client2

在这里插入图片描述

Flower实例运行(quickstart android)

github官网安卓示例
在安卓客户端上使用TFLite进行CIFAR10的联邦学习,将CIFAR-10数据集随机分配给10个客户端,服务器用python运行,客户端运行在安卓上。后台线程是通过安卓的WorkManager库建立的,因此它可以在8到13的安卓版本上运行。

首先需要有安卓虚拟机,可以下载

android studio

,本文将不在赘述。

下载源码:从github上下载example里的源码。源码里有示例的apk文件,需要先把这个apk文件下载下来(注册个账号,直接下载到本地)
https://www.dropbox.com/s/ii0vwrjrpupifiv/flower-client.apk?dl=0

源码中需要最少四个android设备才能运行联邦学习,当然,这个可以在

server.py

文件中更改(我的主机运行不了那么多虚拟设备,所以就测试2个),具体如下:

        min_fit_clients=2,# 根据自己需要更改最少客户端
        min_evaluate_clients=2,
        min_available_clients=2,

创建两个Android studio虚拟设备,打开这两个设备,将刚下载的

apk文件

拖拽到虚拟机里,虚拟机会自动下载apk应用。下载完会有一个flower的应用,点开如下图所示:
在这里插入图片描述
在虚拟机app里输入

client id

server IP / port
server.py

代码里的id是

0.0.0.0:8080

,这里就输入真实的ip就行了,Port就是

8080

在这里插入图片描述
激活虚拟环境并下载依赖项:

pip install -r requirements.txt

运行服务器:

python server.py

依次点击虚拟设备app里的三个黄色按钮

运行结果:

server

在这里插入图片描述

client1

在这里插入图片描述

client2

在这里插入图片描述

我目前只能运行apk已有的示例,examples里的android项目没有搭建成功

在Android studio中构建案例中的android的环境时,报错了:无法找到依赖项TFLite;位置在app文件夹下的

build.gradle


在这里插入图片描述

标签: 开源 学习 笔记

本文转载自: https://blog.csdn.net/weixin_47954484/article/details/141428561
版权归原作者 哥不在辉煌--借用长辈的网名 所有, 如有侵权,请联系我们删除。

“联邦学习开源平台Flower:在移动设备上运行联邦学习”的评论:

还没有评论