CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
sagemathinc

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: sagemathinc/cocalc
Path: blob/master/src/scripts/smc_firewall.py
Views: 687
1
#!/usr/bin/env python
2
3
###############################################################################
4
#
5
# CoCalc: Collaborative Calculation
6
#
7
# Copyright (C) 2016, Sagemath Inc.
8
#
9
# This program is free software: you can redistribute it and/or modify
10
# it under the terms of the GNU General Public License as published by
11
# the Free Software Foundation, either version 3 of the License, or
12
# (at your option) any later version.
13
#
14
# This program is distributed in the hope that it will be useful,
15
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17
# GNU General Public License for more details.
18
#
19
# You should have received a copy of the GNU General Public License
20
# along with this program. If not, see <http://www.gnu.org/licenses/>.
21
#
22
###############################################################################
23
24
import json, os, signal, socket, sys, time
25
from subprocess import Popen, PIPE
26
27
28
def log(s, *args):
29
if args:
30
try:
31
s = str(s % args)
32
except Exception as mesg:
33
s = str(mesg) + str(s)
34
sys.stderr.write(s + '\n')
35
sys.stderr.flush()
36
37
38
def cmd(s,
39
ignore_errors=False,
40
verbose=2,
41
timeout=None,
42
stdout=True,
43
stderr=True,
44
system=False):
45
if isinstance(s, list):
46
s = [str(x) for x in s]
47
if isinstance(s, list):
48
c = ' '.join([x if len(x.split()) <= 1 else "'%s'" % x for x in s])
49
else:
50
c = s
51
if verbose >= 1:
52
if isinstance(s, list):
53
log(c)
54
else:
55
log(s)
56
t = time.time()
57
58
if system:
59
if os.system(c):
60
if verbose >= 1:
61
log("(%s seconds)", time.time() - t)
62
if ignore_errors:
63
return
64
else:
65
raise RuntimeError('error executing %s' % c)
66
return
67
68
mesg = "ERROR"
69
if timeout:
70
mesg = "TIMEOUT: running '%s' took more than %s seconds, so killed" % (
71
s, timeout)
72
73
def handle(*a):
74
if ignore_errors:
75
return mesg
76
else:
77
raise KeyboardInterrupt(mesg)
78
79
signal.signal(signal.SIGALRM, handle)
80
signal.alarm(timeout)
81
try:
82
out = Popen(
83
s,
84
stdin=PIPE,
85
stdout=PIPE,
86
stderr=PIPE,
87
shell=not isinstance(s, list))
88
x = out.stdout.read() + out.stderr.read()
89
e = out.wait(
90
) # this must be *after* the out.stdout.read(), etc. above or will hang when output large!
91
if e:
92
if ignore_errors:
93
return (x + "ERROR").strip()
94
else:
95
raise RuntimeError(x)
96
if verbose >= 2:
97
log("(%s seconds): %s", time.time() - t, x[:500])
98
elif verbose >= 1:
99
log("(%s seconds)", time.time() - t)
100
return x.strip()
101
except IOError:
102
return mesg
103
finally:
104
if timeout:
105
signal.signal(signal.SIGALRM, signal.SIG_IGN) # cancel the alarm
106
107
108
class Firewall(object):
109
def iptables(self, args, **kwds):
110
return cmd(['iptables', '-v'] + args, **kwds)
111
112
def insert_rule(self, rule, force=False):
113
if not self.exists(rule):
114
log("insert_rule: %s", rule)
115
self.iptables(['-I'] + rule)
116
elif force:
117
self.delete_rule(rule)
118
self.iptables(['-I'] + rule)
119
120
def append_rule(self, rule, force=False):
121
if not self.exists(rule):
122
log("append_rule: %s", rule)
123
self.iptables(['-A'] + rule)
124
elif force:
125
self.delete_rule(rule, force=True)
126
self.iptables(['-A'] + rule)
127
128
def delete_rule(self, rule, force=False):
129
if self.exists(rule):
130
log("delete_rule: %s", rule)
131
try:
132
self.iptables(['-D'] + rule)
133
except Exception as mesg:
134
log("delete_rule error -- %s", mesg)
135
# checking for exists is not 100% for uid rules module
136
pass
137
elif force:
138
try:
139
self.iptables(['-D'] + rule)
140
except:
141
pass
142
143
def exists(self, rule):
144
"""
145
Return true if the given rule exists already.
146
"""
147
try:
148
self.iptables(['-C'] + rule, verbose=0)
149
#log("rule %s already exists", rule)
150
return True
151
except:
152
#log("rule %s does not exist", rule)
153
return False
154
155
def clear(self):
156
"""
157
Remove all firewall rules, making everything completely open.
158
"""
159
self.iptables(['-F']) # clear the normal rules
160
self.iptables(
161
['-t', 'mangle',
162
'-F']) # clear the mangle rules used to shape traffic (using tc)
163
return {'status': 'success'}
164
165
def show(self, names=False):
166
"""
167
Show all firewall rules. (NON-JSON interface!)
168
"""
169
if names:
170
os.system("iptables -v -L")
171
else:
172
os.system("iptables -v -n -L")
173
174
def outgoing(self,
175
whitelist_hosts='',
176
whitelist_hosts_file='',
177
whitelist_users='',
178
blacklist_users='',
179
bandwidth_Kbps=1000):
180
"""
181
Block all outgoing traffic, except what is given
182
in a specific whitelist and DNS. Also throttle
183
bandwidth of outgoing SMC *user* traffic.
184
"""
185
if whitelist_users or blacklist_users:
186
self.outgoing_user(whitelist_users, blacklist_users)
187
188
if whitelist_hosts_file:
189
v = []
190
for x in open(whitelist_hosts_file).readlines():
191
i = x.find('#')
192
if i != -1:
193
x = x[:i]
194
x = x.strip()
195
if x:
196
v.append(x)
197
self.outgoing_whitelist_hosts(','.join(v))
198
self.outgoing_whitelist_hosts(whitelist_hosts)
199
200
# Block absolutely all outgoing traffic *from* lo to not loopback on same
201
# machine: this is to make it so a project
202
# can serve a network service listening on eth0 safely without having to worry
203
# about security at all, and still have it be secure, even from users on
204
# the same machine. We insert and remove this every time we mess with the firewall
205
# rules to ensure that it is at the very top.
206
self.insert_rule(
207
['OUTPUT', '-o', 'lo', '-d',
208
socket.gethostname(), '-j', 'REJECT'],
209
force=True)
210
211
if bandwidth_Kbps:
212
self.configure_tc(bandwidth_Kbps)
213
214
return {'status': 'success'}
215
216
def configure_tc(self, bandwidth_Kbps):
217
try:
218
cmd("tc qdisc del dev eth0 root".split())
219
except:
220
pass # will fail if not already configured
221
try:
222
cmd("tc qdisc add dev eth0 root handle 1:0 htb default 99".split())
223
cmd((
224
"tc class add dev eth0 parent 1:0 classid 1:10 htb rate %sKbit ceil %sKbit prio 2"
225
% (bandwidth_Kbps, bandwidth_Kbps)).split())
226
cmd("tc qdisc add dev eth0 parent 1:10 handle 10: sfq perturb 10".
227
split())
228
cmd("tc filter add dev eth0 parent 1:0 protocol ip prio 1 handle 1 fw classid 1:10".
229
split())
230
except Exception:
231
pass # this is more serious but I don't have time to debug this
232
233
def outgoing_whitelist_hosts(self, whitelist):
234
whitelist = [x.strip() for x in whitelist.split(',')]
235
# determine the ip addresses of our locally configured DNS servers
236
for x in open("/etc/resolv.conf").readlines():
237
v = x.split()
238
if v[0] == 'nameserver':
239
log("adding nameserver %s to whitelist", v[1])
240
whitelist.append(v[1])
241
whitelist = ','.join([x for x in whitelist if x])
242
log("whitelist: %s", whitelist)
243
244
# Insert whitelist rule at the beginning of OUTPUT chain.
245
# Anything that matches this will immediately be accepted to go out.
246
if whitelist:
247
self.insert_rule(['OUTPUT', '-d', whitelist, '-j', 'ACCEPT'])
248
249
# Loopback traffic: allow all OUTGOING (so the rule below doesn't cause trouble);
250
# needed, e.g., by Jupyter notebook and probably other services.
251
self.insert_rule(['OUTPUT', '-o', 'lo', '-j', 'ACCEPT'])
252
253
# Block all new outgoing connections that we didn't allow above.
254
self.append_rule(
255
['OUTPUT', '-m', 'state', '--state', 'NEW', '-j', 'REJECT'])
256
257
def outgoing_user(self, add='', remove=''):
258
def rules(user):
259
# returns rule for allowing this user and whether rule is already in chain
260
v = [[
261
'OUTPUT', '-m', 'owner', '--uid-owner', user, '-j', 'ACCEPT'
262
]]
263
if False and user != 'salvus' and user != 'root':
264
# Make it so this user has their bandwidth throttled so DOS attacks are more difficult, and also spending
265
# thousands in bandwidth is harder.
266
# -t mangle mangles packets by adding a mark, which is needed by tc.
267
# -p all -- match all protocols, including both tcp and udp
268
# ! -d 10.240.0.0/8 ensures this rule does NOT apply to any destination inside GCE.;
269
# CRITICAL -- I thought 10.240.0.0/16 was right because that's what it says in the google firewall rules; but with k8s
270
# it's definitely wrong and this mistake frickin' kills everything!!!
271
# -m owner --uid-owner [user] makes the rule apply only to this user
272
# -j MARK --set-mark 0x1 marks packet so the throttling tc filter we created elsewhere gets applied
273
v.append([
274
'OUTPUT', '-t', 'mangle', '-p', 'all', '!', '-d',
275
'10.240.0.0/8', '-m', 'owner', '--uid-owner', user, '-j',
276
'MARK', '--set-mark', '0x1'
277
])
278
return v
279
280
for user in remove.split(','):
281
if user:
282
for x in rules(user):
283
self.delete_rule(x, force=True)
284
285
for user in add.split(','):
286
if user:
287
try:
288
for x in rules(user):
289
self.insert_rule(x, force=True)
290
except Exception as mesg:
291
log("\nWARNING whitelisting user: %s\n",
292
str(mesg).splitlines()[:-1])
293
294
def incoming(self, whitelist_hosts='', whitelist_ports=''):
295
"""
296
Deny all other incoming traffic, except from the
297
explicitly given whitelist of machines.
298
"""
299
# Allow some incoming packets from the whitelist of ports.
300
for p in whitelist_ports.split(','):
301
self.insert_rule(
302
['INPUT', '-p', 'tcp', '--dport', p, '-j', 'ACCEPT'])
303
304
# Allow incoming connections/packets from anything in the whitelist
305
if not whitelist_hosts.strip():
306
v = []
307
for t in ['smc', 'storage', 'admin']:
308
s = cmd(
309
"curl -s http://metadata.google.internal/computeMetadata/v1/project/attributes/%s-servers -H 'Metadata-Flavor: Google'"
310
% t)
311
v.append(s.replace(' ', ','))
312
whitelist_hosts = ','.join(v)
313
314
self.insert_rule(['INPUT', '-s', whitelist_hosts, '-j', 'ACCEPT'])
315
316
# Loopback traffic: allow all INCOMING (so the rule below doesn't cause trouble);
317
# needed, e.g., by Jupyter notebook and probably other services.
318
self.append_rule(['INPUT', '-i', 'lo', '-j', 'ACCEPT'])
319
320
# Block *new* packets arriving via a new connection from anywhere else. We
321
# don't want to block all packets -- e.g., if something on this machine
322
# connects to DNS, it should be allowed to receive the answer back.
323
self.append_rule(
324
['INPUT', '-m', 'state', '--state', 'NEW', '-j', 'DROP'])
325
326
return {'status': 'success'}
327
328
329
if __name__ == "__main__":
330
331
import socket
332
hostname = socket.gethostname()
333
log("hostname=%s", hostname)
334
if not hostname.startswith('compute') and not hostname.startswith('web'):
335
log("skipping firewall since this is not a production SMC machine")
336
sys.exit(0)
337
338
import argparse
339
parser = argparse.ArgumentParser(
340
description="CoCalc firewall control script")
341
subparsers = parser.add_subparsers(help='sub-command help')
342
343
def f(subparser):
344
function = subparser.prog.split()[-1]
345
346
def g(args):
347
special = [k for k in args.__dict__.keys() if k not in ['func']]
348
out = []
349
errors = False
350
kwds = dict([(k, getattr(args, k)) for k in special])
351
try:
352
result = getattr(Firewall(), function)(**kwds)
353
except Exception as mesg:
354
raise #-- for debugging
355
errors = True
356
result = {'error': str(mesg)}
357
print(json.dumps(result))
358
if errors:
359
sys.exit(1)
360
361
subparser.set_defaults(func=g)
362
363
parser_outgoing = subparsers.add_parser(
364
'outgoing',
365
help=
366
'create firewall to block all outgoing traffic, except explicit whitelist)'
367
)
368
parser_outgoing.add_argument(
369
'--whitelist_hosts',
370
help="comma separated list of sites to whitelist (not run if empty)",
371
default='')
372
parser_outgoing.add_argument(
373
'--whitelist_hosts_file',
374
help=
375
"filename of file with one line for each host (comments and blank lines are ignored)",
376
default='')
377
parser_outgoing.add_argument(
378
'--whitelist_users',
379
help="comma separated list of users to whitelist",
380
default='')
381
parser_outgoing.add_argument(
382
'--blacklist_users',
383
help="comma separated list of users to remove from whitelist",
384
default='')
385
parser_outgoing.add_argument(
386
'--bandwidth_Kbps', help="throttle user bandwidth", default=1000)
387
f(parser_outgoing)
388
389
parser_incoming = subparsers.add_parser(
390
'incoming',
391
help=
392
'create firewall to block all incoming traffic except ssh, nfs, http[s], except explicit whitelist'
393
)
394
parser_incoming.add_argument(
395
'--whitelist_hosts',
396
help=
397
"comma separated list of sites to whitelist (default: use metadata server to get smc vms)",
398
default='')
399
parser_incoming.add_argument(
400
'--whitelist_ports',
401
help="comma separated list of ports to whitelist",
402
default='22,80,111,443')
403
f(parser_incoming)
404
405
f(subparsers.add_parser('clear', help='clear all rules'))
406
407
parser_show = subparsers.add_parser('show', help='show all rules')
408
parser_show.add_argument(
409
'--names',
410
help="show hostnames (potentially expensive DNS lookup)",
411
default=False,
412
action="store_const",
413
const=True)
414
f(parser_show)
415
416
args = parser.parse_args()
417
args.func(args)
418
419