package cluster
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/prometheus/client_golang/prometheus"
"github.com/rfratto/ckit"
"github.com/rfratto/ckit/peer"
"github.com/rfratto/ckit/shard"
"golang.org/x/net/http2"
)
type Node interface {
Lookup(key shard.Key, replicationFactor int, op shard.Op) ([]peer.Peer, error)
Observe(ckit.Observer)
Peers() []peer.Peer
Handler() (string, http.Handler)
}
func NewLocalNode(selfAddr string) Node {
p := peer.Peer{
Name: "local",
Addr: selfAddr,
Self: true,
State: peer.StateParticipant,
}
return &localNode{self: p}
}
type localNode struct{ self peer.Peer }
func (ln *localNode) Lookup(key shard.Key, replicationFactor int, op shard.Op) ([]peer.Peer, error) {
if replicationFactor == 0 {
return nil, nil
} else if replicationFactor > 1 {
return nil, fmt.Errorf("need %d nodes; only 1 available", replicationFactor)
}
return []peer.Peer{ln.self}, nil
}
func (ln *localNode) Observe(ckit.Observer) {
}
func (ln *localNode) Peers() []peer.Peer {
return []peer.Peer{ln.self}
}
func (ln *localNode) Handler() (string, http.Handler) {
mux := http.NewServeMux()
mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("clustering is disabled"))
w.WriteHeader(http.StatusBadRequest)
}))
return "/api/v1/ckit/transport/", mux
}
type Clusterer struct {
Node Node
}
func getJoinAddr(addrs []string, in string) []string {
_, _, err := net.SplitHostPort(in)
if err == nil {
addrs = append(addrs, in)
return addrs
}
ip := net.ParseIP(in)
if ip != nil {
addrs = append(addrs, ip.String())
return addrs
}
_, srvs, err := net.LookupSRV("", "", in)
if err == nil {
for _, srv := range srvs {
addrs = append(addrs, srv.Target)
}
}
return addrs
}
func New(log log.Logger, reg prometheus.Registerer, clusterEnabled bool, listenAddr, advertiseAddr, joinAddr string) (*Clusterer, error) {
if !clusterEnabled {
return &Clusterer{Node: NewLocalNode(listenAddr)}, nil
}
gossipConfig := DefaultGossipConfig
defaultPort := 80
_, portStr, err := net.SplitHostPort(listenAddr)
if err == nil {
defaultPort, err = strconv.Atoi(portStr)
if err != nil {
return nil, err
}
}
if advertiseAddr != "" {
gossipConfig.AdvertiseAddr = advertiseAddr
}
if joinAddr != "" {
gossipConfig.JoinPeers = []string{}
jaddrs := strings.Split(joinAddr, ",")
for _, jaddr := range jaddrs {
gossipConfig.JoinPeers = getJoinAddr(gossipConfig.JoinPeers, jaddr)
}
}
err = gossipConfig.ApplyDefaults(defaultPort)
if err != nil {
return nil, err
}
cli := &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
timeout := 30 * time.Second
if dur, ok := deadlineDuration(ctx); ok && dur < timeout {
timeout = dur
}
return net.DialTimeout(network, addr, timeout)
},
},
}
level.Info(log).Log("msg", "starting a new gossip node", "join-peers", gossipConfig.JoinPeers)
gossipNode, err := NewGossipNode(log, reg, cli, &gossipConfig)
if err != nil {
return nil, err
}
err = gossipNode.Start()
if err != nil {
level.Debug(log).Log("msg", "failed to connect to peers; bootstrapping a new cluster")
gossipConfig.JoinPeers = nil
err = gossipNode.Start()
if err != nil {
return nil, err
}
}
err = gossipNode.ChangeState(context.Background(), peer.StateParticipant)
if err != nil {
return nil, err
}
res := &Clusterer{Node: gossipNode}
gossipNode.Observe(ckit.FuncObserver(func(peers []peer.Peer) (reregister bool) {
names := make([]string, len(peers))
for i, p := range peers {
names[i] = p.Name
}
level.Info(log).Log("msg", "peers changed", "new_peers", strings.Join(names, ","))
return true
}))
return res, nil
}
func deadlineDuration(ctx context.Context) (d time.Duration, ok bool) {
if t, ok := ctx.Deadline(); ok {
return time.Until(t), true
}
return 0, false
}