Skip to content

Commit

Permalink
feat(dgraph): making all internal communications with tls configured (#…
Browse files Browse the repository at this point in the history
…6692)

* making all internal communications with tls configured
  • Loading branch information
aman-bansal authored Oct 28, 2020
1 parent f74f202 commit a63af32
Show file tree
Hide file tree
Showing 111 changed files with 3,596 additions and 67 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ dgraph.iml

#darwin
.DS_Store

vendor
41 changes: 22 additions & 19 deletions conn/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package conn
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"fmt"
"math/rand"
Expand Down Expand Up @@ -66,16 +67,17 @@ type Node struct {
_raft raft.Node

// Fields which are never changed after init.
StartTime time.Time
Cfg *raft.Config
MyAddr string
Id uint64
peers map[uint64]string
confChanges map[uint64]chan error
messages chan sendmsg
RaftContext *pb.RaftContext
Store *raftwal.DiskStorage
Rand *rand.Rand
StartTime time.Time
Cfg *raft.Config
MyAddr string
Id uint64
peers map[uint64]string
confChanges map[uint64]chan error
messages chan sendmsg
RaftContext *pb.RaftContext
Store *raftwal.DiskStorage
Rand *rand.Rand
tlsClientConfig *tls.Config

Proposals proposals

Expand All @@ -84,7 +86,7 @@ type Node struct {
}

// NewNode returns a new Node instance.
func NewNode(rc *pb.RaftContext, store *raftwal.DiskStorage) *Node {
func NewNode(rc *pb.RaftContext, store *raftwal.DiskStorage, tlsConfig *tls.Config) *Node {
snap, err := store.Snapshot()
x.Check(err)

Expand Down Expand Up @@ -135,13 +137,14 @@ func NewNode(rc *pb.RaftContext, store *raftwal.DiskStorage) *Node {
},
// processConfChange etc are not throttled so some extra delta, so that we don't
// block tick when applyCh is full
Applied: y.WaterMark{Name: "Applied watermark"},
RaftContext: rc,
Rand: rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())}),
confChanges: make(map[uint64]chan error),
messages: make(chan sendmsg, 100),
peers: make(map[uint64]string),
requestCh: make(chan linReadReq, 100),
Applied: y.WaterMark{Name: "Applied watermark"},
RaftContext: rc,
Rand: rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())}),
confChanges: make(map[uint64]chan error),
messages: make(chan sendmsg, 100),
peers: make(map[uint64]string),
requestCh: make(chan linReadReq, 100),
tlsClientConfig: tlsConfig,
}
n.Applied.Init(nil)
// This should match up to the Applied index set above.
Expand Down Expand Up @@ -521,7 +524,7 @@ func (n *Node) Connect(pid uint64, addr string) {
n.SetPeer(pid, addr)
return
}
GetPools().Connect(addr)
GetPools().Connect(addr, n.tlsClientConfig)
n.SetPeer(pid, addr)
}

Expand Down
2 changes: 1 addition & 1 deletion conn/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestProposal(t *testing.T) {
store := raftwal.Init(dir)

rc := &pb.RaftContext{Id: 1}
n := NewNode(rc, store)
n := NewNode(rc, store, nil)

peers := []raft.Peer{{ID: n.Id}}
n.SetRaft(raft.StartNode(n.Cfg, peers))
Expand Down
23 changes: 18 additions & 5 deletions conn/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package conn

import (
"context"
"crypto/tls"
"sync"
"time"

Expand All @@ -30,6 +31,7 @@ import (
"go.opencensus.io/plugin/ocgrpc"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

var (
Expand Down Expand Up @@ -138,13 +140,13 @@ func (p *Pools) getPool(addr string) (*Pool, bool) {
}

// Connect creates a Pool instance for the node with the given address or returns the existing one.
func (p *Pools) Connect(addr string) *Pool {
func (p *Pools) Connect(addr string, tlsClientConf *tls.Config) *Pool {
existingPool, has := p.getPool(addr)
if has {
return existingPool
}

pool, err := newPool(addr)
pool, err := newPool(addr, tlsClientConf)
if err != nil {
glog.Errorf("Unable to connect to host: %s", addr)
return nil
Expand All @@ -160,21 +162,32 @@ func (p *Pools) Connect(addr string) *Pool {
glog.Infof("CONNECTING to %s\n", addr)
p.all[addr] = pool
return pool

}

// newPool creates a new "pool" with one gRPC connection, refcount 0.
func newPool(addr string) (*Pool, error) {
conn, err := grpc.Dial(addr,
func newPool(addr string, tlsClientConf *tls.Config) (*Pool, error) {
conOpts := []grpc.DialOption {
grpc.WithStatsHandler(&ocgrpc.ClientHandler{}),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(x.GrpcMaxSize),
grpc.MaxCallSendMsgSize(x.GrpcMaxSize),
grpc.UseCompressor((snappyCompressor{}).Name())),
grpc.WithBackoffMaxDelay(time.Second),
grpc.WithInsecure())
}

if tlsClientConf != nil {
conOpts = append(conOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsClientConf)))
} else {
conOpts = append(conOpts, grpc.WithInsecure())
}

conn, err := grpc.Dial(addr, conOpts...)
if err != nil {
glog.Errorf("unable to connect with %s : %s", addr, err)
return nil, err
}

pl := &Pool{conn: conn, Addr: addr, lastEcho: time.Now(), closer: z.NewCloser(1)}
go pl.MonitorHealth()
return pl, nil
Expand Down
17 changes: 11 additions & 6 deletions dgraph/cmd/alpha/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ import (
_ "github.com/vektah/gqlparser/v2/validator/rules" // make gql validator init() all rules
)

const (
tlsNodeCert = "node.crt"
tlsNodeKey = "node.key"
)

var (
bindall bool

Expand Down Expand Up @@ -172,6 +167,11 @@ they form a Raft group and provide synchronous replication.
flag.String("tls_dir", "", "Path to directory that has TLS certificates and keys.")
flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.")
flag.String("tls_client_auth", "VERIFYIFGIVEN", "Enable TLS client authentication")
flag.Bool("tls_internal_port_enabled", false, "(optional) enable inter node TLS encryption between cluster nodes.")
flag.String("tls_cert", "", "(optional) The Cert file name in tls_dir which is needed to " +
"connect as a client with the other nodes in the cluster.")
flag.String("tls_key", "", "(optional) The private key file name "+
"in tls_dir needed to connect as a client with the other nodes in the cluster.")

//Custom plugins.
flag.String("custom_tokenizers", "",
Expand Down Expand Up @@ -426,7 +426,7 @@ func setupServer(closer *z.Closer) {
laddr = "0.0.0.0"
}

tlsCfg, err := x.LoadServerTLSConfig(Alpha.Conf, tlsNodeCert, tlsNodeKey)
tlsCfg, err := x.LoadServerTLSConfig(Alpha.Conf, x.TLSNodeCert, x.TLSNodeKey)
if err != nil {
log.Fatalf("Failed to setup TLS: %v\n", err)
}
Expand Down Expand Up @@ -651,6 +651,8 @@ func run() {
abortDur, err := time.ParseDuration(Alpha.Conf.GetString("abort_older_than"))
x.Check(err)

tlsConf, err := x.LoadClientTLSConfigForInternalPort(Alpha.Conf)
x.Check(err)
x.WorkerConfig = x.WorkerOptions{
ExportPath: Alpha.Conf.GetString("export"),
NumPendingProposals: Alpha.Conf.GetInt("pending_proposals"),
Expand All @@ -665,6 +667,9 @@ func run() {
StartTime: startTime,
LudicrousMode: Alpha.Conf.GetBool("ludicrous_mode"),
LudicrousConcurrency: Alpha.Conf.GetInt("ludicrous_concurrency"),
TLSClientConfig: tlsConf,
TLSDir: Alpha.Conf.GetString("tls_dir"),
TLSInterNodeEnabled: Alpha.Conf.GetBool("tls_internal_port_enabled"),
}
x.WorkerConfig.Parse(Alpha.Conf)

Expand Down
14 changes: 11 additions & 3 deletions dgraph/cmd/bulk/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"compress/gzip"
"context"
"fmt"
"google.golang.org/grpc/credentials"
"hash/adler32"
"io"
"io/ioutil"
Expand Down Expand Up @@ -111,13 +112,20 @@ func newLoader(opt *options) *loader {
}

fmt.Printf("Connecting to zero at %s\n", opt.ZeroAddr)

ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

zero, err := grpc.DialContext(ctx, opt.ZeroAddr,
tlsConf, err := x.LoadClientTLSConfigForInternalPort(Bulk.Conf)
x.Check(err)
dialOpts := []grpc.DialOption{
grpc.WithBlock(),
grpc.WithInsecure())
}
if tlsConf != nil {
dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConf)))
} else {
dialOpts = append(dialOpts, grpc.WithInsecure())
}
zero, err := grpc.DialContext(ctx, opt.ZeroAddr, dialOpts...)
x.Checkf(err, "Unable to connect to zero, Is it running at %s?", opt.ZeroAddr)
st := &state{
opt: opt,
Expand Down
2 changes: 1 addition & 1 deletion dgraph/cmd/bulk/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func init() {
flag.String("badger.cache_percentage", "70,30",
"Cache percentages summing up to 100 for various caches"+
" (FORMAT: BlockCacheSize, IndexCacheSize).")

x.RegisterClientTLSFlags(flag)
// Encryption and Vault options
enc.RegisterFlags(flag)
}
Expand Down
4 changes: 4 additions & 0 deletions dgraph/cmd/live/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,10 @@ func setup(opts batchMutationOptions, dc *dgo.Dgraph, conf *viper.Viper) *loader
var tlsErr error
tlsConfig, tlsErr = x.SlashTLSConfig(conf.GetString("slash_grpc_endpoint"))
x.Checkf(tlsErr, "Unable to generate TLS Cert Pool")
} else {
var tlsErr error
tlsConfig, tlsErr = x.LoadClientTLSConfigForInternalPort(conf)
x.Check(tlsErr)
}

// compression with zero server actually makes things worse
Expand Down
2 changes: 1 addition & 1 deletion dgraph/cmd/zero/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func startServers(m cmux.CMux) {

// if tls is enabled, make tls encryption based connections as default
if Zero.Conf.GetString("tls_dir") != "" {
tlsCfg, err := x.LoadServerTLSConfig(Zero.Conf, "node.crt", "node.key")
tlsCfg, err := x.LoadServerTLSConfig(Zero.Conf, x.TLSNodeCert, x.TLSNodeKey)
x.Check(err)
if tlsCfg == nil {
glog.Fatalf("tls_dir is set but tls config provided is not correct. Please define correct variable --tls_dir")
Expand Down
4 changes: 2 additions & 2 deletions dgraph/cmd/zero/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func (n *node) handleMemberProposal(member *pb.Member) error {
}

// Create a connection to this server.
go conn.GetPools().Connect(member.Addr)
go conn.GetPools().Connect(member.Addr, n.server.tlsClientConfig)

group.Members[member.Id] = member
// Increment nextGroup when we have enough replicas
Expand Down Expand Up @@ -539,7 +539,7 @@ func (n *node) initAndStartNode() error {
}

case len(opts.peer) > 0:
p := conn.GetPools().Connect(opts.peer)
p := conn.GetPools().Connect(opts.peer, opts.tlsClientConfig)
if p == nil {
return errors.Errorf("Unhealthy connection to %v", opts.peer)
}
Expand Down
31 changes: 25 additions & 6 deletions dgraph/cmd/zero/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package zero

import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
Expand All @@ -33,6 +34,7 @@ import (
"go.opencensus.io/zpages"
"golang.org/x/net/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"

"github.com/dgraph-io/dgraph/conn"
"github.com/dgraph-io/dgraph/ee/enc"
Expand All @@ -54,6 +56,7 @@ type options struct {
rebalanceInterval time.Duration
tlsDir string
tlsDisabledRoutes []string
tlsClientConfig *tls.Config
totalCache int64
}

Expand Down Expand Up @@ -94,8 +97,14 @@ instances to achieve high-availability.
flag.String("tls_dir", "", "Path to directory that has TLS certificates and keys.")
flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.")
flag.String("tls_client_auth", "VERIFYIFGIVEN", "Enable TLS client authentication")
flag.String("tls_disabled_route", "", "comma separated zero endpoint which will be disabled from TLS encryption."+
flag.String("tls_disabled_route", "",
"comma separated zero endpoint which will be disabled from TLS encryption."+
"Valid values are /health,/state,/removeNode,/moveTablet,/assign,/enterpriseLicense,/debug.")
flag.Bool("tls_internal_port_enabled", false, "enable inter node TLS encryption between cluster nodes.")
flag.String("tls_cert", "", "(optional) The Cert file name in tls_dir which is needed to " +
"connect as a client with the other nodes in the cluster.")
flag.String("tls_key", "", "(optional) The private key file name "+
"in tls_dir which is needed to connect as a client with the other nodes in the cluster.")
}

func setupListener(addr string, port int, kind string) (listener net.Listener, err error) {
Expand All @@ -112,23 +121,30 @@ type state struct {

func (st *state) serveGRPC(l net.Listener, store *raftwal.DiskStorage) {
x.RegisterExporters(Zero.Conf, "dgraph.zero")

s := grpc.NewServer(
grpcOpts := []grpc.ServerOption{
grpc.MaxRecvMsgSize(x.GrpcMaxSize),
grpc.MaxSendMsgSize(x.GrpcMaxSize),
grpc.MaxConcurrentStreams(1000),
grpc.StatsHandler(&ocgrpc.ServerHandler{}))
grpc.StatsHandler(&ocgrpc.ServerHandler{}),
}

tlsConf, err := x.LoadServerTLSConfigForInternalPort(Zero.Conf.GetBool("tls_internal_port_enabled"), Zero.Conf.GetString("tls_dir"))
x.Check(err)
if tlsConf != nil {
grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(tlsConf)))
}
s := grpc.NewServer(grpcOpts...)

rc := pb.RaftContext{Id: opts.nodeId, Addr: x.WorkerConfig.MyAddr, Group: 0}
m := conn.NewNode(&rc, store)
m := conn.NewNode(&rc, store, opts.tlsClientConfig)

// Zero followers should not be forwarding proposals to the leader, to avoid txn commits which
// were calculated in a previous Zero leader.
m.Cfg.DisableProposalForwarding = true
st.rs = conn.NewRaftServer(m)

st.node = &node{Node: m, ctx: context.Background(), closer: z.NewCloser(1)}
st.zero = &Server{NumReplicas: opts.numReplicas, Node: st.node}
st.zero = &Server{NumReplicas: opts.numReplicas, Node: st.node, tlsClientConfig: opts.tlsClientConfig}
st.zero.Init()
st.node.server = st.zero

Expand Down Expand Up @@ -173,6 +189,8 @@ func run() {
tlsDisRoutes = strings.Split(Zero.Conf.GetString("tls_disabled_route"), ",")
}

tlsConf, err := x.LoadClientTLSConfigForInternalPort(Zero.Conf)
x.Check(err)
opts = options{
bindall: Zero.Conf.GetBool("bindall"),
portOffset: Zero.Conf.GetInt("port_offset"),
Expand All @@ -184,6 +202,7 @@ func run() {
totalCache: int64(Zero.Conf.GetInt("cache_mb")),
tlsDir: Zero.Conf.GetString("tls_dir"),
tlsDisabledRoutes: tlsDisRoutes,
tlsClientConfig: tlsConf,
}
glog.Infof("Setting Config to: %+v", opts)

Expand Down
Loading

0 comments on commit a63af32

Please sign in to comment.