diff --git a/src/sausage/appargs.py b/src/sausage/appargs.py index 344d91b..74a8610 100644 --- a/src/sausage/appargs.py +++ b/src/sausage/appargs.py @@ -1,133 +1,170 @@ """ Arguments parsing module """ import argparse import getpass import sys +import grp +from os import getgroups from sausage.functions import valid_date, valid_period +from sausage.readconf import ReadConf + +ReadConf() class AppArgs(object): - def __init__(self): + def __init__(self, cost=False): self.response = {} self.parser = argparse.ArgumentParser( prog="Sausage", description="SCITAS Account Usage.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) self.add_default_args() - self.add_args() + if cost: + self.add_args_cost() + else: + self.add_args() AppArgs.verbose = self.args.verbose def add_default_args(self): self.parser.add_argument( "-v", "--verbose", help="Verbose", action="store_true", ) return def add_args(self): self.parser.add_argument( "-u", "--user", help="If not provided whoami is considered" ) self.parser.add_argument( "-a", "--all", help="all users from an account are printed", action="store_true", ) self.parser.add_argument( "-A", "--account", help="Prints account consumption per cluster" ) self.parser.add_argument( "-s", "--start", help="Start date - format YYYY-MM-DD", type=valid_date ) self.parser.add_argument( "-e", "--end", help="End date - format YYYY-MM-DD", type=valid_date ) self.parser.add_argument( "-c", "--carbon", help="Prints the carbon footprint per cluster", action="store_true", ) self.parser.add_argument( "-b", "--billing", help="Displays the billing period - format YYYY-MM or YYYY", type=valid_period, ) self.parser.add_argument( "-x", "--csv", help="Print result in csv style", action="store_true", default=False, ) self.args = self.parser.parse_args() - # active = [k for k, v in vars(self.args..items() if v not in (None, False)] - if self.args.billing: - # listofgroups = [grp.getgrgid(g).gr_name for g in os.getgroups()] - # if self.billinggrp not in listofgroups: - # self.parser.error( - # "--billing is only available for users in " - # + self.billinggrp - # + " group" - # ) + listofgroups = [grp.getgrgid(g).gr_name for g in getgroups()] + if ReadConf.billinggrp not in listofgroups: + self.parser.error( + "--billing is only available for users in " + + ReadConf.billinggrp + + " group" + ) if ( self.args.user or self.args.account or self.args.all or self.args.start or self.args.end or self.args.carbon ): self.parser.error( "--billing is not compatible with any other option") if self.args.start and self.args.end is None: self.parser.error("range requires both dates (--start and --end)") if self.args.end: if self.args.start is None: self.parser.error( "range requires both dates (--start and --end)") if self.args.end < self.args.start: self.parser.error("start date must be earlier than end date") if self.args.all and self.args.account is None: self.parser.error( "the option --all requires a valid account (--all and --account)" ) if self.args.all and self.args.user: self.parser.error( "--all option is not compatible with --user option") if len(sys.argv) <= 1 or ( ( ( all(v is not None for v in [ self.args.start, self.args.end]) or any(v is not None for v in [self.args.carbon, self.args.verbose]) ) and all(v is None for v in [self.args.account, self.args.user]) ) ): self.args.user = getpass.getuser() AppArgs.csv = self.args.csv AppArgs.response = { "user": self.args.user, "account": self.args.account, "all": self.args.all, "start": self.args.start, "end": self.args.end, "carbon": self.args.carbon, "billing": self.args.billing, } + + def add_args_cost(self): + self.parser.add_argument( + "-N", "--nodes", help="number of (min) nodes on which to run", default=1 + ) + self.parser.add_argument( + "-n", "--nbtasks", help="number of tasks to run", default=1 + ) + self.parser.add_argument( + "-t", "--time", help="time limit in minutes", default=1 + ) + self.parser.add_argument( + "-p", "--partition", help="partition", default="parallel" + ) + self.parser.add_argument( + "-g", "--gres", help="required generic resources per node", default="gpu:0" + ) + self.parser.add_argument( + "-a", "--array", help="job array index values") + self.parser.add_argument( + "--ntasks-per-node", + help="number of tasks to invoke on each node", + ) + self.parser.add_argument( + "-c", "--cpus-per-task", help="number of cpus required per task", default=1 + ) + + self.args = self.parser.parse_args() + print(self.args) + AppArgs.response = self.args + return diff --git a/src/sausage/readconf.py b/src/sausage/readconf.py index 6cb70ff..a2eb048 100644 --- a/src/sausage/readconf.py +++ b/src/sausage/readconf.py @@ -1,25 +1,27 @@ ''' Configuration file module ''' import configparser class ReadConf(object): def __init__(self): self.options = {} self.cfg_parser = configparser.ConfigParser() self.cfg_parser.read("/etc/sausage/sausage.cfg") ReadConf.cluster = self.cfg_parser["default"].get("cluster") ReadConf.hosts = self.cfg_parser["server"].get("urls").split(",") ReadConf.clusters = [ i.replace(" ", "") for i in self.cfg_parser["default"].get("clusters").split(",") ] ReadConf.indexes = self.cfg_parser["server"].get("index") ReadConf.fields = dict(self.cfg_parser["fields"]) ReadConf.debug = self.cfg_parser["default"].getboolean("log_devel") ReadConf.apikey = tuple( map(str, self.cfg_parser["server"].get("apikey").split(",")) ) if len(ReadConf.apikey) != 2: ReadConf.apikey = tuple((None, None)) + ReadConf.billinggrp = self.cfg_parser["default"].get( + "billing_group", "root")