Page MenuHomec4science

base.py
No OneTemporary

File Metadata

Created
Sun, Apr 28, 10:02
#!/usr/bin/env python3
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
################################################################
from __future__ import print_function
################################################################
from . import job
from . import bdparser
from . import run
from . import bdlogging
from . import jobselector
import os
import psycopg2
import re
import sys
import getpass
import datetime
import atexit
################################################################
__all__ = ["Base"]
print = bdlogging.invalidPrint
logger = bdlogging.getLogger(__name__)
################################################################
class Base(object):
"""
"""
def getRunFromID(self, run_id):
myrun = run.Run(self)
myrun["id"] = run_id
myrun.id = run_id
run_list = myrun.getMatchedObjectList()
if len(run_list) != 1:
raise Exception('Unknown run {0}'.format(run_id))
return run_list[0]
def getJobFromID(self, job_id):
myjob = job.Job(self)
myjob["id"] = job_id
myjob.id = job_id
job_list = myjob.getMatchedObjectList()
if len(job_list) != 1:
raise Exception('Unknown run {0}'.format(job_id))
return job_list[0]
def createBase(self, job_desc, run_desc, quantities={}, **kwargs):
# logger.debug (quantities)
self.createSchema(kwargs)
self.createTable(job_desc)
self.createTable(run_desc)
self.createGenericTables()
for qname, type in quantities.items():
self.pushQuantity(qname, type)
if self.truerun:
self.commit()
def getObject(self, sqlobject):
curs = self.connection.cursor()
curs.execute("SELECT * FROM {0}.{1} WHERE id = {2}".format(
self.schema, sqlobject.table_name, sqlobject.id))
col_info = self.getColumnProperties(sqlobject)
line = curs.fetchone()
for i in range(0, len(col_info)):
col_name = col_info[i][0]
sqlobject[col_name] = line[i]
def createSchema(self, params={"yes": False}):
# create the schema of the simulation
curs = self.connection.cursor()
curs.execute(("SELECT schema_name FROM information_schema.schemata"
" WHERE schema_name = '{0}'").format(
self.schema).lower())
if curs.rowcount:
validated = bdparser.validate_question(
"Are you sure you want to drop the schema named '" +
self.schema + "'", params, False)
if validated is True:
curs.execute("DROP SCHEMA {0} cascade".format(self.schema))
else:
logger.debug("creation canceled: exit program")
sys.exit(-1)
curs.execute("CREATE SCHEMA {0}".format(self.schema))
def createTypeCodes(self):
curs = self.connection.cursor()
curs.execute("SELECT typname,oid from pg_type;")
self.type_code = {}
for i in curs:
if i[0] == 'float8':
self.type_code[i[1]] = float
if i[0] == 'text':
self.type_code[i[1]] = str
if i[0] == 'int8':
self.type_code[i[1]] = int
if i[0] == 'int4':
self.type_code[i[1]] = int
if i[0] == 'bool':
self.type_code[i[1]] = bool
if i[0] == 'timestamp':
self.type_code[i[1]] = datetime.datetime
def createTable(self, obj):
request = obj.createTableRequest()
curs = self.connection.cursor()
logger.debug(request)
curs.execute(request)
def createGenericTables(self,):
sql_script_name = os.path.join(os.path.dirname(__file__),
"build_tables.sql")
curs = self.connection.cursor()
# create generic tables
query_list = list()
with open(sql_script_name, "r") as fh:
for line in fh:
query_list.append(re.sub("SCHEMAS_IDENTIFIER",
self.schema, line))
curs.execute("\n".join(query_list))
def getColumnProperties(self, sqlobject):
curs = self.connection.cursor()
curs.execute("SELECT * FROM {0}.{1} LIMIT 0".format(
self.schema, sqlobject.table_name))
column_names = [desc[0] for desc in curs.description]
column_type = [desc[1] for desc in curs.description]
return list(zip(column_names, column_type))
def setObjectItemTypes(self, sqlobject):
col_info = self.getColumnProperties(sqlobject)
for i, j in col_info:
sqlobject.types[i] = self.type_code[j]
# logger.debug (str(i) + " " + str(self.type_code[j]))
def insert(self, sqlobject):
curs = self.performRequest(*(sqlobject.insert()))
sqlobject.id = curs.fetchone()[0]
def performRequest(self, request, params=[]):
curs = self.connection.cursor()
# logger.debug (request)
# logger.debug (params)
try:
curs.execute(request, params)
except psycopg2.ProgrammingError as err:
raise psycopg2.ProgrammingError(
("While trying to execute the query '{0}' with parameters " +
"'{1}', I caught this: '{2}'").format(request, params, err))
return curs
def createParameterSpace(self, myjob, entry_nb=0,
tmp_job=None, nb_inserted=0):
"""
This function is a recursive call to generate the points
in the parametric space
The entries of the jobs are treated one by one
in a recursive manner
"""
# keys() gives a non-indexable view
keys = list(myjob.entries.keys())
nparam = len(keys)
# if this is the case I have done all the
# entries of the job
# it is time to insert it (after some checks)
if entry_nb == nparam:
if tmp_job is None:
raise RuntimeError("internal error")
# check if already inserted
jselect = jobselector.JobSelector(self)
jobs = jselect.selectJobs(tmp_job, quiet=True)
if len(jobs) > 0:
return nb_inserted
# insert it
nb_inserted += 1
logger.info("insert job #{0}".format(nb_inserted) +
': ' + str(tmp_job.entries))
self.insert(tmp_job)
return nb_inserted
if tmp_job is None:
tmp_job = job.Job(self)
# the key that I am currently treating
key = keys[entry_nb]
e = myjob[key]
# if this is a list I have to create several parametric points
if not isinstance(e, list):
e = [e]
for value in e:
tmp_job[key.lower()] = value
nb_inserted = self.createParameterSpace(
myjob, entry_nb+1, tmp_job, nb_inserted)
if self.truerun:
self.commit()
return nb_inserted
def pushQuantity(self, name, type_code, description=None):
""" implemented type_codes: "int" "float" "int.vector" "float.vector"
"""
if ((type_code == "int") or (type_code == int)):
is_integer = True
is_vector = False
elif (type_code == "int.vector"):
is_integer = True
is_vector = True
elif ((type_code == "float") or (type_code == float)):
is_integer = False
is_vector = False
elif (type_code == "float.vector"):
is_integer = False
is_vector = True
else:
raise Exception(
"invalid type '{0}' for a quantity".format(type_code))
curs = self.connection.cursor()
curs.execute("""
INSERT INTO {0}.quantities (name, is_integer, is_vector, description)
VALUES (%s , %s , %s, %s) RETURNING id
""".format(self.schema), (name, is_integer, is_vector, description))
item = curs.fetchone()
if (item is None):
raise Exception("Counld not create quantity \"" + name + "\"")
return item[0]
def commit(self):
logger.debug("commiting changes to base")
self.connection.commit()
def getUserList(self):
curs = self.connection.cursor()
curs.execute("""
select tableowner from pg_tables where tablename = 'runs';
""")
users = [desc[0] for desc in curs]
return users
def getStudySize(self, study):
curs = self.connection.cursor()
try:
logger.info(study)
curs.execute("""
select sz from (SELECT SUM(pg_total_relation_size(quote_ident(schemaname)
|| '.' || quote_ident(tablename)))::BIGINT
FROM pg_tables WHERE schemaname = '{0}') as sz
""".format(study))
size = curs.fetchone()[0]
curs.execute("""
select pg_size_pretty(cast({0} as bigint))
""".format(size))
size = curs.fetchone()[0]
curs.execute("""
select count({0}.runs.id) from {0}.runs
""".format(study))
nruns = curs.fetchone()[0]
curs.execute("""
select count({0}.jobs.id) from {0}.jobs
""".format(study))
njobs = curs.fetchone()[0]
except psycopg2.ProgrammingError:
self.connection.rollback()
size = '????'
return {'size': size, 'nruns': nruns, 'njobs': njobs}
def grantAccess(self, study, user):
curs = self.connection.cursor()
curs.execute("""
grant SELECT on ALL tables in schema {0} to {1};
grant USAGE on SCHEMA {0} to {1};
""".format(study, user))
self.commit()
def revokeAccess(self, study, user):
curs = self.connection.cursor()
curs.execute("""
revoke SELECT on ALL tables in schema {0} from {1};
revoke USAGE on SCHEMA {0} from {1};
""".format(study, user))
self.commit()
def getStudyOwner(self, schema):
curs = self.connection.cursor()
curs.execute("""
select grantor from information_schema.table_privileges
where (table_name,table_schema,privilege_type)
= ('runs','{0}','SELECT');
""".format(schema))
owners = [desc[0] for desc in curs]
return owners[0]
def getGrantedUsers(self, schema):
curs = self.connection.cursor()
curs.execute("""
select grantee from information_schema.table_privileges
where (table_name,table_schema,privilege_type)
= ('runs','{0}','SELECT');
""".format(schema))
granted_users = [desc[0] for desc in curs]
return granted_users
def getSchemaList(self, filter_names=True):
curs = self.connection.cursor()
curs.execute("""
SELECT distinct(table_schema) from information_schema.tables
where table_name='runs'
""")
schemas = [desc[0] for desc in curs]
filtered_schemas = []
if filter_names is True:
for s in schemas:
m = re.match('{0}_(.+)'.format(self.user), s)
if m:
s = m.group(1)
filtered_schemas.append(s)
else:
filtered_schemas = schemas
return filtered_schemas
def checkStudy(self, dico):
if "study" not in dico:
message = "\n" + "*"*30 + "\n"
message += "Parameter 'study' must be provided at command line\n"
message += "possibilities are:\n"
schemas = self.getSchemaList()
for s in schemas:
message += "\t" + s + "\n"
message += "\n"
message += "FATAL => ABORT\n"
message += "*"*30 + "\n"
logger.error(message)
sys.exit(-1)
def close(self):
if 'connection' in self.__dict__:
logger.debug('closing database session')
self.connection.close()
del (self.__dict__['connection'])
def __init__(self, truerun=False, creation=False, **kwargs):
psycopg2_params = ["host", "user", "port", "password"]
connection_params = bdparser.filterParams(psycopg2_params, kwargs)
connection_params['dbname'] = 'blackdynamite'
if ("password" in connection_params and
connection_params["password"] == 'ask'):
connection_params["password"] = getpass.getpass()
logger.debug('connection arguments: {0}'.format(connection_params))
try:
self.connection = psycopg2.connect(**connection_params)
logger.debug('connected to base')
except Exception as e:
logger.error(
"Connection failed: check your connection settings:\n" +
str(e))
sys.exit(-1)
assert(isinstance(self.connection, psycopg2._psycopg.connection))
self.dbhost = (kwargs["host"]
if "host" in kwargs.keys()
else "localhost")
if 'user' in kwargs:
self.user = kwargs["user"]
else:
self.user = os.getlogin()
if ("should_not_check_study" not in kwargs):
self.checkStudy(kwargs)
if 'study' in kwargs:
# Need this because getSchemaList strips prefix
match = re.match('(.+)_(.+)', kwargs["study"])
if match:
self.schema = kwargs["study"]
study_name = match.group(2)
else:
self.schema = kwargs["user"] + '_' + kwargs["study"]
study_name = kwargs["study"]
if (
(creation is not True) and
(study_name not in self.getSchemaList())):
raise RuntimeError("Study name '{}' invalid".format(
study_name))
self.createTypeCodes()
self.truerun = truerun
if("list_parameters" in kwargs and kwargs["list_parameters"] is True):
message = self.getPossibleParameters()
logger.info("\n{0}".format(message))
sys.exit(0)
# We should avoid using __del__ to close DB
def close_db():
self.close()
atexit.register(close_db)
def getPossibleParameters(self):
myjob = job.Job(self)
message = ""
message += ("*"*65 + "\n")
message += ("Job parameters:\n")
message += ("*"*65 + "\n")
params = [str(j[0]) + ": " + str(j[1])
for j in myjob.types.items()]
message += ("\n".join(params)+"\n")
myrun = run.Run(self)
message += ("*"*65 + "\n")
message += ("Run parameters:\n")
message += ("*"*65 + "\n")
params = [str(j[0]) + ": " + str(j[1])
for j in myrun.types.items()]
message += ("\n".join(params))
return message
################################################################
if __name__ == "__main__":
connection = psycopg2.connect(host="localhost")
job_description = job.Job(dict(hono=int, lulu=float, toto=str))
base = Base("honoluluSchema", connection, job_description)
base.create()
connection.commit()
base.pushJob(dict(hono=12, lulu=24.2, toto="toto"))
base.pushQuantity("ekin", "float")
connection.commit()

Event Timeline