import json, os, signal, socket, sys, time
from subprocess import Popen, PIPE
def log(s, *args):
if args:
try:
s = str(s % args)
except Exception as mesg:
s = str(mesg) + str(s)
sys.stderr.write(s + '\n')
sys.stderr.flush()
def cmd(s,
ignore_errors=False,
verbose=2,
timeout=None,
stdout=True,
stderr=True,
system=False):
if isinstance(s, list):
s = [str(x) for x in s]
if isinstance(s, list):
c = ' '.join([x if len(x.split()) <= 1 else "'%s'" % x for x in s])
else:
c = s
if verbose >= 1:
if isinstance(s, list):
log(c)
else:
log(s)
t = time.time()
if system:
if os.system(c):
if verbose >= 1:
log("(%s seconds)", time.time() - t)
if ignore_errors:
return
else:
raise RuntimeError('error executing %s' % c)
return
mesg = "ERROR"
if timeout:
mesg = "TIMEOUT: running '%s' took more than %s seconds, so killed" % (
s, timeout)
def handle(*a):
if ignore_errors:
return mesg
else:
raise KeyboardInterrupt(mesg)
signal.signal(signal.SIGALRM, handle)
signal.alarm(timeout)
try:
out = Popen(
s,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
shell=not isinstance(s, list))
x = out.stdout.read() + out.stderr.read()
e = out.wait(
)
if e:
if ignore_errors:
return (x + "ERROR").strip()
else:
raise RuntimeError(x)
if verbose >= 2:
log("(%s seconds): %s", time.time() - t, x[:500])
elif verbose >= 1:
log("(%s seconds)", time.time() - t)
return x.strip()
except IOError:
return mesg
finally:
if timeout:
signal.signal(signal.SIGALRM, signal.SIG_IGN)
class Firewall(object):
def iptables(self, args, **kwds):
return cmd(['iptables', '-v'] + args, **kwds)
def insert_rule(self, rule, force=False):
if not self.exists(rule):
log("insert_rule: %s", rule)
self.iptables(['-I'] + rule)
elif force:
self.delete_rule(rule)
self.iptables(['-I'] + rule)
def append_rule(self, rule, force=False):
if not self.exists(rule):
log("append_rule: %s", rule)
self.iptables(['-A'] + rule)
elif force:
self.delete_rule(rule, force=True)
self.iptables(['-A'] + rule)
def delete_rule(self, rule, force=False):
if self.exists(rule):
log("delete_rule: %s", rule)
try:
self.iptables(['-D'] + rule)
except Exception as mesg:
log("delete_rule error -- %s", mesg)
pass
elif force:
try:
self.iptables(['-D'] + rule)
except:
pass
def exists(self, rule):
"""
Return true if the given rule exists already.
"""
try:
self.iptables(['-C'] + rule, verbose=0)
return True
except:
return False
def clear(self):
"""
Remove all firewall rules, making everything completely open.
"""
self.iptables(['-F'])
self.iptables(
['-t', 'mangle',
'-F'])
return {'status': 'success'}
def show(self, names=False):
"""
Show all firewall rules. (NON-JSON interface!)
"""
if names:
os.system("iptables -v -L")
else:
os.system("iptables -v -n -L")
def outgoing(self,
whitelist_hosts='',
whitelist_hosts_file='',
whitelist_users='',
blacklist_users='',
bandwidth_Kbps=1000):
"""
Block all outgoing traffic, except what is given
in a specific whitelist and DNS. Also throttle
bandwidth of outgoing SMC *user* traffic.
"""
if whitelist_users or blacklist_users:
self.outgoing_user(whitelist_users, blacklist_users)
if whitelist_hosts_file:
v = []
for x in open(whitelist_hosts_file).readlines():
i = x.find('#')
if i != -1:
x = x[:i]
x = x.strip()
if x:
v.append(x)
self.outgoing_whitelist_hosts(','.join(v))
self.outgoing_whitelist_hosts(whitelist_hosts)
self.insert_rule(
['OUTPUT', '-o', 'lo', '-d',
socket.gethostname(), '-j', 'REJECT'],
force=True)
if bandwidth_Kbps:
self.configure_tc(bandwidth_Kbps)
return {'status': 'success'}
def configure_tc(self, bandwidth_Kbps):
try:
cmd("tc qdisc del dev eth0 root".split())
except:
pass
try:
cmd("tc qdisc add dev eth0 root handle 1:0 htb default 99".split())
cmd((
"tc class add dev eth0 parent 1:0 classid 1:10 htb rate %sKbit ceil %sKbit prio 2"
% (bandwidth_Kbps, bandwidth_Kbps)).split())
cmd("tc qdisc add dev eth0 parent 1:10 handle 10: sfq perturb 10".
split())
cmd("tc filter add dev eth0 parent 1:0 protocol ip prio 1 handle 1 fw classid 1:10".
split())
except Exception:
pass
def outgoing_whitelist_hosts(self, whitelist):
whitelist = [x.strip() for x in whitelist.split(',')]
for x in open("/etc/resolv.conf").readlines():
v = x.split()
if v[0] == 'nameserver':
log("adding nameserver %s to whitelist", v[1])
whitelist.append(v[1])
whitelist = ','.join([x for x in whitelist if x])
log("whitelist: %s", whitelist)
if whitelist:
self.insert_rule(['OUTPUT', '-d', whitelist, '-j', 'ACCEPT'])
self.insert_rule(['OUTPUT', '-o', 'lo', '-j', 'ACCEPT'])
self.append_rule(
['OUTPUT', '-m', 'state', '--state', 'NEW', '-j', 'REJECT'])
def outgoing_user(self, add='', remove=''):
def rules(user):
v = [[
'OUTPUT', '-m', 'owner', '--uid-owner', user, '-j', 'ACCEPT'
]]
if False and user != 'salvus' and user != 'root':
v.append([
'OUTPUT', '-t', 'mangle', '-p', 'all', '!', '-d',
'10.240.0.0/8', '-m', 'owner', '--uid-owner', user, '-j',
'MARK', '--set-mark', '0x1'
])
return v
for user in remove.split(','):
if user:
for x in rules(user):
self.delete_rule(x, force=True)
for user in add.split(','):
if user:
try:
for x in rules(user):
self.insert_rule(x, force=True)
except Exception as mesg:
log("\nWARNING whitelisting user: %s\n",
str(mesg).splitlines()[:-1])
def incoming(self, whitelist_hosts='', whitelist_ports=''):
"""
Deny all other incoming traffic, except from the
explicitly given whitelist of machines.
"""
for p in whitelist_ports.split(','):
self.insert_rule(
['INPUT', '-p', 'tcp', '--dport', p, '-j', 'ACCEPT'])
if not whitelist_hosts.strip():
v = []
for t in ['smc', 'storage', 'admin']:
s = cmd(
"curl -s http://metadata.google.internal/computeMetadata/v1/project/attributes/%s-servers -H 'Metadata-Flavor: Google'"
% t)
v.append(s.replace(' ', ','))
whitelist_hosts = ','.join(v)
self.insert_rule(['INPUT', '-s', whitelist_hosts, '-j', 'ACCEPT'])
self.append_rule(['INPUT', '-i', 'lo', '-j', 'ACCEPT'])
self.append_rule(
['INPUT', '-m', 'state', '--state', 'NEW', '-j', 'DROP'])
return {'status': 'success'}
if __name__ == "__main__":
import socket
hostname = socket.gethostname()
log("hostname=%s", hostname)
if not hostname.startswith('compute') and not hostname.startswith('web'):
log("skipping firewall since this is not a production SMC machine")
sys.exit(0)
import argparse
parser = argparse.ArgumentParser(
description="CoCalc firewall control script")
subparsers = parser.add_subparsers(help='sub-command help')
def f(subparser):
function = subparser.prog.split()[-1]
def g(args):
special = [k for k in args.__dict__.keys() if k not in ['func']]
out = []
errors = False
kwds = dict([(k, getattr(args, k)) for k in special])
try:
result = getattr(Firewall(), function)(**kwds)
except Exception as mesg:
raise
errors = True
result = {'error': str(mesg)}
print(json.dumps(result))
if errors:
sys.exit(1)
subparser.set_defaults(func=g)
parser_outgoing = subparsers.add_parser(
'outgoing',
help=
'create firewall to block all outgoing traffic, except explicit whitelist)'
)
parser_outgoing.add_argument(
'--whitelist_hosts',
help="comma separated list of sites to whitelist (not run if empty)",
default='')
parser_outgoing.add_argument(
'--whitelist_hosts_file',
help=
"filename of file with one line for each host (comments and blank lines are ignored)",
default='')
parser_outgoing.add_argument(
'--whitelist_users',
help="comma separated list of users to whitelist",
default='')
parser_outgoing.add_argument(
'--blacklist_users',
help="comma separated list of users to remove from whitelist",
default='')
parser_outgoing.add_argument(
'--bandwidth_Kbps', help="throttle user bandwidth", default=1000)
f(parser_outgoing)
parser_incoming = subparsers.add_parser(
'incoming',
help=
'create firewall to block all incoming traffic except ssh, nfs, http[s], except explicit whitelist'
)
parser_incoming.add_argument(
'--whitelist_hosts',
help=
"comma separated list of sites to whitelist (default: use metadata server to get smc vms)",
default='')
parser_incoming.add_argument(
'--whitelist_ports',
help="comma separated list of ports to whitelist",
default='22,80,111,443')
f(parser_incoming)
f(subparsers.add_parser('clear', help='clear all rules'))
parser_show = subparsers.add_parser('show', help='show all rules')
parser_show.add_argument(
'--names',
help="show hostnames (potentially expensive DNS lookup)",
default=False,
action="store_const",
const=True)
f(parser_show)
args = parser.parse_args()
args.func(args)