templateNER
templateNER copied to clipboard
Hard coded numbers in template_entity function of inference.py
Hi,
would you mind explaining some hard-coded numbers in the template_entity function from inference.py?
def template_entity(words, input_TXT, start):
# input text -> template
words_length = len(words)
words_length_list = [len(i) for i in words]
input_TXT = [input_TXT]*(5*words_length)
input_ids = tokenizer(input_TXT, return_tensors='pt')['input_ids']
model.to(device)
template_list = [" is a location entity .", " is a person entity .", " is an organization entity .",
" is an other entity .", " is not a named entity ."]
entity_dict = {0: 'LOC', 1: 'PER', 2: 'ORG', 3: 'MISC', 4: 'O'}
temp_list = []
for i in range(words_length):
for j in range(len(template_list)):
temp_list.append(words[i]+template_list[j])
output_ids = tokenizer(temp_list, return_tensors='pt', padding=True, truncation=True)['input_ids']
output_ids[:, 0] = 2
output_length_list = [0]*5*words_length
for i in range(len(temp_list)//5):
base_length = ((tokenizer(temp_list[i * 5], return_tensors='pt', padding=True, truncation=True)['input_ids']).shape)[1] - 4
output_length_list[i*5:i*5+ 5] = [base_length]*5
output_length_list[i*5+4] += 1
score = [1]*5*words_length
with torch.no_grad():
output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids[:, :output_ids.shape[1] - 2].to(device))[0]
for i in range(output_ids.shape[1] - 3):
# print(input_ids.shape)
logits = output[:, i, :]
logits = logits.softmax(dim=1)
# values, predictions = logits.topk(1,dim = 1)
logits = logits.to('cpu').numpy()
# print(output_ids[:, i+1].item())
for j in range(0, 5*words_length):
if i < output_length_list[j]:
score[j] = score[j] * logits[j][int(output_ids[j][i + 1])]
end = start+(score.index(max(score))//5)
# score_list.append(score)
return [start, end, entity_dict[(score.index(max(score))%5)], max(score)] #[start_index,end_index,label,score]
I learned from the opened issues that the 5s are the length of the template_list but how about the other numbers?
It would be a great help if you could response to this, thank you in advance!
Have you solved this problem