Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
sagemath
GitHub Repository: sagemath/sagecell
Path: blob/master/kernel_dealer.py
447 views
1
import asyncio
2
import time
3
4
import jupyter_client.session
5
import tornado.ioloop
6
import zmq
7
8
from log import logger
9
import misc
10
11
12
config = misc.Config()
13
14
15
class KernelConnection(object):
16
"""
17
Kernel from the dealer point of view.
18
19
Handles connections over ZMQ sockets to compute kernels.
20
"""
21
22
def __init__(self, dealer, id, connection, lifespan, timeout):
23
self._on_stop = None
24
self._dealer = dealer
25
self.id = id
26
self.executing = 0
27
self.status = "starting"
28
now = time.time()
29
self.hard_deadline = now + lifespan
30
self.timeout = timeout
31
if timeout > 0:
32
self.deadline = now + self.timeout
33
self.session = jupyter_client.session.Session(
34
key=connection["key"].encode())
35
self.channels = {}
36
context = zmq.Context.instance()
37
address = connection["ip"]
38
if ":" in address:
39
address = "[{}]".format(address)
40
for channel, socket_type in (
41
("shell", zmq.DEALER), ("iopub", zmq.SUB), ("hb", zmq.REQ)):
42
socket = context.socket(socket_type)
43
socket.connect("tcp://{}:{}".format(address, connection[channel]))
44
stream = zmq.eventloop.zmqstream.ZMQStream(socket)
45
stream.channel = channel
46
self.channels[channel] = stream
47
self.channels["iopub"].socket.subscribe(b"")
48
self.start_hb()
49
logger.debug("KernelConnection initialized")
50
51
def on_stop(self, callback):
52
self._on_stop = callback
53
54
def start_hb(self):
55
logger.debug("start_hb for %s", self.id)
56
hb = self.channels["hb"]
57
ioloop = tornado.ioloop.IOLoop.current()
58
59
def pong(message):
60
#logger.debug("pong for %s", self.id)
61
self._expecting_pong = False
62
63
hb.on_recv(pong)
64
self._expecting_pong = False
65
66
def ping():
67
#logger.debug("ping for %s", self.id)
68
now = ioloop.time()
69
if self._expecting_pong:
70
logger.warning("kernel %s died unexpectedly", self.id)
71
self.stop()
72
elif now > self.hard_deadline:
73
logger.info("hard deadline reached for %s", self.id)
74
self.stop()
75
elif (self.timeout > 0
76
and now > self.deadline
77
and self.status == "idle"):
78
logger.info("kernel %s timed out", self.id)
79
self.stop()
80
else:
81
hb.send(b'ping')
82
self._expecting_pong = True
83
84
self._hb_periodic_callback = tornado.ioloop.PeriodicCallback(
85
ping, config.get("beat_interval") * 1000)
86
87
def start_ping():
88
logger.debug("start_ping for %s", self.id)
89
if self.alive:
90
self._hb_periodic_callback.start()
91
92
self._start_ping_handle = ioloop.call_later(
93
config.get("first_beat"), start_ping)
94
self.alive = True
95
96
def stop(self):
97
logger.debug("stopping kernel %s", self.id)
98
if not self.alive:
99
logger.warning("not alive already")
100
return
101
self.stop_hb()
102
if self._on_stop:
103
self._on_stop()
104
for stream in self.channels.values():
105
stream.close()
106
self._dealer.stop_kernel(self.id)
107
108
def stop_hb(self):
109
logger.debug("stop_hb for %s", self.id)
110
self.alive = False
111
self._hb_periodic_callback.stop()
112
tornado.ioloop.IOLoop.current().remove_timeout(self._start_ping_handle)
113
self.channels["hb"].on_recv(None)
114
115
116
class KernelDealer(object):
117
r"""
118
Kernel Dealer handles compute kernels on the server side.
119
"""
120
121
def __init__(self, provider_settings):
122
self.provider_settings = provider_settings
123
self._available_providers = []
124
self._connected_providers = {} # provider address: last message time
125
self._expected_kernels = []
126
self._get_queue = []
127
self._kernel_origins = {} # id: provider address
128
self._kernels = {} # id: KernelConnection
129
context = zmq.Context.instance()
130
context.IPV6 = 1
131
socket = context.socket(zmq.ROUTER)
132
self.port = socket.bind_to_random_port("tcp://*")
133
# Can configure perhaps interface/IP/port
134
self._stream = zmq.eventloop.zmqstream.ZMQStream(socket)
135
self._stream.on_recv(self._recv)
136
logger.debug("KernelDealer initialized")
137
138
def _try_to_get(self):
139
r"""
140
Send a get request if possible AND needed.
141
"""
142
while self._available_providers and self._get_queue:
143
self._stream.send(self._available_providers.pop(0), zmq.SNDMORE)
144
self._stream.send_json(["get", self._get_queue.pop(0)])
145
logger.debug("sent get request to a provider")
146
if self._available_providers:
147
logger.debug("%s available providers are idling",
148
len(self._available_providers))
149
if self._get_queue:
150
logger.debug("%s get requests are waiting for providers",
151
len(self._get_queue))
152
153
def _recv(self, msg):
154
logger.debug("received %s", msg)
155
assert len(msg) == 2
156
addr = msg[0]
157
self._connected_providers[addr] = time.time()
158
msg = zmq.utils.jsonapi.loads(msg[1])
159
if msg == "get settings":
160
self._stream.send(addr, zmq.SNDMORE)
161
self._stream.send_json(["settings", self.provider_settings])
162
elif msg == "ready":
163
self._available_providers.append(addr)
164
self._try_to_get()
165
elif msg[0] == "kernel":
166
msg = msg[1]
167
for i, (rlimits, f) in enumerate(self._expected_kernels):
168
if rlimits == msg["rlimits"]:
169
self._kernel_origins[msg["id"]] = addr
170
self._expected_kernels.pop(i)
171
f.set_result(msg)
172
break
173
174
async def get_kernel(self,
175
rlimits={}, lifespan=float("inf"), timeout=float("inf")):
176
f = asyncio.get_running_loop().create_future()
177
self._expected_kernels.append((rlimits, f))
178
self._get_queue.append(rlimits)
179
self._try_to_get()
180
d = await f
181
d.pop("rlimits")
182
d["lifespan"] = lifespan
183
d["timeout"] = timeout
184
kernel = KernelConnection(self, **d)
185
self._kernels[kernel.id] = kernel
186
logger.debug("tracking %d kernels", len(self._kernels))
187
logger.info("dealing kernel %s", kernel.id)
188
return kernel
189
190
def kernel(self, id):
191
return self._kernels[id]
192
193
def stop(self):
194
r"""
195
Stop all kernels and disconnect all providers.
196
"""
197
self._stream.stop_on_recv()
198
for k in list(self._kernels.values()):
199
k.stop()
200
for addr in self._connected_providers:
201
logger.debug("stopping %r", addr)
202
self._stream.send(addr, zmq.SNDMORE)
203
self._stream.send_json("disconnect")
204
self._stream.flush()
205
206
def stop_kernel(self, id):
207
addr = self._kernel_origins.pop(id)
208
self._stream.send(addr, zmq.SNDMORE)
209
self._stream.send_json(["stop", id])
210
self._kernels.pop(id)
211
212