diff --git a/BlackDynamite/base.py b/BlackDynamite/base.py index f380953..8777806 100755 --- a/BlackDynamite/base.py +++ b/BlackDynamite/base.py @@ -1,328 +1,333 @@ #!/usr/bin/env python3 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . ################################################################ from abc import ABC, abstractmethod ################################################################ from . import job from . import bdlogging from . import jobselector ################################################################ import os import re import sys __all__ = ["Base"] print = bdlogging.invalidPrint logger = bdlogging.getLogger(__name__) ################################################################ class AbstractBase(ABC): """ """ def getRunFromID(self, run_id): myrun = run.Run(self) myrun["id"] = run_id myrun.id = run_id run_list = myrun.getMatchedObjectList() if len(run_list) != 1: raise Exception('Unknown run {0}'.format(run_id)) return run_list[0] def getJobFromID(self, job_id): myjob = job.Job(self) myjob["id"] = job_id myjob.id = job_id job_list = myjob.getMatchedObjectList() if len(job_list) != 1: raise Exception('Unknown run {0}'.format(job_id)) return job_list[0] def getObject(self, sqlobject): curs = self.connection.cursor() curs.execute("SELECT * FROM {0}.{1} WHERE id = {2}".format( self.schema, sqlobject.table_name, sqlobject.id)) col_info = self.getColumnProperties(sqlobject) line = curs.fetchone() for i in range(0, len(col_info)): col_name = col_info[i][0] sqlobject[col_name] = line[i] def createGenericTables(self,): sql_script_name = os.path.join(os.path.dirname(__file__), "build_tables.sql") curs = self.connection.cursor() # create generic tables query_list = list() with open(sql_script_name, "r") as fh: for line in fh: query_list.append(re.sub("SCHEMAS_IDENTIFIER", self.schema, line)) curs.execute("\n".join(query_list)) def createParameterSpace(self, myjob, entry_nb=0, tmp_job=None, nb_inserted=0): """ This function is a recursive call to generate the points in the parametric space The entries of the jobs are treated one by one in a recursive manner """ # keys() gives a non-indexable view keys = list(myjob.entries.keys()) nparam = len(keys) # if this is the case I have done all the # entries of the job # it is time to insert it (after some checks) if entry_nb == nparam: if tmp_job is None: raise RuntimeError("internal error") # check if already inserted jselect = jobselector.JobSelector(self) jobs = jselect.selectJobs(tmp_job, quiet=True) if len(jobs) > 0: return nb_inserted # insert it nb_inserted += 1 logger.info("insert job #{0}".format(nb_inserted) + ': ' + str(tmp_job.entries)) self.insert(tmp_job) return nb_inserted if tmp_job is None: tmp_job = self.Job(self) # the key that I am currently treating key = keys[entry_nb] e = myjob[key] # if this is a list I have to create several parametric points if not isinstance(e, list): e = [e] for value in e: tmp_job[key.lower()] = value nb_inserted = self.createParameterSpace( myjob, entry_nb+1, tmp_job, nb_inserted) if self.truerun: self.commit() return nb_inserted def pushQuantity(self, name, type_code, description=None): """ implemented type_codes: "int" "float" "int.vector" "float.vector" """ if ((type_code == "int") or (type_code == int)): is_integer = True is_vector = False elif (type_code == "int.vector"): is_integer = True is_vector = True elif ((type_code == "float") or (type_code == float)): is_integer = False is_vector = False elif (type_code == "float.vector"): is_integer = False is_vector = True else: raise Exception( "invalid type '{0}' for a quantity".format(type_code)) curs = self.connection.cursor() curs.execute(""" INSERT INTO {0}.quantities (name, is_integer, is_vector, description) VALUES (%s , %s , %s, %s) RETURNING id """.format(self.schema), (name, is_integer, is_vector, description)) item = curs.fetchone() if (item is None): raise Exception("Counld not create quantity \"" + name + "\"") return item[0] def getUserList(self): curs = self.connection.cursor() curs.execute(""" select tableowner from pg_tables where tablename = 'runs'; """) users = [desc[0] for desc in curs] users = list(set(users)) return users @abstractmethod def getStudySize(self, study): raise RuntimeError("abstract method") def grantAccess(self, study, user): curs = self.connection.cursor() curs.execute(""" grant SELECT on ALL tables in schema {0} to {1}; grant USAGE on SCHEMA {0} to {1}; """.format(study, user)) self.commit() def revokeAccess(self, study, user): curs = self.connection.cursor() curs.execute(""" revoke SELECT on ALL tables in schema {0} from {1}; revoke USAGE on SCHEMA {0} from {1}; """.format(study, user)) self.commit() def getStudyOwner(self, schema): curs = self.connection.cursor() curs.execute(""" select grantor from information_schema.table_privileges where (table_name,table_schema,privilege_type) = ('runs','{0}','SELECT'); """.format(schema)) owners = [desc[0] for desc in curs] return owners[0] def getGrantedUsers(self, schema): curs = self.connection.cursor() curs.execute(""" select grantee from information_schema.table_privileges where (table_name,table_schema,privilege_type) = ('runs','{0}','SELECT'); """.format(schema)) granted_users = [desc[0] for desc in curs] return granted_users def getSchemaList(self, filter_names=True): curs = self.connection.cursor() curs.execute(""" SELECT distinct(table_schema) from information_schema.tables where table_name='runs' """) schemas = [desc[0] for desc in curs] filtered_schemas = [] if filter_names is True: for s in schemas: m = re.match('{0}_(.+)'.format(self.user), s) if m: s = m.group(1) filtered_schemas.append(s) else: filtered_schemas = schemas return filtered_schemas def checkStudy(self, dico): if "study" not in dico: + schemas = self.getSchemaList() + if len(schemas) == 1: + dico['study'] = schemas[0] + return message = "\n" + "*"*30 + "\n" message += "Parameter 'study' must be provided at command line\n" message += "possibilities are:\n" - schemas = self.getSchemaList() for s in schemas: message += "\t" + s + "\n" message += "\n" message += "FATAL => ABORT\n" message += "*"*30 + "\n" logger.error(message) sys.exit(-1) def __init__(self, connection=None, truerun=False, creation=False, **kwargs): self.connection = connection if self.connection is None: raise RuntimeError("This class must be derived") if 'user' in kwargs: self.user = kwargs["user"] else: self.user = os.getlogin() if ("should_not_check_study" not in kwargs): self.checkStudy(kwargs) if 'study' in kwargs: # Need this because getSchemaList strips prefix match = re.match('(.+)_(.+)', kwargs["study"]) if match: self.schema = kwargs["study"] study_name = match.group(2) else: self.schema = kwargs["user"] + '_' + kwargs["study"] study_name = kwargs["study"] # logger.error(self.schema) if ((creation is not True) and (study_name not in self.getSchemaList())): raise RuntimeError( f"Study name '{study_name}' invalid") self.truerun = truerun if("list_parameters" in kwargs and kwargs["list_parameters"] is True): message = self.getPossibleParameters() logger.info("\n{0}".format(message)) sys.exit(0) def getPossibleParameters(self): myjob = self.Job(self) message = "" message += ("*"*65 + "\n") message += ("Job parameters:\n") message += ("*"*65 + "\n") params = [str(j[0]) + ": " + str(j[1]) for j in myjob.types.items()] message += ("\n".join(params)+"\n") myrun = self.Run(self) message += ("*"*65 + "\n") message += ("Run parameters:\n") message += ("*"*65 + "\n") params = [str(j[0]) + ": " + str(j[1]) for j in myrun.types.items()] message += ("\n".join(params)) return message ################################################################ def Base(**params): if 'host' in params: host = params['host'] host_split = host.split('://') if host_split[0] == 'file': from . import base_sqlite params['host'] = host_split[1] return base_sqlite.BaseSQLite(**params) elif host_split[0] == 'zeo': from . import base_zeo return base_zeo.BaseZEO(**params) else: from . import base_psql return base_psql.BasePSQL(**params) if 'host' not in params: from . import base_zeo if ('creation' in params and params['creation'] and not os.path.exists('./.bd')): os.mkdir('./.bd') if not os.path.exists('./.bd'): raise RuntimeError("this is not a black dynamite directory") params['host'] = 'zeo://' + os.path.realpath('.bd/bd.zeo') - return base_zeo.BaseZEO(**params) - + base = base_zeo.BaseZEO(**params) + if base.schema is not None: + params['study'] = base.schema + return base raise RuntimeError("Should not happen") diff --git a/scripts/enterRun.py b/scripts/enterRun.py index bdd255e..c1a9988 100755 --- a/scripts/enterRun.py +++ b/scripts/enterRun.py @@ -1,88 +1,88 @@ #!/usr/bin/env python3 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . ################################################################ import BlackDynamite as BD import subprocess import os import sys import socket ################################################################ parser = BD.BDParser() parser.register_params(group="getRunInfo", params={"run_id": int, "order": str}, help={"run_id": "Select a run_id for switching to it"}) params = parser.parseBDParameters() mybase = BD.Base(**params) if 'run_id' in params: params['run_constraints'] = ['id = {0}'.format(params['run_id'])] try: del params['job_constraints'] except: pass runSelector = BD.RunSelector(mybase) run_list = runSelector.selectRuns(params, quiet=True) mybase.close() if (len(run_list) == 0): print("no run found") sys.exit(1) run, job = run_list[0] run_id = run['id'] separator = '-'*30 print(separator) print("JOB INFO") print(separator) print(job) print(separator) print("RUN INFO") print(separator) print(run) print(separator) print("LOGGING TO '{0}'".format(run['machine_name'])) print(separator) if run['state'] == 'CREATED': print("Cannot enter run: not yet started") sys.exit(-1) bashrc_filename = os.path.join( '/tmp', 'bashrc.user{0}.study{1}.run{2}'.format(params['user'], - params['study'], + mybase.schema, run_id)) bashrc = open(bashrc_filename, 'w') bashrc.write('export PS1="\\u@\\h:<{0}|RUN-{1}> $ "\n'.format( - params['study'], run_id)) + mybase.schema, run_id)) bashrc.write('cd {0}\n'.format(run['run_path'])) bashrc.write('echo ' + separator) bashrc.close() command_login = 'bash --rcfile {0} -i'.format(bashrc_filename) if ((not run['machine_name'] == socket.gethostname()) and (not run['machine_name'] == 'localhost')): command1 = 'scp -q {0} {1}:{0}'.format(bashrc_filename, run['machine_name']) subprocess.call(command1, shell=True) command_login = 'ssh -X -A -t {0} "{1}"'.format( run['machine_name'], command_login) # print command_login subprocess.call(command_login, shell=True)