"""
Created in 2025
@author: Simon Delarue <simon.delarue@telecom-paris.fr>
"""
from collections import defaultdict
import re
from tqdm import tqdm
import torch
[docs]class ContextualEmbedding():
"""Contextual embedding.
Parameters
----------
transformer
Class objet for transformer model.
model_name
Model name to upload.
tokenizer
Class object for tokenizer.
sentence_tokenizer
Class object for sentence tokenizer.
"""
def __init__(self, transformer, model_name, tokenizer, sentence_tokenizer):
self.model_name = model_name
self.transformer = transformer.from_pretrained(self.model_name)
self.tokenizer = tokenizer.from_pretrained(self.model_name)
self.sentence_tokenizer = sentence_tokenizer
self.transformer.eval()
[docs] def clean_text(self, text):
"""
Text cleaning for contextual embedding use.
"""
text = re.sub(r'(RR\\\d+EN\.docx|PE\d+\.\d+v\d+-\d+)', '', text)
text = re.sub(r'?\d4/\d+?[A−Z]∗?\d4/\d+\(?[A−Z]∗??[A−Z]∗?\d{4}/\d+\(?[A-Z]*??', '', text)
text = text.replace('\x0c', ' ')
text = re.sub(r'\n+', '\n', text)
text = re.sub(r'\s{2,}', ' ', text)
text = re.sub(r'\d+/\d+', '', text)
text = re.sub(r'https?://\S+|www\.\S+', '', text)
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'\.{4,}', '', text)
text = re.sub(r'-', '', text)
text = text.strip()
return text
[docs] def tokenize_and_chunk(self, text, tokenizer, max_length=512, stride=256):
"""Tokenize a text.
When number of tokens is greater than max_length, result is split into chunks.
Parameters
----------
text: str
Original text.
tokenizer:
Tokenizer.
max_length: int
Maximal number of tokens.
stride: int
Size of stride between token chunks. Useful in order not to lose contextual information when splitting tokens.
Returns
-------
list
List of :math:`n` dictionaries of tokens, with :math:`n` the number of chunks.
Each dictionary contains:
- input ids: tensor of ids of tokens in model vocabulary
- attention_mask: tensor of values indicating if a token is real (1) or comes from padding (0)
- offset_mapping: tensor indicating the start and end positions of each token within the original text.
"""
tokens = tokenizer(
text,
return_offsets_mapping=True,
return_attention_mask=True,
return_tensors="pt",
truncation=False,
add_special_tokens=True
)
input_ids = tokens["input_ids"][0]
attention_mask = tokens["attention_mask"][0]
offset_mapping = tokens["offset_mapping"][0]
chunks = []
for i in range(0, len(input_ids), stride):
input_chunk = input_ids[i:i + max_length]
attention_chunk = attention_mask[i:i + max_length]
offset_chunk = offset_mapping[i:i + max_length]
if input_chunk.size(0) < max_length:
pad_len = max_length - input_chunk.size(0)
input_chunk = torch.cat([input_chunk, torch.full((pad_len,), tokenizer.pad_token_id)])
attention_chunk = torch.cat([attention_chunk, torch.zeros(pad_len, dtype=torch.long)])
offset_chunk = torch.cat([offset_chunk, torch.tensor([[0, 0]] * pad_len)])
chunks.append({
"input_ids": input_chunk.unsqueeze(0),
"attention_mask": attention_chunk.unsqueeze(0),
"offset_mapping": offset_chunk
})
if i + max_length >= len(input_ids):
break
return chunks
[docs] def group_subword_embeddings(self, offset_mapping, embeddings, text, doc_id) -> list:
"""Average embeddings for splitted words.
Parameters
----------
offset_mapping: tensor
Indicates the start and end positions of each token within the original text.
embeddings
Matrix of word embeddings
text: str
Original textual context.
doc_id: int
Id of the document in which word embeddings were computed.
Returns
-------
list
List of dictionaries containing:
- doc id: id of the original document
- word: word
- position: position of word within original context
- embedding: tensor representing the embedding of the word
- context: textual context used for word embedding
"""
cached_start = None
cached_end = None
word_embeddings = []
subword_embeddings = []
for (start, end), embedding in zip(offset_mapping, embeddings):
if start == end:
continue
if cached_start is None:
cached_start = start.item()
cached_end = end.item()
subword_embeddings = [embedding]
elif start == cached_end:
cached_end = end.item()
subword_embeddings.append(embedding)
else:
word_text = text[cached_start:cached_end]
avg_word_embedding = torch.stack(subword_embeddings).mean(dim=0)
word_embeddings.append({
"doc_id": doc_id,
"word": word_text,
"position": (cached_start, cached_end),
"embedding": avg_word_embedding,
'context': text
})
cached_start = start.item()
cached_end = end.item()
subword_embeddings = [embedding]
# Save last word embedding
if subword_embeddings:
word_text = text[cached_start:cached_end]
avg_word_embedding = torch.stack(subword_embeddings).mean(dim=0)
word_embeddings.append({
"doc_id": doc_id,
"word": word_text,
"position": (cached_start, cached_end),
"embedding": avg_word_embedding,
'context': text
})
return word_embeddings
[docs] def clean_word(self, word):
word = re.sub(r'[^\w\s]', '', word)
word = word.lower()
return word
[docs] def encode_chunk(self, model, chunk, text, doc_id):
"""Compute word embeddings for a chunk of text.
Parameters
----------
model
Embedding model.
chunk
Chunk of tokenized text.
text
Original textual context.
doc_id
Id of the document.
Returns
-------
list
List of dictionaries containing:
- doc id: id of the original document
- word: word
- position: position of word within original context
- embedding: tensor representing the embedding of the word
- context: textual context used for word embedding
"""
with torch.no_grad():
outputs = model(
input_ids=chunk["input_ids"],
attention_mask=chunk["attention_mask"]
)
embeddings = outputs.last_hidden_state.squeeze(0)
offset_mapping = chunk["offset_mapping"]
# Merge subword embeddings
grouped_embeddings = self.group_subword_embeddings(offset_mapping, embeddings, text, doc_id)
return grouped_embeddings
[docs] def transform(self, corpus, keywords) -> list:
"""Performing contextual embedding of keywods within a corpus of documents.
Parameters
----------
corpus: list
List of documents.
keywords: list
List of keywords to search and embed.
Returns
-------
list
A list of :math:`n` embedding dictionaries, with :math:`n` the number of documents in corpus.
"""
all_embeddings = []
for doc_id, document in tqdm(enumerate(corpus), total=len(corpus), desc='corpus'):
embeddings = self.process_doc(document, doc_id, keywords)
all_embeddings.append(embeddings)
return all_embeddings
[docs] def process_doc(self, document, doc_id, keywords) -> dict:
"""Processing a document consists in:
- cleaning document
- splitting in chunks (if necessary)
- extracting context for each keyword
- embedding keyword according to each specific context
Parameters
----------
document: str
Textual document.
keywords: list of str
List of keywords to search and embed.
doc_id: int
Document id.
keywords: list
List of keywords to search and embed.
Returns
-------
dict
Dictionary of embeddings with the following structure:
keyword_0 (STR):
context_0 (STR):
word_0 (STR): embedding (tensor)
word_1 (STR): embedding (tensor)
...
context_1 (STR):
word_0 (STR): embedding (tensor)
word_1 (STR): embedding (tensor)
...
...
keyword_1 (STR):
context_0 (STR):
word_0 (STR): embedding (tensor)
word_1 (STR): embedding (tensor)
...
context_1 (STR):
word_0 (STR): embedding (tensor)
word_1 (STR): embedding (tensor)
...
...
...
"""
cleaned_text = self.clean_text(document)
contexts = self.extract_keyword_sentences(cleaned_text, keywords, self.sentence_tokenizer)
all_embeddings = {}
for keyword, context_list in contexts.items():
all_embeddings[keyword] = {}
for i, context in enumerate(context_list):
all_embeddings[keyword][context] = defaultdict(list)
chunks = self.tokenize_and_chunk(context, self.tokenizer)
for chunk in chunks:
word_embeddings = self.encode_chunk(self.transformer, chunk, context, doc_id)
# Case-insensitive
for word_embedding in word_embeddings:
all_embeddings[keyword][context][self.clean_word(word_embedding.get('word'))].append(word_embedding.get('embedding'))
# Average word embeddings for all chunks within a context
for k, v in all_embeddings[keyword][context].items():
all_embeddings[keyword][context][k] = torch.stack(v).mean(axis=0)
# Average word embeddings for multi-word keywords
subwords = keyword.split(' ')
if len(subwords) > 1:
subword_embeddings = []
for subword in subwords:
# Punctuation can occur in embedded subwords; this allows to take into account punctuated subwords
for embedded_word in all_embeddings[keyword][context].keys():
if subword in embedded_word:
if not isinstance(all_embeddings[keyword][context][embedded_word], list):
subword_embeddings.append(all_embeddings[keyword][context][embedded_word])
avg_subword_embedding = torch.stack(subword_embeddings, dim=0).mean(axis=0)
all_embeddings[keyword][context][keyword] = avg_subword_embedding
return all_embeddings