Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
elebumm
GitHub Repository: elebumm/RedditVideoMakerBot
Path: blob/master/utils/ai_methods.py
327 views
1
import numpy as np
2
import torch
3
from transformers import AutoModel, AutoTokenizer
4
5
6
# Mean Pooling - Take attention mask into account for correct averaging
7
def mean_pooling(model_output, attention_mask):
8
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
9
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
10
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
11
input_mask_expanded.sum(1), min=1e-9
12
)
13
14
15
# This function sort the given threads based on their total similarity with the given keywords
16
def sort_by_similarity(thread_objects, keywords):
17
# Initialize tokenizer + model.
18
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
19
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
20
21
# Transform the generator to a list of Submission Objects, so we can sort later based on context similarity to
22
# keywords
23
thread_objects = list(thread_objects)
24
25
threads_sentences = []
26
for i, thread in enumerate(thread_objects):
27
threads_sentences.append(" ".join([thread.title, thread.selftext]))
28
29
# Threads inference
30
encoded_threads = tokenizer(
31
threads_sentences, padding=True, truncation=True, return_tensors="pt"
32
)
33
with torch.no_grad():
34
threads_embeddings = model(**encoded_threads)
35
threads_embeddings = mean_pooling(threads_embeddings, encoded_threads["attention_mask"])
36
37
# Keywords inference
38
encoded_keywords = tokenizer(keywords, padding=True, truncation=True, return_tensors="pt")
39
with torch.no_grad():
40
keywords_embeddings = model(**encoded_keywords)
41
keywords_embeddings = mean_pooling(keywords_embeddings, encoded_keywords["attention_mask"])
42
43
# Compare every keyword w/ every thread embedding
44
threads_embeddings_tensor = torch.tensor(threads_embeddings)
45
total_scores = torch.zeros(threads_embeddings_tensor.shape[0])
46
cosine_similarity = torch.nn.CosineSimilarity()
47
for keyword_embedding in keywords_embeddings:
48
keyword_embedding = torch.tensor(keyword_embedding).repeat(
49
threads_embeddings_tensor.shape[0], 1
50
)
51
similarity = cosine_similarity(keyword_embedding, threads_embeddings_tensor)
52
total_scores += similarity
53
54
similarity_scores, indices = torch.sort(total_scores, descending=True)
55
56
threads_sentences = np.array(threads_sentences)[indices.numpy()]
57
58
thread_objects = np.array(thread_objects)[indices.numpy()].tolist()
59
60
# print('Similarity Thread Ranking')
61
# for i, thread in enumerate(thread_objects):
62
# print(f'{i}) {threads_sentences[i]} score {similarity_scores[i]}')
63
64
return thread_objects, similarity_scores
65
66