《深度学习:卷积神经网络从入门到精通》——3.7.4 交通路网的自动提取代码及说明
3.7.4 交通路网的自动提取代码及说明
在图像块分类LeNet训练好后,接下来就可以利用这个网络从遥感图像中自动提取交通路网。提取过程需要用到lenet.prototxt文件和3个Python程序。其中,lenet.prototxt文件用来根据表3.11设置LeNet的结构,通过结合已经训练好的权值和偏置,对图像块进行分类。
表3.11 图像块分类LeNet程序在lenet.prototxt文件中设置的参数值
参数 设置值
3个Python程序分别是batch_classify.py、batch_remove_scattered_point.py和draw_points.py。其中,batch_classify.py调用训练好的LeNet对待提路网遥感图像中的所有图像块进行分类,如果是正例,就把中心标注为红色,否则保持原色,得到初步的路网提取结果。batch_remove_scattered_point.py用来去除较小的红色连通区域。draw_points.py则用来显示路网提取的最终结果。
下面是batch_classify.py、batch_remove_scattered_point.py和draw_points.py的代码及说明。
1.?batch_classify.py的代码及说明
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image,ImageDraw
import caffe
def label_image(dir,image_file): # 该函数用来在图像中标记正例中心点
image = Image.open(dir+image_file) # 读取图像
h,w = image.size # 提取图像的高和宽
wdSize = [25,25] # 定义图像块大小
plist = [] # 用来存储正例中心点坐标
for i in range(0,h-wdSize[1],1): # 按列扫描图像
for j in range(0,w-wdSize[0],1): # 按行扫描图像
region = (i,j,i+wdSize[0],j+wdSize[1])
roi = image.crop(region)
roi = np.array(roi,dtype='float32') # 转化为矩阵
roi = roi/255 # 归一化
prediction=net.predict([roi],oversample = False) # 预测图像块的类别
if 0 != prediction[0].argmax(): # 如果不是反例,添加正例图像块中心坐标
到plist
plist.append([i+int((wdSize[0]+1)/2),j+int((wdSize[1]+1)/2)])
for i in range(len(plist)): # 该循环用来在遥感图像中显示正例中心点
draw = ImageDraw.Draw(image)
x,y=plist[i]
draw.arc((x-1, y-1, x+1, y+1), 0, 360, fill=255)
image.save(r'./resultImages/'+image_file) # 保存带正例中心点的结果图像
return plist
def write_list(my_list,file_name): # 该函数用来把列表数据存储到文件中
file = open(file_name,'w')
for e in my_list:
file.writelines(str(e[0])+' '+str(e[1])+'\n')
file.close()
model_file=r'./LeNet.prototxt' # 提取网络模型框架
trained_model=r'./snapshots/_iter_200000.caffemodel' # 提取训练200?000次的模型参数
imageDir = r'./images/testImages/' # 提取测试图像目录
net = caffe.Classifier(model_file,trained_model,channel_swap=(2,1,0))
caffe.set_mode_gpu() # 选择GPU运行模式
files = os.listdir(imageDir) # 提取测试图像目录下的所有文件
for image_file in files: # 对每个图像文件提取正例中心并存储到相应文件
plist = label_image(imageDir,image_file)
name,suffix = image_file.split('.')
write_list(plist,r'./resultPoints/'+name+r'.txt')
注意:在上述代码中,net.predict是Caffe自带函数,用来对未知样本进行分类。
2.?batch_remove_scattered_point.py的代码及说明
import os
import numpy as np
import scipy.ndimage as ndi
from skimage import measure,color
import matplotlib.pyplot as plt
def read_file(filename): # 该函数用来读取正例中心坐标
points = [];
with open(filename,'r') as f:
for line in f.readlines():
linestr = line.strip() # strip()方法用于移除字符串头尾指定的字符(默
认为空格)
x,y = linestr.split(' ')
print x,y
points.append([int(x),int(y)])
return points
def write_list(mylist,filename): # 该函数用来存储正例中心坐标
f = open(filename,'w')
for p in mylist:
f.writelines(str(p[0])+' '+str(p[1])+'\n')
f.close()
def get_adj(points,wdSize): # 该函数用来生成一个关于正例中心的二值矩阵
adj = np.zeros(wdSize)
for p in points:
adj[p[0],p[1]]=1 # 将矩阵adj位置为p[0],p[1]的值置为1
return adj
def adj_to_coordinate(adj): # 把二值矩阵的正例中心点转换成坐标数据
points = []
row,col = adj.shape # 把adj的宽和高赋给row,col
for i in range(row):
for j in range(col):
if adj[i,j]==1:
points.append([i,j]) # 把[i,j]添加到points串中
return points
def get_labels(adj): # 标记8连通区域,并按顺序编号
labels=measure.label(adj,connectivity=2) # 标记8连通区域
print('regions number:',labels.max()+1) # 显示连通区域的个数
return labels
def zeros_elements(idx,image): # 该函数用来将图像的某个区域置为0
for i in idx:
image[i[0],i[1]]=0
def remove_unroad(image,label_image,threshold): # 该函数用来去除小连通区域
rp = measure.regionprops(label_image) # 检测连通区域的属性
for i in range(len(rp)):
flag = False
if rp[i].area<threshold: # 标注面积小于阈值的区域
flag = True
if flag == True:
idx = np.argwhere(label_image == i+1) # 统计小区域的点集
zeros_elements(idx,image) # 将小区域的点集置零
def max_points(plist): # 该函数用来计算点列的横纵坐标的最大值
maxp = [0,0];
for p in plist:
if p[0]>maxp[0]:
maxp[0] = p[0]
if p[1]>maxp[1]:
maxp[1] = p[1]
return maxp
points_dir = './resultPoints/' # 初始正例中心坐标目录
save_dir = './resultPointsFinal/' # 优化后的正例中心坐标目录
pfiles = os.listdir(points_dir)
wdSize = [0,0]
for pf in pfiles:
points = read_file(points_dir+pf)
if 0 == len(points):
continue
wdSize = max_points(points)
wdSize[0] = wdSize[0] + 1
wdSize[1] = wdSize[1] + 1
adj = get_adj(points,wdSize) # 生成正例中心的二值矩阵
labels = get_labels(adj) # 标记连通区域并按顺序编号
remove_unroad(adj,labels,200) # 去除少于200个点的连通区域
new_points = adj_to_coordinate(adj) # 把剩余的正例中心点转换成坐标数据
write_list(new_points,save_dir+pf)
3.?draw_points.py的代码及说明
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image,ImageDraw
def read_file(filename): # 该函数用来按整数格式读取正例中心点的坐标数据
points = [];
with open(filename,'r') as f:
for line in f.readlines():
linestr = line.strip()
x,y = linestr.split(' ')
print x,y
points.append([int(x),int(y)])
return points
def write_labeled_img(points,image,filename): # 该函数用来将正例中心点标注为红色
for i in range(len(points)):
draw = ImageDraw.Draw(image)
x,y=points[i]
image.putpixel([x,y],(255,0,0))
image.save(filename)
points_dir = r'./ resultPointsFinal/' # 优化后的正例中心坐标目录
image_dir = r'./images/testImages/' # 测试图像目录
save_dir = r'./resultImgesFinal/' # 图像路网提取结果的保存目录
pfile = os.listdir(points_dir)
for pf in pfile:
points = read_file(points_dir+pf)
name,suffix = pf.split('.')
image = Image.open(image_dir+name+'.bmp')
write_labeled_img(points,image,save_dir+name+'.bmp')
- 点赞
- 收藏
- 关注作者
评论(0)