Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aamini
GitHub Repository: aamini/introtodeeplearning
Path: blob/master/mitdeeplearning/lab3.py
547 views
1
import os
2
3
from openai import OpenAI
4
from datasets import load_dataset
5
from torch.utils.data import DataLoader
6
7
8
cwd = os.path.dirname(__file__)
9
10
def create_dataloader(style):
11
ds = load_dataset("databricks/databricks-dolly-15k", split="train")
12
with open(os.path.join(cwd, f"data/text_styles/{style}.txt"), "r") as f:
13
new_responses = [line.strip().replace("\\n", "\n") for line in f]
14
15
# Update the entire dataset at once with the new responses
16
ds_ = ds.select(range(len(new_responses)))
17
ds_ = ds_.map(
18
lambda x, idx: {"response_style": new_responses[idx]},
19
with_indices=True,
20
num_proc=1
21
)
22
23
n = len(new_responses)
24
ds_test = ds.select(range(n, n+n))
25
26
# Create a dataloader
27
dataloader = DataLoader(ds_, batch_size=1, shuffle=True)
28
dataloader_test = DataLoader(ds_test, batch_size=1, shuffle=True)
29
return dataloader, dataloader_test
30
31
32
33
class LLMClient:
34
def __init__(self, model: str, api_key: str, api_base: str = "https://openrouter.ai/api/v1"):
35
self.llm_client = OpenAI(api_key=api_key, base_url=api_base)
36
self.model = model
37
38
def ask(self, user: str, system: str = None, **kwargs):
39
messages = [{"role": "user", "content": user}]
40
if system:
41
messages.insert(0, {"role": "system", "content": system})
42
res = self.llm_client.chat.completions.create(
43
model=self.model,
44
messages=messages,
45
**kwargs
46
)
47
return res
48
49
50
yoda_test_text = (
51
"Wisdom, sought by many, found by few, it is. Haste not, patience have. "
52
"For in stillness, answers come. Much to learn, still you have. "
53
"Fear leads to anger; anger, to hate. Down the dark path, guide you it will. "
54
"Trust the Force, you must. Powerful ally it is. Life it creates, surrounds, binds. "
55
"Adventure, excitement, a Jedi craves not these things. Discipline, balance, seek you should. "
56
"Hmm, clearer now is the path, yes? Help you more, I can, if needed it is. "
57
"Endless, the journey of learning is. Stay true to your path, and clarity you will find. "
58
"Remember, the Force flows through all, but your heart determines how it shapes your destiny. "
59
"Much more to teach, I have. Ready, are you? Mmm."
60
)
61
62
63
64
# class Llama(LLMClient):
65
# def __init__(self, api_key: str):
66
# """
67
# Initialize the LlamaFree model client.
68
69
# LlamaFree is available from LlamaFree.
70
# Provide your LlamaFree API key (`api_key`) to access.
71
# """
72
# # super().__init__(model="meta-llama/llama-3.2-3b-instruct", api_key=api_key)
73
# super().__init__(model="meta-llama/llama-3.1-8b-instruct", api_key=api_key)
74
75
76
# class LFM40B(LLMClient):
77
# def __init__(self, api_key: str):
78
# """
79
# Initialize the LFM-40B model client.
80
81
# LFM-40B is available from Lambda Labs.
82
# Provide your Lambda Labs API key (`api_key`) to access.
83
# """
84
# api_base = "https://api.lambdalabs.com/v1"
85
# super().__init__(model="lfm-40b", api_base=api_base, api_key=api_key)
86
87