2023年8月2日发(作者:)
⽤⾃⼰编写进程池实现DataLoader(⽗⼦进程通信)⽬录是多进程安全的队列,可以使⽤Queue实现多进程之间的数据传递, 底层使⽤管道pipe,同步信号量和互斥锁实现。multiprocess库提供了pool进程池,直接实现了多进程之间的通信,pool的使⽤场景也很多,这⾥不做介绍。但pool不能实现⽗⼦进程之间的通信,要想实现⽗⼦通信,需要⾃⼰⽤Queue写⼀个进程池,通过创建⼦队列和⽗队列来进⾏⽗⼦通信。这个⽤途就更⼴泛了,⽐如想要开多个⼦进程帮你处理⼀些数据或⽂件,最后把结果都收回到主进程,这也就是⾃⼰编写DataLoader,很多实际项⽬中需要⾃⼰写⼀个数据预处理代码,所以要⾃⼰重写DataLoader。本⽂先讲如何实现⽗⼦进程通信,之后就很容易实现DataLoader了。1. ⽤Queue写进程池从⽽实现⽗⼦进程通信思路其实很简单,先分别创建⼦队列⽗队列,然后在主进程中把⽂件放⼊⼦队列,多个⼦进程各⾃从⼦队列中取⽂件并进⾏处理,再把各⾃处理好的数据放⼊⽗队列,主进程等所有⽂件处理完后从⽗队列取处理好的数据。这⾥⽐较需要注意的是⽤None作为退出信号,因为如果是⽤Queue的get()函数(默认block=False)则会在队列为空时阻塞等待⽽不退出;如果⽤get_nowait()也即get(block=True),则会在队列为空时报出异常;设置成当队列为空则退出也是不合理的。因此可以考虑⽤None作为退出信号,由主进程在⽂件都放⼊⼦队列后发出。这⾥省略了具体处理data的函数。由于⽐较容易理解所以直接上代码。from multiprocessing import Process, Queuedef data_preprocess(file, args): # preprocess data passdef preprocess_module(queue_in, queue_res, args): while True: file = () # exit if getting terminating signal if file is None: break instance = data_preprocess(file, args) queue_(instance) # put None into queue_res as terminating signal queue_(None)if __name__ == '__main__': nthread = 2 data_path = ' ' files = r(data_path) queue_in = Queue(nthread) # child queue, maxlength of Queue equals to nthread. queue_res = Queue() # parant queue # create multiprocess processes = [Process(target=perprocess_module, args=( queue_in, queue_res, args)) for _ in range(nthread)]
# start multiprocess for each in processes: # terminate all child processes when parant process normally exits = True () # feed files to multiprocess
for file in files: queue_(file)
# put None into queue_in as terminating signal for i in range(_num): queue_(None) # parant process fetches results cnt_None = 0 while True: if cnt_None == _num: break t = queue_() if t is None: cnt_None += 1 else: res_(t) # join multiprocess try: for each in processes: () except Exception as e: print(str(e))2. ⾃⼰编写DataLoader重写torch的DataLoader只需要在中执⾏上述的多进程处理流程,在def __getitem__(self):中返回处理好的结果即可。class Dataset(t): def __init__(self, args, batch_size): data_dir = _dir nthread = _num _list = []
files = r(data_dir) queue_in = Queue(nthread) # child queue, maxlength of Queue equals to nthread. queue_res = Queue() # parant queue # create child processes processes = [Process(target=perprocess_module, args=( queue_in, queue_res, args)) for _ in range(nthread)]
# start child processes for each in processes: # terminate all child processes when parant process normally exits = True () # feed files to child processes for file in files: queue_(file) # put None into queue_in as terminating signal for i in range(_num): queue_(None)
# parant process fetches results cnt_None = 0 while True: if cnt_None == _num: break t = queue_() if t is None: cnt_None += 1 else: res_(t) # join child processes try: for each in processes: () except Exception as e: print(str(e)) def __len__(self): return len(_list) def __getitem__(self, idx): instance = _list[idx] return instance要特别注意的是,Process库所执⾏的这个target_function,也即这⾏代码中的preprocess_module,不能定义在这⾏代码所在的域内,就是说要么放在整个⽂件的最外层即if __name__ == '__main__':之外,或者单独写成类的⼀个⽅法cess_module,否则会报错AttributeError:Can't pickle local object 'get_dataset.
发布者:admin,转转请注明出处:http://www.yc00.com/news/1690957713a472831.html
评论列表(0条)