Page MenuHomec4science

batches.py
No OneTemporary

File Metadata

Created
Wed, May 15, 09:50

batches.py

# -*- coding: utf-8 -*-
# @Author: Theo Lemaire
# @Email: theo.lemaire@epfl.ch
# @Date: 2017-08-22 14:33:04
# @Last Modified by: Theo Lemaire
# @Last Modified time: 2019-06-26 13:51:06
''' Utility functions used in simulations '''
import logging
import multiprocess as mp
import numpy as np
from ..utils import logger
class Consumer(mp.Process):
''' Generic consumer process, taking tasks from a queue and outputing results in
another queue.
'''
def __init__(self, queue_in, queue_out):
mp.Process.__init__(self)
self.queue_in = queue_in
self.queue_out = queue_out
logger.debug('Starting %s', self.name)
def run(self):
while True:
nextTask = self.queue_in.get()
if nextTask is None:
logger.debug('Exiting %s', self.name)
self.queue_in.task_done()
break
answer = nextTask()
self.queue_in.task_done()
self.queue_out.put(answer)
return
class Worker:
''' Generic worker class calling a specific function with a given set of parameters. '''
def __init__(self, wid, func, params, loglevel):
''' Worker constructor.
:param wid: worker ID
:param func: function object
:param params: list of method parameters
:param loglevel: logging level
'''
self.id = wid
self.func = func
self.params = params
self.loglevel = loglevel
def __call__(self):
''' Caller to the function with specific parameters. '''
logger.setLevel(self.loglevel)
return self.id, self.func(*self.params)
class Batch:
''' Generic interface to run batches of function calls. '''
def __init__(self, func, queue):
''' Batch constructor.
:param func: function object
:param queue: list of list of function parameters
'''
self.func = func
self.queue = queue
def __call__(self, *args, **kwargs):
''' Call the internal run method. '''
return self.run(*args, **kwargs)
def getNConsumers(self):
''' Determine number of consumers based on queue length and number of available CPUs. '''
return min(mp.cpu_count(), len(self.queue))
def start(self):
''' Create tasks and results queues, and start consumers. '''
mp.freeze_support()
self.tasks = mp.JoinableQueue()
self.results = mp.Queue()
self.consumers = [Consumer(self.tasks, self.results) for i in range(self.getNConsumers())]
for c in self.consumers:
c.start()
def assign(self, loglevel):
''' Assign tasks to workers. '''
for i, params in enumerate(self.queue):
worker = Worker(i, self.func, params, loglevel)
self.tasks.put(worker, block=False)
def join(self):
''' Put all tasks to None and join the queue. '''
for i in range(len(self.consumers)):
self.tasks.put(None, block=False)
self.tasks.join()
def get(self):
''' Extract and re-order results. '''
outputs, idxs = [], []
for i in range(len(self.queue)):
wid, out = self.results.get()
outputs.append(out)
idxs.append(wid)
return [x for _, x in sorted(zip(idxs, outputs))]
def stop(self):
''' Close tasks and results queues. '''
self.tasks.close()
self.results.close()
def run(self, mpi=False, loglevel=logging.INFO):
''' Run batch with or without multiprocessing. '''
if mpi:
self.start()
self.assign(loglevel)
self.join()
outputs = self.get()
self.stop()
else:
outputs = [self.func(*params) for params in self.queue]
return outputs
def createQueue(*dims):
''' Create a serialized 2D array of all parameter combinations for a series of individual
parameter sweeps.
:param dims: list of lists (or 1D arrays) of input parameters
:return: list of parameters (list) for each simulation
'''
ndims = len(dims)
dims_in = [dims[1], dims[0]]
inds_out = [1, 0]
if ndims > 2:
dims_in += dims[2:]
inds_out += list(range(2, ndims))
queue = np.stack(np.meshgrid(*dims_in), -1).reshape(-1, ndims)
queue = queue[:, inds_out]
return queue.tolist()

Event Timeline