2.12 使用multiprocessing加速可视化脚本

早上起来发现昨晚启动的脚本也太慢了,要跑完得130个小时,cpu占用也才5%。。。。

觉得这些文件都是相互独立的,可以独立处理,应该可以并行计算。但是手动划分区块太麻烦了,最后决定学习利用multiprocessing库来并行加速加速画图。

主要的坑点有两个,一个是multiprocessing中的Pool处理的函数,只能传入一个迭代的参数。

一开始准备用全局变量穿进一些不变的参数,捣鼓半天发现新的进程不继承__name__里定义的全局变量,放弃了。

后来发现functools可以处理函数,将函数的一些参数固定住,变成新的函数,很适合用于multiprocessing。

最后发现了一个nvidia工程师,国人大牛的代码,直接学习了,地址:https://leimao.github.io/blog/Python-tqdm-Multiprocessing/

附上自己垃圾的代码。。。

import os
import sys
import datetime
from unittest import result
import numpy as np
from tqdm import tqdm
import tkinter as tk
from tkinter import filedialog
from scipy.io import loadmat, savemat
import matplotlib.pyplot as plt
import seaborn as sns
from multiprocessing import Pool
from functools import partial


def GetDataPath():
    root = tk.Tk()
    root.withdraw()
    data_path = filedialog.askdirectory()
    if not data_path:
        print("Warning: path error!")
        sys.exit(0)
    else:
        return data_path


def GetDataFileName(folder_path, file_type=".mat", data_list=[]):
    file_dirs = os.listdir(folder_path)
    file_dirs.sort()
    for i in file_dirs:
        if os.path.splitext(i)[1] == file_type:
            data_list.append(i)
    return data_list


def GetDataInfor():
    data_path = GetDataPath()
    data_list = GetDataFileName(data_path)
    return data_path, data_list


def GetDate():
    date = datetime.datetime.today()
    date_name = (
        str(date.date())[:4] + "_" + str(date.date())[5:7] + "_" + str(date.date())[8:]
    )
    return date_name


def ReadMat(data_path):
    data = loadmat(data_path)
    if "image" in data:
        image = data["image"]
        image_type = "image"
    else:
        image = data["target_image"]
        image_type = "target_image"
    return image, image_type


def DrawPicture(file_name, data_path, new_path):

    sns.set_theme()
    data, image_type = ReadMat((data_path + "/" + file_name))
    data = (data - data.min())/(data.max() - data.min())
    data = np.flip(data, axis=1)
    if data.shape[2] > 36:
        fig, axn = plt.subplots(6, 7, sharex=True, sharey=True)
        lim = 41
    else:
        fig, axn = plt.subplots(4, 8, sharex=True, sharey=True)
        lim = 31
    file = file_name[0:-4] + "_z_sectional_view.png"
    fig.tight_layout(rect=[0, 0, .9, 1])
    cbar_ax = fig.add_axes([.90, .09, .01, .86])
    for i, ax in enumerate(axn.flat):
        if i < lim:
            slice = data[:, :, i]
            sns.heatmap(slice, ax=ax,
                    cbar=i == 0,
                    vmin=0, vmax=1,
                    cbar_ax=None if i else cbar_ax,
                    cmap="YlGnBu")
    
    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.05, hspace=0.05)
    plt.title(file[0:-4], fontsize =10, loc='right')
    file = new_path + "/" + file
    plt.savefig(file, dpi=800)
    fig.clf()
    cbar_ax.cla()
    plt.cla()
    plt.clf()
    plt.close()

    return True


def run_apply_async_multiprocessing(func, argument_list, num_processes):

    pool = Pool(processes=num_processes)

    jobs = [pool.apply_async(func=func, args=(*argument,)) if isinstance(argument, tuple) else pool.apply_async(func=func, args=(argument,)) for argument in argument_list]
    pool.close()
    result_list_tqdm = []
    for job in tqdm(jobs):
        result_list_tqdm.append(job.get())

    return result_list_tqdm


if __name__ == "__main__":

    ################################################
    #                                              #
    #          Please set the parameters           #
    #                                              #
    ################################################

    new_folder = "pic"

    num_processes = os.cpu_count()
    date = GetDate()
    new_path = new_folder + date
    if os.path.exists(new_path):
        print(
            "warning: The pictures folder already exists, please check your data!"
        )
    else:
        os.mkdir(new_path)
    data_path, data_list = GetDataInfor()    


    partial_func = partial(DrawPicture, data_path=data_path, new_path=new_path)
    result_list = run_apply_async_multiprocessing(func=partial_func, argument_list=data_list, num_processes=num_processes)
    print("Program done! Get {} png files.".format(len(data_list)))
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇