Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/scripts/tinker_train_gem.py
2240 views
1
# Copyright 2025 AxonRL Team. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
A basic RL implementation to train agents on GEM environments using Tinker backends.
17
Install the following libraries:
18
- https://github.com/thinking-machines-lab/tinker (need API key)
19
- https://github.com/axon-rl/gem
20
Then run `python tinker_train_gem max_steps=20`
21
Note: we do not use tinker-cookbook.
22
"""
23
24
import asyncio
25
import json
26
import logging
27
import os
28
import pprint
29
import time
30
from datetime import datetime
31
from typing import Any, Literal
32
33
import chz
34
import numpy as np
35
import tinker
36
import torch
37
import wandb
38
from termcolor import colored
39
from tinker import types
40
from tinker.types.tensor_data import TensorData
41
from transformers.models.auto.tokenization_auto import AutoTokenizer
42
from transformers.tokenization_utils import PreTrainedTokenizer
43
44
import gem
45
from gem.wrappers.wrapper_factory import get_wrapper_fns
46
47
logger = logging.getLogger(__name__)
48
logging.getLogger("httpx").setLevel(logging.WARN)
49
50
51
52
53
@chz.chz
54
class Config:
55
# https://tinker-docs.thinkingmachines.ai/model-lineup
56
model_name: str = "Qwen/Qwen3-4B-Instruct-2507" # "Qwen/Qwen3-8B-Base"
57
batch_size: int = 128
58
learning_rate: float = 4e-5
59
lora_rank: int = 8 # 32
60
max_tokens: int = 1024 # 2048
61
seed: int = 0
62
max_steps: int = 20
63
save_every: int = 5 # -1
64
65
env_id: str = "game:GuessTheNumber-v0-easy" # GEM environment ID
66
num_env: int = 32 # 4 # number of parallel environments
67
env_wrappers: str = (
68
"concat" # wrappers are typically used to concat chat history, etc.
69
)
70
template: Literal["qwen3_general", "qwen3_game", "no"] = "qwen3_game"
71
72
gamma: float = 0.9
73
use_rebn: bool = True
74
75
wandb_project: str | None = None
76
wandb_name: str | None = None
77
log_dir: str | None = None
78
79
80
# Define a lightweight renderer following tinker's renderer logics
81
def apply_qwen3_game_template(observation: str) -> str:
82
return (
83
f"<|im_start|>user\nYou are playing language games. Make valid actions to win.\nObservation: {observation}"
84
"\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n"
85
"<|im_start|>assistant\n"
86
)
87
88
89
def apply_qwen3_game_no_think_template(observation: str) -> str:
90
return (
91
f"<|im_start|>user\nYou are playing language games. Make valid actions to win.\nObservation: {observation}"
92
"\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n"
93
"<|im_start|>assistant\n"
94
)
95
96
97
def apply_qwen3_general_template(question: str) -> str:
98
return (
99
f"<|im_start|>user\nQuestion: {question}"
100
"\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n"
101
"<|im_start|>assistant\n"
102
)
103
104
105
def apply_no_template(observation: str) -> str:
106
return observation
107
108
109
TEMPLATE_FACTORY = {
110
"qwen3_game": apply_qwen3_game_template,
111
"qwen3_general": apply_qwen3_general_template,
112
"no": apply_no_template,
113
}
114
115
116
def get_tokenizer(model_name: str) -> PreTrainedTokenizer:
117
# Avoid gating of Llama 3 models:
118
if model_name.startswith("meta-llama/Llama-3"):
119
model_name = "baseten/Meta-Llama-3-tokenizer"
120
return AutoTokenizer.from_pretrained(model_name, use_fast=True)
121
122
123
async def save_checkpoint_async(
124
training_client: tinker.TrainingClient,
125
name: str,
126
log_path: str,
127
loop_state: dict[str, Any],
128
kind: Literal["state", "sampler", "both"] = "state",
129
) -> dict[str, str]:
130
"""Save model checkpoint.
131
Args:
132
training_client: Training client to save from
133
name: Name for the checkpoint
134
log_path: Path to the log directory, where we can find checkpoints.jsonl file
135
Returns:
136
Path to the saved checkpoint
137
"""
138
futures = {}
139
if kind in ["state", "both"]:
140
futures["state"] = await training_client.save_state_async(name)
141
if kind in ["sampler", "both"]:
142
futures["sampler"] = await training_client.save_weights_for_sampler_async(name)
143
144
results = {k: await v.result_async() for k, v in futures.items()}
145
paths = {k + "_path": v.path for k, v in results.items()}
146
logger.info(f"Saved checkpoints: {paths}")
147
full_dict = {"name": name, **loop_state, **paths}
148
with open(os.path.join(log_path, "checkpoints.jsonl"), "a") as f:
149
f.write(json.dumps(full_dict) + "\n")
150
151
return paths
152
153
154
155
def augment_transitions_with_advantages(episodes_buffer, config):
156
transitions = []
157
for episode in episodes_buffer:
158
# Augment each (s, a, r) transition with MC estimate of return to go.
159
rewards = [transition["reward"] for transition in episode]
160
cur = 0.0
161
for i in reversed(range(len(rewards))):
162
cur = rewards[i] + config.gamma * cur
163
episode[i]["return"] = cur
164
transitions.extend(episode)
165
166
# return batch normalization
167
if config.use_rebn:
168
returns = torch.tensor([transition["return"] for transition in transitions]).float()
169
returns = (returns - returns.mean()) / (returns.std() + 1e-9)
170
for i, transition in enumerate(transitions):
171
transition["return"] = returns[i].item()
172
173
# subsample to make a constant batch size
174
if len(transitions) > config.batch_size:
175
transitions = np.random.choice(transitions, config.batch_size, replace=False)
176
return transitions
177
178
179
180
def make_training_data(transitions):
181
# prepare training datums compatible with Tinker API
182
training_datums = []
183
for transition in transitions:
184
ob_len_m1 = len(transition["obs_tokens"]) - 1 # -1 due to shifting
185
tokens = transition["obs_tokens"] + transition["act_tokens"]
186
187
input_tokens = tokens[:-1]
188
target_tokens = tokens[1:]
189
all_logprobs = [0.0] * ob_len_m1 + transition["act_logprobs"]
190
all_advantages = [0.0] * ob_len_m1 + [transition["return"]] * (len(input_tokens) - ob_len_m1)
191
192
datum = types.Datum(
193
model_input=types.ModelInput.from_ints(tokens=input_tokens),
194
loss_fn_inputs={
195
"target_tokens": TensorData.from_torch(torch.tensor(target_tokens)),
196
"logprobs": TensorData.from_torch(torch.tensor(all_logprobs)),
197
"advantages": TensorData.from_torch(torch.tensor(all_advantages)),
198
},
199
)
200
training_datums.append(datum)
201
return training_datums
202
203
204
205
async def collect_episode_async(sampling_client, sampling_params, env, tokenizer, config):
206
transitions = []
207
obs, _ = env.reset()
208
while True:
209
# 1) prepare observation
210
obs = TEMPLATE_FACTORY[config.template](obs) # add system prompt
211
obs_tokens = tokenizer.encode(obs, add_special_tokens=False)
212
213
# 2) sample an action from the policy
214
try:
215
sample_result = await sampling_client.sample_async(
216
prompt=types.ModelInput.from_ints(tokens=obs_tokens),
217
num_samples=1,
218
sampling_params=sampling_params,
219
)
220
except Exception:
221
transitions = []
222
break
223
sampled_tokens = sample_result.sequences[0].tokens
224
sampled_logprobs = sample_result.sequences[0].logprobs
225
action = tokenizer.decode(sampled_tokens)
226
227
# 3) step the environment
228
next_obs, reward, terminated, truncated, _ = env.step(action)
229
done = terminated | truncated
230
obs = next_obs
231
232
# 4) save into buffer
233
transitions.append(
234
{ "obs_tokens": obs_tokens, "act_tokens": sampled_tokens,
235
"act_logprobs": sampled_logprobs, "reward": reward, "done": done})
236
237
if done:
238
break
239
return transitions
240
241
async def collect_episodes_buffer_async(sampling_client, sampling_params, envs, tokenizer, config):
242
episodes_buffer = []
243
while True:
244
batch_episodes = await asyncio.gather(
245
*[collect_episode_async(sampling_client, sampling_params, env, tokenizer, config) for env in envs])
246
batch_episodes = [x for x in batch_episodes if x != []]
247
episodes_buffer.extend(batch_episodes)
248
if sum([len(ep) for ep in episodes_buffer]) >= config.batch_size:
249
break
250
return episodes_buffer
251
252
253
def debug_print_episodes(episodes_buffer, tokenizer):
254
# print at most two episodes for debugging purposes
255
for n, episode in enumerate(episodes_buffer):
256
print(f"----- episode {n} -----")
257
for t, transition in enumerate(episode):
258
obs = tokenizer.decode(transition["obs_tokens"])
259
act = tokenizer.decode(transition["act_tokens"])
260
#obs = obs[:196] + "\n...\n" + obs[-200:] if len(obs) > 396 else obs
261
#act = act[:196] + "\n...\n" + act[-200:] if len(act) > 396 else act
262
print(f"turn={t+1}")
263
print(colored(obs, "blue"))
264
print(colored(act, "light_red", attrs=["bold"]))
265
print(
266
colored(
267
"reward=" + str(transition["reward"]),
268
"light_magenta",
269
attrs=["bold"],
270
)
271
)
272
if n > 0:
273
break
274
275
def make_envs(config: Config, tokenizer: PreTrainedTokenizer):
276
wrappers = get_wrapper_fns(config.env_wrappers, tokenizer=tokenizer)
277
# init one env first, check if it has dataset; if so we avoid load from HF multiple times
278
# by directly providing dataset when creating the env. (we can also use the gem.Env.spawn api).
279
envs = [gem.make(config.env_id, seed=int(time.time_ns()), use_mp=False)]
280
for i in range(config.num_env - 1):
281
dataset = envs[0].dataset if hasattr(envs[0], "dataset") else None
282
envs.append(
283
gem.make(
284
config.env_id,
285
seed=int(time.time_ns()) * i,
286
dataset=dataset,
287
use_mp=False,
288
)
289
)
290
for i in range(len(envs)):
291
for wrapper in wrappers:
292
envs[i] = wrapper(envs[i])
293
return envs
294
295
def compute_policy_metrics(config, transitions, fwd_bwd_result):
296
# compute policy entropy and sampler-learner difference
297
act_token_logprobs = []
298
act_token_diffs = []
299
for i in range(config.batch_size):
300
transition = transitions[i]
301
train_output = fwd_bwd_result.loss_fn_outputs[i]
302
nact = len(transition["act_logprobs"])
303
act_token_logprobs.extend(transition["act_logprobs"])
304
sampling_token_logprobs = torch.tensor(transition["act_logprobs"])
305
policy_token_logprobs = train_output["logprobs"].to_torch()[-nact:]
306
# k1 = E_{a~qsample} [log q_sample(a) - log p_policy(a)]
307
act_token_diffs.append(sampling_token_logprobs - policy_token_logprobs)
308
309
act_token_diffs = torch.cat(act_token_diffs)
310
kl_sample_train_v1 = act_token_diffs.mean().item()
311
kl_sample_train_v2 = 0.5 * (act_token_diffs**2).mean().item()
312
return {
313
"token_entropy": -torch.tensor(act_token_logprobs).mean().item(),
314
"kl_sample_train_v1": kl_sample_train_v1,
315
"kl_sample_train_v2": kl_sample_train_v2,
316
}
317
318
async def main(config: Config):
319
# Setup logging
320
wandb_name = (
321
config.wandb_name or config.model_name.split("/")[-1] + f"_{config.env_id}"
322
)
323
wandb_name += "_" + datetime.now().strftime("%m%dT%H:%M:%S")
324
save_path = os.path.join("./tinker_output", wandb_name)
325
os.makedirs(save_path, exist_ok=True)
326
327
wandb.init(
328
project=config.wandb_project,
329
config=chz.asdict(config),
330
dir=str(config.log_dir) if config.log_dir else None,
331
name=wandb_name,
332
)
333
334
tokenizer = get_tokenizer(config.model_name)
335
envs = make_envs(config, tokenizer)
336
337
service_client = tinker.ServiceClient()
338
training_client = await service_client.create_lora_training_client_async(
339
base_model=config.model_name, rank=config.lora_rank
340
)
341
sampling_params = tinker.types.SamplingParams(
342
max_tokens=config.max_tokens,
343
)
344
adam_params = types.AdamParams(
345
learning_rate=config.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8
346
)
347
348
349
# Start agent-environment loop (Algo: https://arxiv.org/pdf/2510.01051#page=15.10):
350
for policy_iteration_step in range(config.max_steps):
351
print("=" * 10 + f" Step {policy_iteration_step} " + "=" * 10)
352
metrics = {"step": policy_iteration_step}
353
354
# save model
355
if (
356
config.save_every > 0
357
and policy_iteration_step > 0
358
and policy_iteration_step % config.save_every == 0
359
):
360
await save_checkpoint_async(
361
training_client,
362
f"{policy_iteration_step:06d}",
363
log_path=save_path,
364
kind="state",
365
loop_state={"policy_iteration_step": policy_iteration_step},
366
)
367
368
sampling_path = (
369
training_client.save_weights_for_sampler(
370
name=f"{policy_iteration_step:06d}"
371
)
372
.result()
373
.path
374
)
375
sampling_client = service_client.create_sampling_client(
376
model_path=sampling_path
377
)
378
379
# collect episodes with parallel environments
380
print(f"🎲 Start collecting episodes at step {policy_iteration_step}")
381
st = time.time()
382
383
episodes_buffer = await collect_episodes_buffer_async(
384
sampling_client, sampling_params, envs, tokenizer, config)
385
debug_print_episodes(episodes_buffer, tokenizer)
386
transitions = augment_transitions_with_advantages(episodes_buffer, config)
387
training_datums = make_training_data(transitions)
388
389
metrics["time/sample"] = time.time() - st
390
metrics["sampler/episode_return"] = np.mean(
391
[sum(transition["reward"] for transition in episode) for episode in episodes_buffer])
392
metrics["sampler/num_turns_per_episode"] = np.mean(
393
[len(episode) for episode in episodes_buffer])
394
metrics["sampler/action_num_tokens"] = np.mean(
395
[sum(len(transition["act_tokens"]) for transition in episode) for episode in episodes_buffer])
396
metrics["sampler/num_episodes"] = len(episodes_buffer)
397
398
print(f"🎈 Start training at step {policy_iteration_step}")
399
st = time.time()
400
401
fwd_bwd_future = training_client.forward_backward(
402
training_datums, loss_fn="importance_sampling"
403
)
404
optim_step_future = training_client.optim_step(adam_params)
405
fwd_bwd_result = fwd_bwd_future.result()
406
_ = optim_step_future.result()
407
res = compute_policy_metrics(config, transitions, fwd_bwd_result)
408
409
metrics["time/train"] = time.time() - st
410
metrics["sampler/token_entropy"] = res["token_entropy"]
411
metrics["train/kl_sample_train_v1"] = res["kl_sample_train_v1"]
412
metrics["train/kl_sample_train_v2"] = res["kl_sample_train_v2"]
413
metrics.update(**{f"train/{k}": v for k, v in fwd_bwd_result.metrics.items()})
414
pprint.pprint(metrics)
415
wandb.log(metrics)
416
417
418
await save_checkpoint_async(training_client, f"{policy_iteration_step:06d}",
419
log_path=save_path, kind="state", loop_state={"policy_iteration_step": policy_iteration_step})
420
wandb.finish()
421
422
423
async def main_no_metrics(config: Config):
424
# shorter version, for the book
425
wandb_name = (config.wandb_name or config.model_name.split("/")[-1] + f"_{config.env_id}")
426
wandb_name += "_" + datetime.now().strftime("%m%dT%H:%M:%S")
427
save_path = os.path.join("./tinker_output", wandb_name)
428
os.makedirs(save_path, exist_ok=True)
429
430
tokenizer = get_tokenizer(config.model_name)
431
envs = make_envs(config, tokenizer)
432
service_client = tinker.ServiceClient()
433
training_client = await service_client.create_lora_training_client_async(
434
base_model=config.model_name, rank=config.lora_rank)
435
sampling_params = tinker.types.SamplingParams(max_tokens=config.max_tokens)
436
adam_params = types.AdamParams(
437
learning_rate=config.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8)
438
439
for policy_iteration_step in range(config.max_steps):
440
sampling_path = (training_client.save_weights_for_sampler(
441
name=f"{policy_iteration_step:06d}").result().path)
442
sampling_client = service_client.create_sampling_client(model_path=sampling_path)
443
444
episodes_buffer = await collect_episodes_buffer_async(
445
sampling_client, sampling_params, envs, tokenizer, config)
446
transitions = augment_transitions_with_advantages(episodes_buffer, config)
447
training_datums = make_training_data(transitions)
448
449
#fwd_bwd_future = training_client.forward_backward(training_datums, loss_fn="ppo")
450
fwd_bwd_future = training_client.forward_backward(training_datums, loss_fn="importance_sampling")
451
optim_step_future = training_client.optim_step(adam_params)
452
fwd_bwd_result = fwd_bwd_future.result()
453
_ = optim_step_future.result()
454
455
await save_checkpoint_async(training_client, f"{policy_iteration_step:06d}",
456
log_path=save_path, kind="state", loop_state={"policy_iteration_step": policy_iteration_step})
457
458
if __name__ == "__main__":
459
asyncio.run(main(chz.entrypoint(Config)))
460
#asyncio.run(main_no_metrics(chz.entrypoint(Config)))
461
462
463
464