flask部署pytorch-服务端

举报
Nikolas 发表于 2021/01/03 21:31:00 2021/01/03
【摘要】 用flask部署pytorch

## 1.导入依赖包


```python
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms as T
from torchvision.models import resnet50
```

## 2.初始化一个flask


```python
app = flask.Flask(__name__)
model = None
use_gpu = False  # 是否使用GPU训练模型

with open('./data/class_map.txt', 'r') as f:
    label_map = eval(f.read())  # 转化成字典
```

## 3.加载模型


```python
def load_model():
    global model
    model = resnet50(pretrained=True)
    model.eval()  # 不启用 BatchNormalization 和 Dropout
    if use_gpu:
        model.cuda()  # 将模型加载到GPU上
```

## 4.处理接收到的图片


```python
def prepare_image(image, target_size):
    if image.mode != 'RGB':
        image = image.convert('RGB')  # 使用'RGB'模式读取图片

    image = T.Resize(target_size)(image)
    image = T.ToTensor()(image)

    image = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

    image = image[None]
    if use_gpu:
        image = image.cuda()
    return torch.autograd.Variable(image, volatile=True)  # 自动微分变量
```

## 5.定义路由


```python
@app.route('/predict', methods=['POST'])
def predict():
    data = {'success': False}

    if flask.request.method == 'POST':
        if flask.request.files.get('image'):
            image = flask.request.files['image'].read()
            image = Image.open(io.BytesIO(image))  # 将字节对象转为Byte字节流数据

            image = prepare_image(image, target_size=(224, 224))

            preds = F.softmax(model(image), dim=1)
            results = torch.topk(preds.cpu().data, k=3, dim=1)  # 返回Tensor中的前k个元素以及元素对应的索引值
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())  # 把tensor转换成numpy的格式

            data['predictions'] = list()

            for prob, label in zip(results[0][0], results[1][0]):
                label_name = label_map[label]
                r = {'label': label_name, 'probability': float(prob)}
                data['predictions'].append(r)

            data['success'] = True

    return flask.jsonify(data)  # 将字典转成json字符串
```

## 6.主函数


```python
if __name__ == '__main__':
    load_model()
    app.run()
```

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。