deeplearning-image-segmentation icon indicating copy to clipboard operation
deeplearning-image-segmentation copied to clipboard

图像预测的处理

Open jayboxyz opened this issue 6 years ago • 1 comments

1、重叠滑动窗口预测

jayboxyz avatar Dec 09 '19 02:12 jayboxyz

1、重叠滑动窗口预测2

参考:https://github.com/whjpython/Unet-Segmentation/blob/master/predict.py

import skimage.io
import numpy as np
import os
from osgeo.gdalconst import *
from osgeo import gdal
import tqdm
import time
from unet import Unet
import glob

back = [0,0,0]
stalk = [255,0,0]
twig = [0,255,0]
grain = [128,128,0]
COLOR_DICT = np.array([back,stalk, twig, grain])#上色代码只有4类

num_class = 3#网络识别类别

def predict_x(batch_x, model):
    """
    预测一个batch的数据
    """
    batch_y = model.predict(batch_x)#官方预测代码
    return batch_y#返回的是批次的预测(批次,图片长宽,种类)

def stretch(img):#%2線性拉伸
    n = img.shape[2]
    for i in range(n):
        c1 = img[:, :, i]
        c = np.percentile(c1[c1>0], 2)  # 只拉伸大于零的值
        d = np.percentile(c1[c1>0], 98)
        t = (img[:, :, i] - c) / (d - c)
        t *= 65535
        t[t < 0] = 0
        t[t > 65535] = 65535
        img[:, :, i] = t
    return img

def CreatTf(file_path_img,data,outpath):#原始文件,识别后的文件数组形式,新保存文件
    d,n = os.path.split(file_path_img)
    dataset = gdal.Open(file_path_img, GA_ReadOnly)#打开图片只读
    #data = gdal.Open(os.path.join(outpath,'gyey'+n))#打开标签图片
    #data_label = data.ReadAsArray(0, 0, data.RasterXSize, data.RasterYSize)#获取数据
    projinfo = dataset.GetProjection()#获取坐标系
    geotransform = dataset.GetGeoTransform()
    #band = dataset.RasterCount()
    format = "GTiff"
    driver = gdal.GetDriverByName(format)#数据格式
    name = n#输出文件名字
    dst_ds = driver.Create(os.path.join(outpath,name), dataset.RasterXSize, dataset.RasterYSize,
                              1, gdal.GDT_Byte )#创建一个新的文件
    dst_ds.SetGeoTransform(geotransform)#投影
    dst_ds.SetProjection(projinfo)#坐标
    dst_ds.GetRasterBand(1).WriteArray(data)
    dst_ds.FlushCache()


def make_prediction_img(x, target_size, batch_size, predict):  # 函数当做变量
    """
    滑动窗口预测图像。
    每次取target_size大小的图像预测,但只取中间的1/4,这样预测可以避免产生接缝。
    """
    # target window是正方形,target_size是边长
    quarter_target_size = target_size // 4
    half_target_size = target_size // 2

    pad_width = (
        (quarter_target_size, target_size),  # 32,128
        (quarter_target_size, target_size),  # 32,128
        (0, 0))

    # 只在前两维pad
    pad_x = np.pad(x, pad_width, 'constant', constant_values=0)  # 填充(x.shape[0]+160,x.shape[1]+160)
    pad_y = np.zeros(
        (pad_x.shape[0], pad_x.shape[1],num_class ),
        dtype=np.float32)  # 32位浮点型

    def update_prediction_center(one_batch):
        """根据预测结果更新原图中的一个小窗口,只取预测结果正中间的1/4的区域"""
        wins = []  # 窗口
        for row_begin, row_end, col_begin, col_end in one_batch:
            win = pad_x[row_begin:row_end, col_begin:col_end, :]  # 每次裁剪数组这里引入数据
            win = np.expand_dims(win, 0)  # 喂入数据的维度确定了喂入的数据要求是(n, 256,256,3)
            wins.append(win)
        x_window = np.concatenate(wins, 0)  # 一个批次的数据
        y_window = predict(x_window)  # 预测一个窗格,返回结果需要一个一个批次的取出来
        for k in range(len(wins)):  # 获取窗口编号
            row_begin, row_end, col_begin, col_end = one_batch[k]  # 取出来一个索引
            pred = y_window[k, ...]  # 裁剪出来一个数组,取出来一个批次数据
            y_window_center = pred[
                              quarter_target_size:target_size - quarter_target_size,
                              quarter_target_size:target_size - quarter_target_size,
                              :]  # 只取预测结果中间区域减去边界32[32:96,32:96]

            pad_y[
            row_begin + quarter_target_size:row_end - quarter_target_size,
            col_begin + quarter_target_size:col_end - quarter_target_size,
            :] = y_window_center  # 把预测的结果放到建立的空矩阵中[32:96,32:96]

    # 每次移动半个窗格
    batchs = []
    batch = []
    for row_begin in range(0, pad_x.shape[0], half_target_size):  # 行中每次移动半个[0,x+160,64]
        for col_begin in range(0, pad_x.shape[1], half_target_size):  # 列中每次移动半个[0,x+160,64]
            row_end = row_begin + target_size  # 0+128
            col_end = col_begin + target_size  # 0+128
            if row_end <= pad_x.shape[0] and col_end <= pad_x.shape[1]:  # 范围不能超出图像的shape
                batch.append((row_begin, row_end, col_begin, col_end))  # 取出来一部分列表[0,128,0,128]
                if len(batch) == batch_size:  # 够一个批次的数据
                    batchs.append(batch)
                    batch = []
    if len(batch) > 0:
        batchs.append(batch)
        batch = []
    for bat in tqdm.tqdm(batchs, desc='Batch pred'):  # 添加一个批次的数据
        update_prediction_center(bat)  # bat只是一个裁剪边界坐标
    y = pad_y[quarter_target_size:quarter_target_size + x.shape[0],
        quarter_target_size:quarter_target_size + x.shape[1],
        :]  # 收缩切割为原来的尺寸
    return y  # 原图像的预测结果

def main_p(model,allpath,sign='tif',changes=True):#读取图片函数
    print('执行预测...')
    img_p = glob.glob(os.path.join(allpath, "*.%s"%sign))
    for one_path in img_p:
        pic = skimage.io.imread(one_path)
        if changes:
            pic = stretch(pic)
        pic = pic.astype(np.float32)
        #pic = img_to_array(pic)
        y_probs = make_prediction_img(
            pic, 256, 8,
            lambda xx: predict_x(xx, model))  # 数据,目标大小,批次大小,返回每次识别的
        y_preds = np.argmax(y_probs, axis=2)
        d, n = os.path.split(one_path)
        t0 = time.time()
        change = y_preds.astype(np.uint8)
        outpath = os.path.join(d, 'result')
        if not os.path.exists(outpath):
            os.makedirs(outpath)
        CreatTf(one_path, change,outpath)  # 添加坐标系
        img_out = np.zeros(change.shape + (3,))
        for i in range(num_class):
            img_out[change == i, :] = COLOR_DICT[i]#对应上色
        change = img_out / 255
        save_file=os.path.join(outpath,n[:-4]+'_color'+'.png')
        skimage.io.imsave(save_file, change)
        print('预测耗费时间: %0.2f(min).' % ((time.time() - t0) / 60))
if __name__ == '__main__':
    model = Unet((256,256,3),num_class)
    p = r'E:\buildingone\output\YMDD.h5'  # 说明权重所在位置
    print("网络参数来自: '%s'." % p)
    model.load_weights(p)
    path = r'E:\mynet\end'
    main_p(model,path,changes=False)

jayboxyz avatar Dec 09 '19 02:12 jayboxyz