图片分割

举报
Nikolas 发表于 2021/01/10 22:43:48 2021/01/10
【摘要】 使用tensorflow对图片进行分割

## 1. 安装环境

!pip install -q git+https://github.com/tensorflow/examples.git

!pip install -q -U tfds-nightly

## 2. 导入依赖包


```python
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
import tensorflow_datasets as tfds
from IPython.display import clear_output
import matplotlib.pyplot as plt

tfds.disable_progress_bar()
```

## 3. 导入数据集


```python
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
```


```python
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0  # 归一化
    input_mask -= 1  # 标签减1
    return input_image, input_mask
```


```python
@tf.function
def load_image_train(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'], [128, 128])

    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)  # 图像左右翻转
        input_mask = tf.image.flip_left_right(input_mask)  # 轮廓标注左右翻转

    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask
```


```python
def load_image_test(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask
```

## 4. 拆分训练集和测试集


```python
train_length = info.splits['train'].num_examples
batch_size = 64
buffer_size = 1000
steps_per_epoch = train_length // batch_size

train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset['test'].map(load_image_test)

train_dataset = train.cache().shuffle(buffer_size).batch(batch_size).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(batch_size)
```

## 5. 查看原始图片和轮廓标注图片


```python
def display(display_list):
    plt.figure(figsize=(15, 15))
    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i + 1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()
```


```python
for image, mask in train.take(1):
    sample_image, sample_mask = image, mask
display([sample_image, sample_mask])
```

输出通道为3


```python
output_channels = 3
```

## 6. 使用预训练的MobileNetV2作为编码器/下采样器


```python
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)
layer_names = [
    'block_1_expand_relu'
    , 'block_3_expand_relu'
    , 'block_6_expand_relu'
    , 'block_13_expand_relu'
    , 'block_16_project'
]
layers = [base_model.get_layer(name).output for name in layer_names]
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False
```

## 7. 使用pix2pix作为解码器/上采样器


```python
up_stack = [
    pix2pix.upsample(512, 3)
    , pix2pix.upsample(256, 3)
    , pix2pix.upsample(128, 3)
    , pix2pix.upsample(64, 3)
]
```


```python
def unet_model(output_channels):
    inputs = tf.keras.Input(shape=[128, 128, 3])
    x = inputs

    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    last = tf.keras.layers.Conv2DTranspose(
        output_channels, 3, strides=2
        , padding='same'
    )
    x = last(x)
    return tf.keras.Model(inputs=inputs, outputs=x)
```

## 8. 建立模型


```python
model = unet_model(output_channels)
model.compile(optimizer='adam'
              , loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
              , metrics=['accuracy'])
```

## 9. 查看模型结构


```python
tf.keras.utils.plot_model(model, show_shapes=True)
```

## 10. 查看训练前模型的预测结果


```python
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]
```


```python
def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
                 create_mask(model.predict(sample_image[tf.newaxis, ...]))])


show_predictions()
```

## 11. 回调函数


```python
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print('Eepoch {}'.format(epoch + 1))


epochs = 20
val_subsplits = 5
validation_steps = info.splits['test'].num_examples // batch_size // val_subsplits
```

## 12. 训练模型


```python
history = model.fit(train_dataset, epochs=epochs
                    , steps_per_epoch=steps_per_epoch
                    , validation_steps=validation_steps
                    , validation_data=test_dataset
                    , callbacks=[DisplayCallback()]
                    )
```

# 13. 损失可视化


```python
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = history.epoch

plt.figure()
plt.plot(epochs, loss, 'g', label='loss')
plt.plot(epochs, val_loss, 'yo', label='val_loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.ylim([0, 1])
plt.legend()
plt.show()
```

## 14. 进行预测


```python
show_predictions(test_dataset, 3)
```

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。