Page MenuHomec4science

base.py
No OneTemporary

File Metadata

Created
Sat, Oct 19, 20:14
#!/usr/bin/env python3
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
################################################################
from abc import ABC, abstractmethod
################################################################
from . import job
from . import bdlogging
from . import jobselector
################################################################
import os
import re
import sys
__all__ = ["Base"]
print = bdlogging.invalidPrint
logger = bdlogging.getLogger(__name__)
################################################################
class AbstractBase(ABC):
"""
"""
def getRunFromID(self, run_id):
myrun = run.Run(self)
myrun["id"] = run_id
myrun.id = run_id
run_list = myrun.getMatchedObjectList()
if len(run_list) != 1:
raise Exception('Unknown run {0}'.format(run_id))
return run_list[0]
def getJobFromID(self, job_id):
myjob = job.Job(self)
myjob["id"] = job_id
myjob.id = job_id
job_list = myjob.getMatchedObjectList()
if len(job_list) != 1:
raise Exception('Unknown run {0}'.format(job_id))
return job_list[0]
def getObject(self, sqlobject):
curs = self.connection.cursor()
curs.execute("SELECT * FROM {0}.{1} WHERE id = {2}".format(
self.schema, sqlobject.table_name, sqlobject.id))
col_info = self.getColumnProperties(sqlobject)
line = curs.fetchone()
for i in range(0, len(col_info)):
col_name = col_info[i][0]
sqlobject[col_name] = line[i]
def createGenericTables(self,):
sql_script_name = os.path.join(os.path.dirname(__file__),
"build_tables.sql")
curs = self.connection.cursor()
# create generic tables
query_list = list()
with open(sql_script_name, "r") as fh:
for line in fh:
query_list.append(re.sub("SCHEMAS_IDENTIFIER",
self.schema, line))
curs.execute("\n".join(query_list))
@abstractmethod
def performRequest(self, request, params=[]):
raise RuntimeError("abstract method")
def createParameterSpace(self, myjob, entry_nb=0,
tmp_job=None, nb_inserted=0):
"""
This function is a recursive call to generate the points
in the parametric space
The entries of the jobs are treated one by one
in a recursive manner
"""
# keys() gives a non-indexable view
keys = list(myjob.entries.keys())
nparam = len(keys)
# if this is the case I have done all the
# entries of the job
# it is time to insert it (after some checks)
if entry_nb == nparam:
if tmp_job is None:
raise RuntimeError("internal error")
# check if already inserted
jselect = jobselector.JobSelector(self)
jobs = jselect.selectJobs(tmp_job, quiet=True)
if len(jobs) > 0:
return nb_inserted
# insert it
nb_inserted += 1
logger.info("insert job #{0}".format(nb_inserted) +
': ' + str(tmp_job.entries))
self.insert(tmp_job)
return nb_inserted
if tmp_job is None:
tmp_job = self.Job(self)
# the key that I am currently treating
key = keys[entry_nb]
e = myjob[key]
# if this is a list I have to create several parametric points
if not isinstance(e, list):
e = [e]
for value in e:
tmp_job[key.lower()] = value
nb_inserted = self.createParameterSpace(
myjob, entry_nb+1, tmp_job, nb_inserted)
if self.truerun:
self.commit()
return nb_inserted
def pushQuantity(self, name, type_code, description=None):
""" implemented type_codes: "int" "float" "int.vector" "float.vector"
"""
if ((type_code == "int") or (type_code == int)):
is_integer = True
is_vector = False
elif (type_code == "int.vector"):
is_integer = True
is_vector = True
elif ((type_code == "float") or (type_code == float)):
is_integer = False
is_vector = False
elif (type_code == "float.vector"):
is_integer = False
is_vector = True
else:
raise Exception(
"invalid type '{0}' for a quantity".format(type_code))
curs = self.connection.cursor()
curs.execute("""
INSERT INTO {0}.quantities (name, is_integer, is_vector, description)
VALUES (%s , %s , %s, %s) RETURNING id
""".format(self.schema), (name, is_integer, is_vector, description))
item = curs.fetchone()
if (item is None):
raise Exception("Counld not create quantity \"" + name + "\"")
return item[0]
def getUserList(self):
curs = self.connection.cursor()
curs.execute("""
select tableowner from pg_tables where tablename = 'runs';
""")
users = [desc[0] for desc in curs]
users = list(set(users))
return users
@abstractmethod
def getStudySize(self, study):
raise RuntimeError("abstract method")
def grantAccess(self, study, user):
curs = self.connection.cursor()
curs.execute("""
grant SELECT on ALL tables in schema {0} to {1};
grant USAGE on SCHEMA {0} to {1};
""".format(study, user))
self.commit()
def revokeAccess(self, study, user):
curs = self.connection.cursor()
curs.execute("""
revoke SELECT on ALL tables in schema {0} from {1};
revoke USAGE on SCHEMA {0} from {1};
""".format(study, user))
self.commit()
def getStudyOwner(self, schema):
curs = self.connection.cursor()
curs.execute("""
select grantor from information_schema.table_privileges
where (table_name,table_schema,privilege_type)
= ('runs','{0}','SELECT');
""".format(schema))
owners = [desc[0] for desc in curs]
return owners[0]
def getGrantedUsers(self, schema):
curs = self.connection.cursor()
curs.execute("""
select grantee from information_schema.table_privileges
where (table_name,table_schema,privilege_type)
= ('runs','{0}','SELECT');
""".format(schema))
granted_users = [desc[0] for desc in curs]
return granted_users
def getSchemaList(self, filter_names=True):
curs = self.connection.cursor()
curs.execute("""
SELECT distinct(table_schema) from information_schema.tables
where table_name='runs'
""")
schemas = [desc[0] for desc in curs]
filtered_schemas = []
if filter_names is True:
for s in schemas:
m = re.match('{0}_(.+)'.format(self.user), s)
if m:
s = m.group(1)
filtered_schemas.append(s)
else:
filtered_schemas = schemas
return filtered_schemas
def checkStudy(self, dico):
if "study" not in dico:
message = "\n" + "*"*30 + "\n"
message += "Parameter 'study' must be provided at command line\n"
message += "possibilities are:\n"
schemas = self.getSchemaList()
for s in schemas:
message += "\t" + s + "\n"
message += "\n"
message += "FATAL => ABORT\n"
message += "*"*30 + "\n"
logger.error(message)
sys.exit(-1)
def close(self):
if 'connection' in self.__dict__:
logger.debug('closing database session')
self.connection.close()
del (self.__dict__['connection'])
def __init__(self, connection=None, truerun=False, creation=False, **kwargs):
self.connection = connection
if self.connection is None:
raise RuntimeError("This class must be derived")
if 'user' in kwargs:
self.user = kwargs["user"]
else:
self.user = os.getlogin()
if ("should_not_check_study" not in kwargs):
self.checkStudy(kwargs)
if 'study' in kwargs:
# Need this because getSchemaList strips prefix
match = re.match('(.+)_(.+)', kwargs["study"])
if match:
self.schema = kwargs["study"]
study_name = match.group(2)
else:
self.schema = kwargs["user"] + '_' + kwargs["study"]
study_name = kwargs["study"]
if ((creation is not True) and
(study_name not in self.getSchemaList())):
raise RuntimeError("Study name '{}' invalid".format(
study_name))
self.truerun = truerun
if("list_parameters" in kwargs and kwargs["list_parameters"] is True):
message = self.getPossibleParameters()
logger.info("\n{0}".format(message))
sys.exit(0)
def getPossibleParameters(self):
myjob = job.Job(self)
message = ""
message += ("*"*65 + "\n")
message += ("Job parameters:\n")
message += ("*"*65 + "\n")
params = [str(j[0]) + ": " + str(j[1])
for j in myjob.types.items()]
message += ("\n".join(params)+"\n")
myrun = run.Run(self)
message += ("*"*65 + "\n")
message += ("Run parameters:\n")
message += ("*"*65 + "\n")
params = [str(j[0]) + ": " + str(j[1])
for j in myrun.types.items()]
message += ("\n".join(params))
return message
################################################################
def Base(**params):
if 'host' in params:
host = params['host']
host_split = host.split('://')
if host_split[0] == 'file':
from . import base_sqlite
params['host'] = host_split[1]
return base_sqlite.BaseSQLite(**params)
elif host_split[0] == 'zeo':
from . import base_zeo
params['host'] = host_split[1]
return base_zeo.BaseZEO(**params)
from . import base_psql
return base_psql.BasePSQL(**params)

Event Timeline