seaborn是基于matplotlib.pyplot的高级库,绘制heatmap很方便。
首先是验证了子图和共色标功能。
然后初始绘制的z切片图不太好看,色标范围不固定,子图间距太大,图中目标不太符合俯视图。
最终先进行图像归一化,然后进行垂直镜像,调整子图间隙,添加了文件标题,理想的图如下。
等脚本跑一晚,明天就能收获11万张图片啦。代码如下。
import os
import sys
import datetime
import numpy as np
from tqdm import tqdm
import tkinter as tk
from tkinter import filedialog
from scipy.io import loadmat, savemat
import matplotlib.pyplot as plt
import seaborn as sns
def GetDataPath():
root = tk.Tk()
root.withdraw()
data_path = filedialog.askdirectory()
if not data_path:
print("Warning: path error!")
sys.exit(0)
else:
return data_path
def GetDataFileName(folder_path, file_type=".mat", data_list=[]):
file_dirs = os.listdir(folder_path)
file_dirs.sort()
for i in file_dirs:
if os.path.splitext(i)[1] == file_type:
data_list.append(i)
return data_list
def GetDataInfor():
data_path = GetDataPath()
data_list = GetDataFileName(data_path)
return data_path, data_list
def GetDate():
date = datetime.datetime.today()
date_name = (
str(date.date())[:4] + "_" + str(date.date())[5:7] + "_" + str(date.date())[8:]
)
return date_name
def ReadMat(data_path):
data = loadmat(data_path)
if "image" in data:
image = data["image"]
image_type = "image"
else:
image = data["target_image"]
image_type = "target_image"
return image, image_type
def DrawPicture(
data,
path,
file_name,
):
if data.shape[2] > 36:
fig, axn = plt.subplots(6, 7, sharex=True, sharey=True)
lim = 41
else:
fig, axn = plt.subplots(5, 6, sharex=True, sharey=True)
lim = 31
file = file_name[0:-4] + "_z_sectional_view.png"
cbar_ax = fig.add_axes([.90, .08, .01, .88])
for i, ax in enumerate(axn.flat):
if i < lim:
slice = data[:, :, i]
sns.heatmap(slice, ax=ax,
cbar=i == 0,
vmin=0, vmax=1,
cbar_ax=None if i else cbar_ax,
cmap="YlGnBu")
fig.tight_layout(rect=[0, 0, .9, 1])
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.05, hspace=0.05)
plt.title(file, fontsize =10, loc='right')
file = path + "/" + file
plt.savefig(file, dpi=800)
fig.clf()
cbar_ax.cla()
plt.cla()
plt.clf()
plt.close()
return True
if __name__ == "__main__":
################################################
# #
# Please set the parameters #
# #
################################################
new_folder = "pic"
date = GetDate()
new_path = new_folder + date
sns.set_theme()
if os.path.exists(new_path):
print(
"warning: The pictures folder already exists, please remember to check your data!"
)
else:
os.mkdir(new_path)
data_path, data_list = GetDataInfor()
with tqdm(
total=len(data_list),
desc="Processing",
leave=True,
ncols=100,
unit="B",
unit_scale=False,
) as pbar:
for i in range(len(data_list)):
data, image_type = ReadMat((data_path + "/" + data_list[i]))
data = (data - data.min())/(data.max() - data.min())
data = np.flip(data, axis=1)
if not DrawPicture(data, new_path, data_list[i]):
print("Draw error!")
pbar.update(1)
print("Program done! Get {} png files.".format(len(data_list)))