diff --git a/go/tasks/plugins/k8s/ray/auth.go b/go/tasks/plugins/k8s/ray/auth.go new file mode 100644 index 000000000..89dc8e394 --- /dev/null +++ b/go/tasks/plugins/k8s/ray/auth.go @@ -0,0 +1,58 @@ +package ray + +import ( + "fmt" + "io/ioutil" + + "github.com/pkg/errors" + restclient "k8s.io/client-go/rest" +) + +type ClusterConfig struct { + Name string `json:"name" pflag:",Friendly name of the remote cluster"` + Endpoint string `json:"endpoint" pflag:", Remote K8s cluster endpoint"` + Auth Auth `json:"auth" pflag:"-, Auth setting for the cluster"` + Enabled bool `json:"enabled" pflag:", Boolean flag to enable or disable"` +} + +type Auth struct { + TokenPath string `json:"tokenPath" pflag:", Token path"` + CaCertPath string `json:"caCertPath" pflag:", Certificate path"` +} + +func (auth Auth) GetCA() ([]byte, error) { + cert, err := ioutil.ReadFile(auth.CaCertPath) + if err != nil { + return nil, errors.Wrap(err, "failed to read k8s CA cert from configured path") + } + return cert, nil +} + +func (auth Auth) GetToken() (string, error) { + token, err := ioutil.ReadFile(auth.TokenPath) + if err != nil { + return "", errors.Wrap(err, "failed to read k8s bearer token from configured path") + } + return string(token), nil +} + +// KubeClientConfig ... +func KubeClientConfig(host string, auth Auth) (*restclient.Config, error) { + tokenString, err := auth.GetToken() + if err != nil { + return nil, errors.New(fmt.Sprintf("Failed to get auth token: %+v", err)) + } + + caCert, err := auth.GetCA() + if err != nil { + return nil, errors.New(fmt.Sprintf("Failed to get auth CA: %+v", err)) + } + + tlsClientConfig := restclient.TLSClientConfig{} + tlsClientConfig.CAData = caCert + return &restclient.Config{ + Host: host, + TLSClientConfig: tlsClientConfig, + BearerToken: tokenString, + }, nil +} diff --git a/go/tasks/plugins/k8s/ray/config.go b/go/tasks/plugins/k8s/ray/config.go index e141708ab..387fe862f 100644 --- a/go/tasks/plugins/k8s/ray/config.go +++ b/go/tasks/plugins/k8s/ray/config.go @@ -40,6 +40,9 @@ type Config struct { // NodeIPAddress the IP address of the head node. By default, this is pod ip address. NodeIPAddress string `json:"nodeIPAddress,omitempty"` + + // Remote Ray Cluster Config + RemoteClusterConfig ClusterConfig `json:"remoteClusterConfig" pflag:"Configuration of remote K8s cluster for array jobs"` } func GetConfig() *Config { diff --git a/go/tasks/plugins/k8s/ray/ray.go b/go/tasks/plugins/k8s/ray/ray.go index d7e274250..3516eacdc 100644 --- a/go/tasks/plugins/k8s/ray/ray.go +++ b/go/tasks/plugins/k8s/ray/ray.go @@ -371,5 +371,19 @@ func init() { ResourceToWatch: &rayv1alpha1.RayJob{}, Plugin: rayJobResourceHandler{}, IsDefault: false, + CustomKubeClient: func(ctx context.Context) (pluginsCore.KubeClient, error) { + remoteConfig := GetConfig().RemoteClusterConfig + if !remoteConfig.Enabled { + // use controller-runtime KubeClient + return nil, nil + } + + kubeConfig, err := KubeClientConfig(remoteConfig.Endpoint, remoteConfig.Auth) + if err != nil { + return nil, err + } + + return k8s.NewDefaultKubeClient(kubeConfig) + }, }) }