Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
sagemath
GitHub Repository: sagemath/sagecell
Path: blob/master/kernel_provider.py
447 views
1
#! /usr/bin/env python
2
3
r"""
4
Kernel Provider starts compute kernels and sends connection info to Dealer.
5
"""
6
7
8
import argparse
9
import errno
10
from multiprocessing import Process
11
import os
12
import resource
13
import signal
14
import sys
15
import time
16
import uuid
17
18
from ipykernel.kernelapp import IPKernelApp
19
import zmq
20
21
import kernel_init
22
import log
23
logger = log.provider_logger.getChild(str(os.getpid()))
24
25
26
class KernelProcess(Process):
27
"""
28
Kernel from the provider point of view.
29
30
Configures a kernel process and does its best at cleaning up.
31
"""
32
33
def __init__(self, id, rlimits, dir, waiter_port):
34
super(KernelProcess, self).__init__()
35
self.id = id
36
self.rlimits = rlimits
37
self.dir = dir
38
self.waiter_port = waiter_port
39
40
def run(self):
41
global logger
42
logger = log.kernel_logger.getChild(str(os.getpid()))
43
logger.debug("forked kernel is running")
44
log.std_redirect(logger)
45
# Become a group leader for cleaner exit.
46
os.setpgrp()
47
dir = os.path.join(self.dir, self.id)
48
try:
49
os.mkdir(dir)
50
except OSError as e:
51
if e.errno != errno.EEXIST:
52
raise
53
os.chdir(dir)
54
#config = traitlets.config.loader.Config({"ip": self.ip})
55
#config.HistoryManager.enabled = False
56
app = IPKernelApp.instance(log=logger)
57
from namespace import InstrumentedNamespace
58
app.user_ns = InstrumentedNamespace()
59
app.initialize([]) # Redirects stdout/stderr
60
#log.std_redirect(logger) # Uncomment for debugging
61
# This function should be called via atexit, but it isn't, perhaps due
62
# to forking. Stale connection files do cause problems.
63
app.cleanup_connection_file()
64
kernel_init.initialize(app.kernel)
65
for r, limit in self.rlimits.items():
66
resource.setrlimit(getattr(resource, r), (limit, limit))
67
logger.debug("kernel ready")
68
context = zmq.Context.instance()
69
socket = context.socket(zmq.PUSH)
70
socket.connect("tcp://localhost:{}".format(self.waiter_port))
71
socket.send_json({
72
"id": self.id,
73
"connection": {
74
"key": app.session.key.decode(),
75
"ip": app.ip,
76
"hb": app.hb_port,
77
"iopub": app.iopub_port,
78
"shell": app.shell_port,
79
},
80
"rlimits": self.rlimits,
81
})
82
83
def signal_handler(signum, frame):
84
logger.info("received %s, shutting down", signum)
85
# TODO: this may not be the best way to do it.
86
app.kernel.do_shutdown(False)
87
88
signal.signal(signal.SIGTERM, signal_handler)
89
app.start()
90
logger.debug("Kernel.run finished")
91
92
93
class KernelProvider(object):
94
r"""
95
Kernel Provider handles compute kernels on the worker side.
96
"""
97
98
def __init__(self, dealer_address, dir):
99
self.is_active = False
100
self.dir = dir
101
try:
102
os.mkdir(dir)
103
logger.warning("created parent directory for kernels, "
104
"consider doing it yourself with appropriate attributes")
105
except OSError as e:
106
if e.errno != errno.EEXIST:
107
raise
108
context = zmq.Context()
109
context.IPV6 = 1
110
self.dealer = context.socket(zmq.DEALER)
111
logger.debug("connecting to %s", address)
112
self.dealer.connect(address)
113
self.dealer.send_json("get settings")
114
if not self.dealer.poll(5000):
115
logger.debug("dealer does not answer, terminating")
116
exit(1)
117
reply = self.dealer.recv_json()
118
logger.debug("received %s", reply)
119
assert reply[0] == "settings"
120
self.preforked_rlimits = reply[1].pop("preforked_rlimits")
121
self.max_kernels = reply[1].pop("max_kernels")
122
self.max_preforked = reply[1].pop("max_preforked")
123
self.waiter = context.socket(zmq.PULL)
124
self.waiter_port = self.waiter.bind_to_random_port("tcp://*")
125
self.kernels = dict() # id: KernelProcess
126
self.forking = None
127
self.preforking = None
128
self.preforked = []
129
self.ready_sent = False
130
self.to_kill = []
131
setup_sage()
132
133
def fork(self, rlimits):
134
r"""
135
Start a new kernel by forking.
136
137
INPUT:
138
139
- ``rlimits`` - dictionary with keys ``resource.RLIMIT_*``
140
141
OUTPUT:
142
143
- ID of the forked kernel
144
"""
145
logger.debug("fork with rlimits %s", rlimits)
146
id = str(uuid.uuid4())
147
kernel = KernelProcess(id, rlimits, self.dir, self.waiter_port)
148
kernel.start()
149
self.kernels[id] = kernel
150
return id
151
152
def kill_check(self):
153
"""
154
Kill old kernels.
155
"""
156
to_kill = []
157
for kernel in self.to_kill:
158
if kernel.is_alive():
159
if time.time() < kernel.deadline:
160
to_kill.append(kernel)
161
continue
162
else:
163
logger.warning(
164
"kernel process %d did not stop by deadline",
165
kernel.pid)
166
try:
167
# Kernel PGID is the same as PID
168
os.killpg(kernel.pid, signal.SIGKILL)
169
except OSError as e:
170
if e.errno != errno.ESRCH:
171
raise
172
logger.debug("killed kernel process group %d", kernel.pid)
173
self.to_kill = to_kill
174
175
def send_kernel(self, msg):
176
self.dealer.send_json(["kernel", msg])
177
178
def start(self):
179
self.is_active = True
180
poller = zmq.Poller()
181
poller.register(self.dealer, zmq.POLLIN)
182
poller.register(self.waiter, zmq.POLLIN)
183
while self.is_active:
184
# For pretty red lines in the log
185
#logger.error("%s %s %s",
186
# self.forking, self.preforking, self.to_kill)
187
188
# Tell the dealer if we are ready.
189
if (not self.ready_sent
190
and self.forking is None
191
and (self.preforked or len(self.kernels) < self.max_kernels)):
192
self.dealer.send_json("ready")
193
self.ready_sent = True
194
# Kill old kernel process groups.
195
self.kill_check()
196
# Process requests from the dealer ...
197
events = dict(poller.poll(100))
198
if self.dealer in events:
199
msg = self.dealer.recv_json()
200
logger.debug("received %s", msg)
201
if msg == "disconnect":
202
self.stop()
203
if msg[0] == "get":
204
# We expect a single "get" for every "ready" sent.
205
self.ready_sent = False
206
if msg[1] == self.preforked_rlimits and self.preforked:
207
self.send_kernel(self.preforked.pop(0))
208
logger.debug("%d preforked kernels left",
209
len(self.preforked))
210
elif msg[1] == self.preforked_rlimits and self.preforking:
211
self.forking = self.preforking
212
self.preforking = None
213
else:
214
if len(self.kernels) == self.max_kernels:
215
logger.warning("killing a preforked kernel to "
216
"provide a special one")
217
self.stop_kernel(self.preforked.pop(0)["id"])
218
self.forking = self.fork(msg[1])
219
if msg[0] == "stop":
220
self.stop_kernel(msg[1])
221
# ... and connection info from kernels.
222
if self.waiter in events:
223
msg = self.waiter.recv_json()
224
if self.forking == msg["id"]:
225
self.send_kernel(msg)
226
self.forking = None
227
if self.preforking == msg["id"]:
228
self.preforked.append(msg)
229
self.preforking = None
230
# Prefork more standard kernels.
231
if (not (self.forking or self.preforking)
232
and len(self.preforked) < self.max_preforked
233
and len(self.kernels) < self.max_kernels):
234
self.preforking = self.fork(self.preforked_rlimits)
235
for id in list(self.kernels):
236
self.stop_kernel(id)
237
while self.to_kill:
238
self.kill_check()
239
time.sleep(0.1)
240
241
def stop(self):
242
self.is_active = False
243
244
def stop_kernel(self, id):
245
kernel = self.kernels.pop(id)
246
if kernel.is_alive():
247
logger.debug("killing kernel process %d", kernel.pid)
248
os.kill(kernel.pid, signal.SIGTERM)
249
kernel.deadline = time.time() + 1
250
self.to_kill.append(kernel)
251
252
253
def setup_sage():
254
# Non-existing startup file that users cannot create.
255
os.environ["SAGE_STARTUP_FILE"] = "/init.sage"
256
import sage
257
import sage.all
258
# override matplotlib and pylab show functions
259
# TODO: use something like IPython's inline backend
260
261
def mp_show(savefig):
262
filename = "%s.png" % uuid.uuid4()
263
savefig(filename)
264
msg = {"text/image-filename": filename}
265
sys._sage_.sent_files[filename] = os.path.getmtime(filename)
266
sys._sage_.display_message(msg)
267
268
from functools import partial
269
import pylab
270
pylab.show = partial(mp_show, savefig=pylab.savefig)
271
import matplotlib.pyplot
272
matplotlib.pyplot.show = partial(mp_show, savefig=matplotlib.pyplot.savefig)
273
274
# The first plot takes about 2 seconds to generate (presumably
275
# because lots of things, like matplotlib, are imported). We plot
276
# something here so that worker processes don't have this overhead
277
try:
278
sage.all.plot(1, (0, 1))
279
except Exception:
280
logger.exception("plotting exception")
281
282
283
if __name__ == "__main__":
284
parser = argparse.ArgumentParser(
285
description="Launch a kernel provider for SageMathCell")
286
parser.add_argument("--address",
287
help="address of the kernel dealer (defaults to $SSH_CLIENT)")
288
parser.add_argument("port", type=int,
289
help="port of the kernel dealer")
290
parser.add_argument("dir",
291
help="directory name for user files saved by kernels")
292
args = parser.parse_args()
293
294
log.std_redirect(logger)
295
address = args.address or os.environ["SSH_CLIENT"].split()[0]
296
if ":" in address:
297
address = "[{}]".format(address)
298
address = "tcp://{}:{}".format(address, args.port)
299
provider = KernelProvider(address, args.dir)
300
301
def signal_handler(signum, frame):
302
logger.info("received %s, shutting down", signum)
303
provider.stop()
304
305
signal.signal(signal.SIGTERM, signal_handler)
306
provider.start()
307
308