《深度学习:卷积神经网络从入门到精通》——3.7.4 交通路网的自动提取代码及说明

举报
华章计算机 发表于 2019/06/06 09:10:59 2019/06/06
【摘要】 本书摘自《深度学习:卷积神经网络从入门到精通》——书中第3章,第3.7.4节,作者是李玉鑑、张婷、单传辉、刘兆英等。

3.7.4 交通路网的自动提取代码及说明

在图像块分类LeNet训练好后,接下来就可以利用这个网络从遥感图像中自动提取交通路网。提取过程需要用到lenet.prototxt文件和3个Python程序。其中,lenet.prototxt文件用来根据表3.11设置LeNet的结构,通过结合已经训练好的权值和偏置,对图像块进行分类。

表3.11 图像块分类LeNet程序在lenet.prototxt文件中设置的参数值

参数 设置值

image.png

3个Python程序分别是batch_classify.py、batch_remove_scattered_point.py和draw_points.py。其中,batch_classify.py调用训练好的LeNet对待提路网遥感图像中的所有图像块进行分类,如果是正例,就把中心标注为红色,否则保持原色,得到初步的路网提取结果。batch_remove_scattered_point.py用来去除较小的红色连通区域。draw_points.py则用来显示路网提取的最终结果。

下面是batch_classify.py、batch_remove_scattered_point.py和draw_points.py的代码及说明。

1.?batch_classify.py的代码及说明

import os

import sys

import numpy as np

import matplotlib.pyplot as plt

import matplotlib.image as mpimg

from PIL import Image,ImageDraw

import caffe

def label_image(dir,image_file): # 该函数用来在图像中标记正例中心点

    image = Image.open(dir+image_file) # 读取图像

    h,w = image.size # 提取图像的高和宽

    wdSize = [25,25] # 定义图像块大小

    plist = [] # 用来存储正例中心点坐标

    for i in range(0,h-wdSize[1],1): # 按列扫描图像

        for j in range(0,w-wdSize[0],1): # 按行扫描图像

            region = (i,j,i+wdSize[0],j+wdSize[1])

            roi = image.crop(region)

            roi = np.array(roi,dtype='float32') # 转化为矩阵

            roi = roi/255 # 归一化

            prediction=net.predict([roi],oversample = False)  # 预测图像块的类别

            if 0 != prediction[0].argmax(): # 如果不是反例,添加正例图像块中心坐标

  到plist

                plist.append([i+int((wdSize[0]+1)/2),j+int((wdSize[1]+1)/2)])

    for i in range(len(plist)): # 该循环用来在遥感图像中显示正例中心点

        draw = ImageDraw.Draw(image)

        x,y=plist[i]

        draw.arc((x-1, y-1, x+1, y+1), 0, 360, fill=255)

    image.save(r'./resultImages/'+image_file) # 保存带正例中心点的结果图像

    return plist

def write_list(my_list,file_name):      # 该函数用来把列表数据存储到文件中

    file = open(file_name,'w')

    for e in my_list:

        file.writelines(str(e[0])+' '+str(e[1])+'\n')

    file.close()

model_file=r'./LeNet.prototxt'    # 提取网络模型框架

trained_model=r'./snapshots/_iter_200000.caffemodel' # 提取训练200?000次的模型参数

imageDir = r'./images/testImages/'    # 提取测试图像目录

net = caffe.Classifier(model_file,trained_model,channel_swap=(2,1,0))

caffe.set_mode_gpu()      # 选择GPU运行模式

files = os.listdir(imageDir)    # 提取测试图像目录下的所有文件

for image_file in files:      # 对每个图像文件提取正例中心并存储到相应文件

    plist = label_image(imageDir,image_file)

    name,suffix = image_file.split('.')

    write_list(plist,r'./resultPoints/'+name+r'.txt')

注意:在上述代码中,net.predict是Caffe自带函数,用来对未知样本进行分类。

2.?batch_remove_scattered_point.py的代码及说明

import os

import numpy as np

import scipy.ndimage as ndi

from skimage import measure,color

import matplotlib.pyplot as plt

def read_file(filename):      # 该函数用来读取正例中心坐标

    points = [];

    with open(filename,'r') as f:

        for line in f.readlines():

            linestr = line.strip() # strip()方法用于移除字符串头尾指定的字符(默

  认为空格)

            x,y = linestr.split(' ')

            print x,y

            points.append([int(x),int(y)])

    return points

def write_list(mylist,filename):      # 该函数用来存储正例中心坐标

    f = open(filename,'w')

    for p in mylist:

        f.writelines(str(p[0])+' '+str(p[1])+'\n')

    f.close()

def get_adj(points,wdSize):      # 该函数用来生成一个关于正例中心的二值矩阵

    adj = np.zeros(wdSize)

    for p in points:

        adj[p[0],p[1]]=1      # 将矩阵adj位置为p[0],p[1]的值置为1

    return adj

def adj_to_coordinate(adj):      # 把二值矩阵的正例中心点转换成坐标数据

    points = []

    row,col = adj.shape    # 把adj的宽和高赋给row,col

    for i in range(row):

        for j in range(col):

            if adj[i,j]==1:

                points.append([i,j]) # 把[i,j]添加到points串中

return points

def get_labels(adj):      # 标记8连通区域,并按顺序编号

    labels=measure.label(adj,connectivity=2)  # 标记8连通区域

    print('regions number:',labels.max()+1)  # 显示连通区域的个数

    return labels

def zeros_elements(idx,image):    # 该函数用来将图像的某个区域置为0

    for i in idx:

        image[i[0],i[1]]=0

def remove_unroad(image,label_image,threshold): # 该函数用来去除小连通区域

    rp = measure.regionprops(label_image)  # 检测连通区域的属性

    for i in range(len(rp)):

        flag = False

        if rp[i].area<threshold: # 标注面积小于阈值的区域

            flag = True

        if flag == True:

            idx = np.argwhere(label_image == i+1) # 统计小区域的点集

            zeros_elements(idx,image) # 将小区域的点集置零

def max_points(plist):    # 该函数用来计算点列的横纵坐标的最大值

    maxp = [0,0];

    for p in plist:

        if p[0]>maxp[0]:

            maxp[0] = p[0]

        if p[1]>maxp[1]:

            maxp[1] = p[1]

    return maxp

points_dir = './resultPoints/'      # 初始正例中心坐标目录

save_dir = './resultPointsFinal/'    # 优化后的正例中心坐标目录

pfiles = os.listdir(points_dir)

wdSize = [0,0]

for pf in pfiles:

    points = read_file(points_dir+pf)

    if 0 == len(points):

        continue

    wdSize = max_points(points)

    wdSize[0] = wdSize[0] + 1

    wdSize[1] = wdSize[1] + 1

    adj = get_adj(points,wdSize)    # 生成正例中心的二值矩阵

    labels = get_labels(adj)      # 标记连通区域并按顺序编号

    remove_unroad(adj,labels,200)    # 去除少于200个点的连通区域

    new_points = adj_to_coordinate(adj) # 把剩余的正例中心点转换成坐标数据

write_list(new_points,save_dir+pf)

3.?draw_points.py的代码及说明

import os

import matplotlib.pyplot as plt

import matplotlib.image as mpimg

from PIL import Image,ImageDraw

def read_file(filename):        # 该函数用来按整数格式读取正例中心点的坐标数据

    points = [];

    with open(filename,'r') as f:

        for line in f.readlines():

            linestr = line.strip()

            x,y = linestr.split(' ')

            print x,y

            points.append([int(x),int(y)])

    return points

def write_labeled_img(points,image,filename): # 该函数用来将正例中心点标注为红色

    for i in range(len(points)):

        draw = ImageDraw.Draw(image)

        x,y=points[i]

        image.putpixel([x,y],(255,0,0))

    image.save(filename)

points_dir = r'./ resultPointsFinal/'  # 优化后的正例中心坐标目录

image_dir = r'./images/testImages/'  # 测试图像目录

save_dir = r'./resultImgesFinal/'    # 图像路网提取结果的保存目录

pfile = os.listdir(points_dir)

for pf in pfile:

    points = read_file(points_dir+pf)

    name,suffix = pf.split('.')

    image = Image.open(image_dir+name+'.bmp')

    write_labeled_img(points,image,save_dir+name+'.bmp')


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。