diff --git a/BlackDynamite/base.py b/BlackDynamite/base.py index 0117d6a..5d28ced 100755 --- a/BlackDynamite/base.py +++ b/BlackDynamite/base.py @@ -1,448 +1,450 @@ #!/usr/bin/env python3 ################################################################ from __future__ import print_function ################################################################ from . import job from . import bdparser from . import run from . import bdlogging from . import jobselector import os import psycopg2 import re import sys import getpass import datetime import atexit ################################################################ __all__ = ["Base"] print = bdlogging.invalidPrint logger = bdlogging.getLogger(__name__) ################################################################ class Base(object): """ """ def getRunFromID(self, run_id): myrun = run.Run(self) myrun["id"] = run_id myrun.id = run_id run_list = myrun.getMatchedObjectList() if len(run_list) != 1: raise Exception('Unknown run {0}'.format(run_id)) return run_list[0] def getJobFromID(self, job_id): myjob = job.Job(self) myjob["id"] = job_id myjob.id = job_id job_list = myjob.getMatchedObjectList() if len(job_list) != 1: raise Exception('Unknown run {0}'.format(job_id)) return job_list[0] def createBase(self, job_desc, run_desc, quantities={}, **kwargs): # logger.debug (quantities) self.createSchema(kwargs) self.createTable(job_desc) self.createTable(run_desc) self.createGenericTables() for qname, type in quantities.items(): self.pushQuantity(qname, type) if self.truerun: self.commit() def getObject(self, sqlobject): curs = self.connection.cursor() curs.execute("SELECT * FROM {0}.{1} WHERE id = {2}".format( self.schema, sqlobject.table_name, sqlobject.id)) col_info = self.getColumnProperties(sqlobject) line = curs.fetchone() for i in range(0, len(col_info)): col_name = col_info[i][0] sqlobject[col_name] = line[i] def createSchema(self, params={"yes": False}): # create the schema of the simulation curs = self.connection.cursor() curs.execute(("SELECT schema_name FROM information_schema.schemata" " WHERE schema_name = '{0}'").format( self.schema).lower()) if curs.rowcount: validated = bdparser.validate_question( "Are you sure you want to drop the schema named '" + self.schema + "'", params, False) if validated is True: curs.execute("DROP SCHEMA {0} cascade".format(self.schema)) else: logger.debug("creation canceled: exit program") sys.exit(-1) curs.execute("CREATE SCHEMA {0}".format(self.schema)) def createTypeCodes(self): curs = self.connection.cursor() curs.execute("SELECT typname,oid from pg_type;") self.type_code = {} for i in curs: if i[0] == 'float8': self.type_code[i[1]] = float if i[0] == 'text': self.type_code[i[1]] = str if i[0] == 'int8': self.type_code[i[1]] = int if i[0] == 'int4': self.type_code[i[1]] = int if i[0] == 'bool': self.type_code[i[1]] = bool if i[0] == 'timestamp': self.type_code[i[1]] = datetime.datetime def createTable(self, obj): request = obj.createTableRequest() curs = self.connection.cursor() logger.debug(request) curs.execute(request) def createGenericTables(self,): sql_script_name = os.path.join(os.path.dirname(__file__), "build_tables.sql") curs = self.connection.cursor() # create generic tables query_list = list() with open(sql_script_name, "r") as fh: for line in fh: query_list.append(re.sub("SCHEMAS_IDENTIFIER", self.schema, line)) curs.execute("\n".join(query_list)) def getColumnProperties(self, sqlobject): curs = self.connection.cursor() curs.execute("SELECT * FROM {0}.{1} LIMIT 0".format( self.schema, sqlobject.table_name)) column_names = [desc[0] for desc in curs.description] column_type = [desc[1] for desc in curs.description] return list(zip(column_names, column_type)) def setObjectItemTypes(self, sqlobject): col_info = self.getColumnProperties(sqlobject) for i, j in col_info: sqlobject.types[i] = self.type_code[j] # logger.debug (str(i) + " " + str(self.type_code[j])) def insert(self, sqlobject): curs = self.performRequest(*(sqlobject.insert())) sqlobject.id = curs.fetchone()[0] def performRequest(self, request, params=[]): curs = self.connection.cursor() # logger.debug (request) # logger.debug (params) try: curs.execute(request, params) except psycopg2.ProgrammingError as err: raise psycopg2.ProgrammingError( ("While trying to execute the query '{0}' with parameters " + "'{1}', I caught this: '{2}'").format(request, params, err)) return curs def createParameterSpace(self, myjob, entry_nb=0, tmp_job=None, nb_inserted=0): """ This function is a recursive call to generate the points in the parametric space The entries of the jobs are treated one by one in a recursive manner """ # keys() gives a non-indexable view keys = list(myjob.entries.keys()) nparam = len(keys) # if this is the case I have done all the # entries of the job # it is time to insert it (after some checks) if entry_nb == nparam: if tmp_job is None: raise RuntimeError("internal error") # check if already inserted jselect = jobselector.JobSelector(self) jobs = jselect.selectJobs(tmp_job, quiet=True) if len(jobs) > 0: return nb_inserted # insert it nb_inserted += 1 logger.info("insert job #{0}".format(nb_inserted) + ': ' + str(tmp_job.entries)) self.insert(tmp_job) return nb_inserted if tmp_job is None: tmp_job = job.Job(self) # the key that I am currently treating key = keys[entry_nb] e = myjob[key] # if this is a list I have to create several parametric points if not isinstance(e, list): e = [e] for value in e: tmp_job[key.lower()] = value nb_inserted = self.createParameterSpace( myjob, entry_nb+1, tmp_job, nb_inserted) if self.truerun: self.commit() return nb_inserted def pushQuantity(self, name, type_code, description=None): """ implemented type_codes: "int" "float" "int.vector" "float.vector" """ if ((type_code == "int") or (type_code == int)): is_integer = True is_vector = False elif (type_code == "int.vector"): is_integer = True is_vector = True elif ((type_code == "float") or (type_code == float)): is_integer = False is_vector = False elif (type_code == "float.vector"): is_integer = False is_vector = True else: raise Exception( "invalid type '{0}' for a quantity".format(type_code)) curs = self.connection.cursor() curs.execute(""" INSERT INTO {0}.quantities (name, is_integer, is_vector, description) VALUES (%s , %s , %s, %s) RETURNING id """.format(self.schema), (name, is_integer, is_vector, description)) item = curs.fetchone() if (item is None): raise Exception("Counld not create quantity \"" + name + "\"") return item[0] def commit(self): logger.debug("commiting changes to base") self.connection.commit() def getUserList(self): curs = self.connection.cursor() curs.execute(""" select tableowner from pg_tables where tablename = 'runs'; """) users = [desc[0] for desc in curs] return users def getStudySize(self, study): curs = self.connection.cursor() try: logger.info(study) curs.execute(""" select sz from (SELECT SUM(pg_total_relation_size(quote_ident(schemaname) || '.' || quote_ident(tablename)))::BIGINT FROM pg_tables WHERE schemaname = '{0}') as sz """.format(study)) size = curs.fetchone()[0] curs.execute(""" select pg_size_pretty(cast({0} as bigint)) """.format(size)) size = curs.fetchone()[0] curs.execute(""" select count({0}.runs.id) from {0}.runs """.format(study)) nruns = curs.fetchone()[0] curs.execute(""" select count({0}.jobs.id) from {0}.jobs """.format(study)) njobs = curs.fetchone()[0] except psycopg2.ProgrammingError: self.connection.rollback() size = '????' return {'size': size, 'nruns': nruns, 'njobs': njobs} def grantAccess(self, study, user): curs = self.connection.cursor() curs.execute(""" grant SELECT on ALL tables in schema {0} to {1}; grant USAGE on SCHEMA {0} to {1}; """.format(study, user)) self.commit() def revokeAccess(self, study, user): curs = self.connection.cursor() curs.execute(""" revoke SELECT on ALL tables in schema {0} from {1}; revoke USAGE on SCHEMA {0} from {1}; """.format(study, user)) self.commit() def getStudyOwner(self, schema): curs = self.connection.cursor() curs.execute(""" select grantor from information_schema.table_privileges where (table_name,table_schema,privilege_type) = ('runs','{0}','SELECT'); """.format(schema)) owners = [desc[0] for desc in curs] return owners[0] def getGrantedUsers(self, schema): curs = self.connection.cursor() curs.execute(""" select grantee from information_schema.table_privileges where (table_name,table_schema,privilege_type) = ('runs','{0}','SELECT'); """.format(schema)) granted_users = [desc[0] for desc in curs] return granted_users def getSchemaList(self, filter_names=True): curs = self.connection.cursor() curs.execute(""" SELECT distinct(table_schema) from information_schema.tables where table_name='runs' """) schemas = [desc[0] for desc in curs] filtered_schemas = [] if filter_names is True: for s in schemas: m = re.match('{0}_(.+)'.format(self.user), s) if m: s = m.group(1) filtered_schemas.append(s) else: filtered_schemas = schemas return filtered_schemas def checkStudy(self, dico): if "study" not in dico: message = "\n" + "*"*30 + "\n" message += "Parameter 'study' must be provided at command line\n" message += "possibilities are:\n" schemas = self.getSchemaList() for s in schemas: message += "\t" + s + "\n" message += "\n" message += "FATAL => ABORT\n" message += "*"*30 + "\n" logger.error(message) sys.exit(-1) def close(self): if 'connection' in self.__dict__: logger.debug('closing database session') self.connection.close() del (self.__dict__['connection']) - def __init__(self, truerun=False, **kwargs): + def __init__(self, truerun=False, creation=False, **kwargs): psycopg2_params = ["host", "user", "port", "password"] connection_params = bdparser.filterParams(psycopg2_params, kwargs) connection_params['dbname'] = 'blackdynamite' if ("password" in connection_params and connection_params["password"] == 'ask'): connection_params["password"] = getpass.getpass() logger.debug('connection arguments: {0}'.format(connection_params)) try: self.connection = psycopg2.connect(**connection_params) logger.debug('connected to base') except Exception as e: logger.error( "Connection failed: check your connection settings:\n" + str(e)) sys.exit(-1) assert(isinstance(self.connection, psycopg2._psycopg.connection)) self.dbhost = (kwargs["host"] if "host" in kwargs.keys() else "localhost") if 'user' in kwargs: self.user = kwargs["user"] else: self.user = os.getlogin() if ("should_not_check_study" not in kwargs): self.checkStudy(kwargs) if 'study' in kwargs: # Need this because getSchemaList strips prefix match = re.match('(.+)_(.+)', kwargs["study"]) if match: self.schema = kwargs["study"] study_name = match.group(2) else: self.schema = kwargs["user"] + '_' + kwargs["study"] study_name = kwargs["study"] - if study_name not in self.getSchemaList(): + if ( + (creation is not True) and + (study_name not in self.getSchemaList())): raise RuntimeError("Study name '{}' invalid".format( - study_name)) + study_name)) self.createTypeCodes() self.truerun = truerun if("list_parameters" in kwargs and kwargs["list_parameters"] is True): message = self.getPossibleParameters() logger.info("\n{0}".format(message)) sys.exit(0) # We should avoid using __del__ to close DB def close_db(): self.close() atexit.register(close_db) def getPossibleParameters(self): myjob = job.Job(self) message = "" message += ("*"*65 + "\n") message += ("Job parameters:\n") message += ("*"*65 + "\n") params = [str(j[0]) + ": " + str(j[1]) for j in myjob.types.items()] message += ("\n".join(params)+"\n") myrun = run.Run(self) message += ("*"*65 + "\n") message += ("Run parameters:\n") message += ("*"*65 + "\n") params = [str(j[0]) + ": " + str(j[1]) for j in myrun.types.items()] message += ("\n".join(params)) return message ################################################################ if __name__ == "__main__": connection = psycopg2.connect(host="localhost") job_description = job.Job(dict(hono=int, lulu=float, toto=str)) base = Base("honoluluSchema", connection, job_description) base.create() connection.commit() base.pushJob(dict(hono=12, lulu=24.2, toto="toto")) base.pushQuantity("ekin", "float") connection.commit()