Page MenuHomec4science

gen_master.py
No OneTemporary

File Metadata

Created
Mon, Sep 2, 21:56

gen_master.py

import json
import os
import random
import copy
from abc import abstractmethod
from transformers import TFGPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
class GenData:
def __init__(self, home_path):
self.home_path = home_path
os.chdir(self.home_path)
os.chdir("./transformers/examples/")
os.chdir("./language-modeling")
self.finetune_path = os.getcwd()
self.nbEpochs = 5
self.outModelName = "genFakeData5"
os.chdir(self.home_path)
os.chdir("./models")
self.model_path = os.getcwd()
os.chdir(self.home_path)
os.chdir("./crawl_data")
self.crawl_path = os.getcwd()
self.data = {"first name": [], "last name": [], "city": []}
self.fake_data = {"first name": [], "last name": [], "city": []}
self.model = None
self.tokenizer = None
self.keys = None
self.keys_all = None
self.schema = None
self.abrev = None
# Private methods
# protected functions
def _strnum(self, i, min_len=2):
i = str(i)
while (len(i) < min_len):
i = "0" + i
return i
@abstractmethod
def _get_firstName(self):
raise NotImplementedError("abstract method")
@abstractmethod
def _get_lastName(self):
raise NotImplementedError("abstract method")
@abstractmethod
def _get_city(self):
raise NotImplementedError("abstract method")
@abstractmethod
def _get_answer(self, key):
raise NotImplementedError("abstract method")
def _fill_schema(self, question, nb_keys="all"):
tmp = copy.deepcopy(self.keys)
if nb_keys != "all":
random.shuffle(tmp)
tmp = tmp[:nb_keys]
if question in tmp:
pass
else:
# replace a random entry
tmp[random.randint(0, len(tmp) - 1)] = question
context = ""
answer = ""
for key in tmp:
ans = self._get_answer(key)
ans = ans[random.randint(0, len(ans)-1)]
context += self.schema[key] + ": " + str(ans) + "; "
if key == question:
answer = ans
return context, question, answer
def __doubleCapital(self, word):
for i, l in enumerate(word[:-1]):
if l.isupper() and word[i+1].isupper() and l.lower() in "abcdefghijklmnopqrstuvwxyz" \
and word[i+1].lower() in "abcdefghijklmnopqrstuvwxyz":
return True
return False
def _clean_name(self, dirty, countries=[]):
# collect all city names
city_names = []
for c in countries:
data, _ = c._get_data()
city_names += data["city"]
clean = []
for d in dirty:
# Name contains spaces and is not "van ..."
if " " in d and "van" not in d:
pass
# first letter is lowercase or "." is in word, or it's just two letters
elif d[0].upper() != d[0] or "." in d or len(d) <= 2 or "(" in d or ")" in d:
pass
# contains numbers
elif any(char.isdigit() for char in d):
pass
# everything is in upper case
elif all(char.isupper() for char in d):
pass
# specific for finnish crawl
elif "A-Z" in d or "oy" == d.lower() or "studio" in d.lower():
pass
elif ">" in d or "<" in d or "/" in d:
pass
elif self.__doubleCapital(d):
pass
elif d in city_names:
pass
else:
clean.append(d)
return clean
def _clean_city(self, dirty, countries=[]):
# collect all city names
other_city_names = []
for c in countries:
data, _ = c._get_data()
other_city_names += data["city"]
clean = []
for d in dirty:
# first letter is lowercase, or it's just two letters
if d[0].upper() != d[0] or len(d) <= 2 or "(" in d or ")" in d:
pass
# contains numbers
elif any(char.isdigit() for char in d):
pass
# everything is in upper case
elif all(char.isupper() for char in d):
pass
elif ">" in d or "<" in d or "/" in d:
pass
elif self.__doubleCapital(d):
pass
elif d in other_city_names:
pass
else:
clean.append(d)
return clean
@abstractmethod
def _gen_date_(self, d, m, y):
raise NotImplementedError("abstract method")
@abstractmethod
def _gen_identityCard(self):
raise NotImplementedError("abstract method")
def _gen_date(self, from_y, to_y):
# year should be given in the format 2012 (4 digits)
dates = []
for d in range(1, 32):
for m in range(1, 13):
for y in range(from_y + 1, to_y + 1):
dates.append(self._gen_date_(d, m, y))
return dates
def _gpt_train_entry(self, ident, nat, list):
keywords = ["<input>", "<answer>", "<find>", "<|endoftext|>"]
out = []
for l in list:
out.append(keywords[0] + str(list[:10]) + "<{}><{}>".format(ident, nat) + l + keywords[3])
return out
def _gpt_save(self, list):
long_str = ""
for x in list:
try:
long_str += x + "\n"
except:
pass
import io
with io.open("gen_data_train.txt", 'w', encoding="utf-8") as f:
# with open("gen_data_train.txt", 'w') as f:
f.write(long_str)
def _get_data(self):
return self.data, self.fake_data
def _set_data(self, data, fake_data):
if data is not None:
self.data = data
if fake_data is not None:
self.fake_data = fake_data
def _get_country(self):
return self.abrev
def _fake_gen(self, ident, nat, list):
keywords = ["<input>", "<answer>", "<find>", "<|endoftext|>"]
prompt = keywords[0] + str(list[:10]) + "<{}><{}>".format(ident, nat) # + l + keywords[3]
input_ids = self.tokenizer.encode(prompt, return_tensors='tf')
generated_text_samples = self.model.generate(
input_ids,
max_length=len(input_ids[0]) + 50,
num_return_sequences=1,
no_repeat_ngram_size=0,
repetition_penalty=1.0,
top_p=1.0,
temperature=1.0,
do_sample=True,
top_k=0,
early_stopping=True
)
answer = self.tokenizer.decode(generated_text_samples[0])
answer = answer.replace(prompt, "")
answer = answer.replace("<|endoftext|>", "")
return answer
# Public methods
def rm_dub(self, dirty): # remove dublicates
return list(set(dirty))
def load_real_data(self):
os.chdir(self.crawl_path)
self.data["first name"] = self._clean_name(self.rm_dub(self._get_firstName()))
self.data["last name"] = self._clean_name(self.rm_dub(self._get_lastName()))
self.data["city"] = self._clean_city(self.rm_dub(self._get_city()))
def train_fakeGen(self, countries):
# countries is a list of child classes
train_data = []
for ident in self.data.keys():
for country in countries:
data, _ = country._get_data()
random.shuffle(data)
if len(data) > 5000:
data = data[:5000]
else:
print("Warning, label <{}> in nationality <{}> only has {} entries".format(
ident, country._get_country(), len(data)))
train_data.append(self._gpt_train_entry(ident, country._get_country(), data))
random.shuffle(train_data)
self._gpt_save(train_data)
os.chdir(self.finetune_path)
with open("real_id_data" + ".json", "w") as f:
json.dump(data, f)
cmd = "python run_clm.py \
--model_type {} \
--train_file \"{}\" \
--do_train \
--per_gpu_train_batch_size 1 \
--save_steps -1 \
--num_train_epochs {} \
--fp16 \
--tokenizer_name gpt2 \
--model_name_or_path gpt2 \
--output_dir=\"{}\" \
".format(
"gpt2",
"gen_data_train" + ".txt",
self.nbEpochs,
self.model_path + "/" + self.outModelName)
print(cmd)
os.system(cmd)
def run_fakeGen(self, nb, countries):
os.chdir(self.home_path)
self.model = TFGPT2LMHeadModel.from_pretrained(self.model_path + "/" + self.outModelName, from_pt=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path + "/" + self.outModelName)
os.chdir(self.crawl_path)
#try:
f = 'fake_id_data_'+ self.abrev +'.json'
print(f)
with open(f, 'r') as fp:
fake_data = json.load(fp)
# apply new cleaning rules
for key in fake_data.keys():
if key in ["first name", "last name"]:
fake_data[key] = self._clean_name(fake_data[key], countries)
if key in ["city"]:
fake_data[key] = self._clean_city(fake_data[key], [c for c in countries if c._get_country() != self.abrev])
self.fake_data = fake_data
print("{}:\n# first names: {}, # last names: {}, # cities: {}".format(self.abrev,
len(fake_data["first name"]),
len(fake_data["last name"]),
len(fake_data["city"])))
#except:
# print("couldn't find existing fake data entries")
for ident in self.data.keys():
print("{}, {}".format(ident, self.abrev))
while len(self.fake_data[ident]) < nb:
data, fake_data = self._get_data()
tmp = self._fake_gen(ident, self.abrev, data[ident])
# clean names
try:
if ident in ["first name", "last name"]:
tmp = self._clean_name([tmp], countries)[0]
if ident in ["city"]:
tmp = self._clean_city([tmp], [c for c in countries if c._get_country() != self.abrev])[0]
fake_data[ident].append(tmp)
except:
# if the clean function returns an empty list
pass
fake_data[ident] = self.rm_dub(fake_data[ident])
self.fake_data = fake_data
new = [tmp for tmp in fake_data[ident] if tmp not in data[ident]]
print("There have been {} new generation, and {} identical to the test set".format(len(new), len(
fake_data[ident]) - len(new)))
data, fake_data = self._get_data()
new = [tmp for tmp in fake_data[ident] if tmp not in data[ident]]
print(new)
print("There have been {} new generation, and {} identical to the test set".format(len(new), len(
fake_data[ident]) - len(new)))
print("{}, {} -- ended search".format(ident, self.abrev))
os.chdir(self.crawl_path)
with open("fake_id_data_"+ self.abrev + ".json", "w") as f:
json.dump(fake_data, f)
def fill_schema(self, keys="all", nb_keys="all"):
if keys == "all":
keys = self.keys
for k in keys:
context, question, answer = self._fill_schema(k, nb_keys=nb_keys)
print(context)
print(question)
print(answer)

Event Timeline