Page MenuHomec4science

mpi_routines.py
No OneTemporary

File Metadata

Created
Sat, May 10, 06:44

mpi_routines.py

from mpi4py import MPI
import horovod.keras as hvd
import data_loaders
import json
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import tensorflow as tf
import itertools
from keras import backend as K
def get_cuda_devices(rank):
cuda.init()
ctx = cuda.Device(rank).make_context()
ngpus = ctx.get_device().count()
ctx.pop()
ctx.detach()
return ngpus
def distribute_parameters(inter_comm, intra_comm, nreplica, restart_file=None):
json_data = data_loaders.read_json()
hyperparam_keys = ['epochs', 'dropout', 'learning_rate', 'loss', 'optimizer']
print('AAAAAA',intra_comm.rank, inter_comm.rank, MPI.COMM_WORLD.rank)
if intra_comm.rank==0:
if restart_file is not None:
with open(restart_file) as f:
hyperparam_dict = json.load(f)
hyperparam_values = [ list(h.values()) for h in hyperparam_dict]
else:
hyperparam_values = list(itertools.product(*[json_data[d]
for d in hyperparam_keys]))
# print(hyperparam_values)
if nreplica>1:
hyperparam_values = list(np.array_split(hyperparam_values, nreplica))
hyper_index = np.array([len(l) for l in hyperparam_values])
hyper_index = np.cumsum(hyper_index)
else:
hyperparam_values = None
hyper_index = None
if nreplica > 1:
hyperparam_values = inter_comm.scatter(hyperparam_values, root=0)
hyperparam_values = intra_comm.bcast(hyperparam_values, root=0)
hyper_index = inter_comm.bcast(hyper_index, root=0)
hyper_index = intra_comm.bcast(hyper_index, root=0)
hyperparam = []
for h in hyperparam_values:
hyperparam.append(dict(zip(hyperparam_keys, h)))
param_dict = { 'hyperparam': hyperparam,
'dir_path': json_data['directory_path'],
'hyper_index': hyper_index}
return param_dict
def set_horovod(inter_comm, intra_comm):
ngpus = get_cuda_devices(inter_comm.rank)
hvd.init(intra_comm)
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
config.gpu_options.visible_device_list=str(divmod(MPI.COMM_WORLD.rank, ngpus)[1])
print('CCCCCCC', config.gpu_options.visible_device_list)
sess = tf.Session(config=config)
K.set_session(sess)
def set_intra_comm(nreplica):
world_rank = MPI.COMM_WORLD.rank
world_size = MPI.COMM_WORLD.size
ranks_array = np.array_split(np.arange(world_size), nreplica)
for i, j in enumerate(ranks_array):
if world_rank in j:
color = i
intra_comm = MPI.COMM_WORLD.Split(color, world_rank)
#print('BBBBBB', color, intra_comm.rank, world_rank)
return intra_comm
def set_inter_comm(size):
comm = MPI.COMM_WORLD
world_rank = MPI.COMM_WORLD.rank
proc_list = divmod(np.arange(comm.size), size)
proc_list = np.where(proc_list[0]==0)[0]
if world_rank in proc_list:
color = 1
else:
color = 0
inter_comm = MPI.COMM_WORLD.Split(color, world_rank)
return inter_comm

Event Timeline