Skip to content

Enable MTLS by default - Webhook #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions config/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ rules:
- subjectaccessreviews
verbs:
- create
- apiGroups:
- config.openshift.io
resources:
- ingresses
verbs:
- get
- apiGroups:
- ""
resources:
Expand Down
25 changes: 25 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (
"sigs.k8s.io/yaml"

routev1 "github.com/openshift/api/route/v1"
clientset "github.com/openshift/client-go/config/clientset/versioned"

"github.com/project-codeflare/codeflare-operator/pkg/config"
"github.com/project-codeflare/codeflare-operator/pkg/controllers"
Expand All @@ -75,6 +76,8 @@ func init() {
utilruntime.Must(routev1.Install(scheme))
}

// +kubebuilder:rbac:groups=config.openshift.io,resources=ingresses,verbs=get;

func main() {
var configMapName string
flag.StringVar(&configMapName, "config", "codeflare-operator-config",
Expand Down Expand Up @@ -117,6 +120,7 @@ func main() {
KubeRay: &config.KubeRayConfiguration{
RayDashboardOAuthEnabled: ptr.To(true),
IngressDomain: "",
MTLSEnabled: ptr.To(true),
},
}

Expand Down Expand Up @@ -155,6 +159,13 @@ func main() {
certsReady := make(chan struct{})
exitOnError(setupCertManagement(mgr, namespace, certsReady), "unable to setup cert-controller")

if cfg.KubeRay.IngressDomain == "" {
configClient, err := clientset.NewForConfig(kubeConfig)
exitOnError(err, "unable to create Route Client Set")
cfg.KubeRay.IngressDomain, err = getClusterDomain(ctx, configClient)
exitOnError(err, cfg.KubeRay.IngressDomain)
}

go setupControllers(mgr, kubeClient, cfg, isOpenShift(ctx, kubeClient.DiscoveryClient), certsReady)

setupLog.Info("setting up health endpoints")
Expand Down Expand Up @@ -332,3 +343,17 @@ func isOpenShift(ctx context.Context, dc discovery.DiscoveryInterface) bool {
logger.Info("We detected being on Vanilla Kubernetes!")
return false
}

func getClusterDomain(ctx context.Context, configClient *clientset.Clientset) (string, error) {
ingress, err := configClient.ConfigV1().Ingresses().Get(ctx, "cluster", metav1.GetOptions{})
if err != nil {
return "", fmt.Errorf("failed to get Ingress object: %v", err)
}

domain := ingress.Spec.Domain
if domain == "" {
return "", fmt.Errorf("domain is not set in the Ingress object")
}

return domain, nil
}
2 changes: 2 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type KubeRayConfiguration struct {
RayDashboardOAuthEnabled *bool `json:"rayDashboardOAuthEnabled,omitempty"`

IngressDomain string `json:"ingressDomain"`

MTLSEnabled *bool `json:"mTLSEnabled,omitempty"`
}

type ControllerManager struct {
Expand Down
217 changes: 210 additions & 7 deletions pkg/controllers/raycluster_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package controllers

import (
"context"
"strconv"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"

Expand All @@ -36,6 +37,7 @@ import (
const (
oauthProxyContainerName = "oauth-proxy"
oauthProxyVolumeName = "proxy-tls-secret"
initContainerName = "create-cert"
)

// log is for logging in this package.
Expand Down Expand Up @@ -66,17 +68,44 @@ var _ webhook.CustomValidator = &rayClusterWebhook{}
func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) error {
rayCluster := obj.(*rayv1.RayCluster)

if !ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
return nil
}
if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
rayclusterlog.V(2).Info("Adding OAuth sidecar container")
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), withContainerName(oauthProxyContainerName))

rayclusterlog.V(2).Info("Adding OAuth sidecar container")
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), withVolumeName(oauthProxyVolumeName))

rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), withContainerName(oauthProxyContainerName))
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
}

rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), withVolumeName(oauthProxyVolumeName))
if ptr.Deref(w.Config.MTLSEnabled, true) {
rayclusterlog.V(2).Info("Adding create-cert Init Containers")
// HeadGroupSpec //
// Append the list of environment variables for the ray-head container
for _, envVar := range envVarList() {
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[0].Env = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name))
}

// Append the create-cert Init Container
rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers, rayHeadInitContainer(rayCluster, w.Config.IngressDomain), withContainerName(initContainerName))

// Append the CA volumes
for _, caVol := range caVolumes(rayCluster) {
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
}
// WorkerGroupSpec //
// Append the list of environment variables for the worker container
for _, envVar := range envVarList() {
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Env = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name))
}

// Append the CA volumes
for _, caVol := range caVolumes(rayCluster) {
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
}
// Append the create-cert Init Container
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(), withContainerName(initContainerName))

rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
}

return nil
}
Expand Down Expand Up @@ -117,6 +146,14 @@ func (w *rayClusterWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj r
allErrors = append(allErrors, validateHeadGroupServiceAccountName(rayCluster)...)
}

// Init Container related errors
if ptr.Deref(w.Config.MTLSEnabled, true) {
allErrors = append(allErrors, validateHeadInitContainer(rayCluster, w.Config.IngressDomain)...)
allErrors = append(allErrors, validateWorkerInitContainer(rayCluster)...)
allErrors = append(allErrors, validateHeadEnvVars(rayCluster)...)
allErrors = append(allErrors, validateWorkerEnvVars(rayCluster)...)
allErrors = append(allErrors, validateCaVolumes(rayCluster)...)
}
return warnings, allErrors.ToAggregate()
}

Expand Down Expand Up @@ -225,3 +262,169 @@ func oauthProxyTLSSecretVolume(rayCluster *rayv1.RayCluster) corev1.Volume {
},
}
}

func initCaVolumeMounts() []corev1.VolumeMount {
return []corev1.VolumeMount{
{
Name: "ca-vol",
MountPath: "/home/ray/workspace/ca",
ReadOnly: true,
},
{
Name: "server-cert",
MountPath: "/home/ray/workspace/tls",
ReadOnly: false,
},
}
}

func envVarList() []corev1.EnvVar {
return []corev1.EnvVar{
{
Name: "MY_POD_IP",
ValueFrom: &corev1.EnvVarSource{
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "status.podIP",
},
},
},
{
Name: "RAY_USE_TLS",
Value: "1",
},
{
Name: "RAY_TLS_SERVER_CERT",
Value: "/home/ray/workspace/tls/server.crt",
},
{
Name: "RAY_TLS_SERVER_KEY",
Value: "/home/ray/workspace/tls/server.key",
},
{
Name: "RAY_TLS_CA_CERT",
Value: "/home/ray/workspace/tls/ca.crt",
},
}
}

func caVolumes(rayCluster *rayv1.RayCluster) []corev1.Volume {
return []corev1.Volume{
{
Name: "ca-vol",
VolumeSource: corev1.VolumeSource{
Secret: &corev1.SecretVolumeSource{
SecretName: `ca-secret-` + rayCluster.Name,
},
},
},
{
Name: "server-cert",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
},
}
}

func rayHeadInitContainer(rayCluster *rayv1.RayCluster, domain string) corev1.Container {
rayClientRoute := "rayclient-" + rayCluster.Name + "-" + rayCluster.Namespace + "." + domain
// Service name for basic interactive
svcDomain := rayCluster.Name + "-head-svc." + rayCluster.Namespace + ".svc"

initContainerHead := corev1.Container{
Name: "create-cert",
Image: "quay.io/project-codeflare/ray:latest-py39-cu118",
Command: []string{
"sh",
"-c",
`cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)\nDNS.5 = ` + rayClientRoute + `\nDNS.6 = ` + svcDomain + `">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext`,
},
VolumeMounts: initCaVolumeMounts(),
}
return initContainerHead
}

func rayWorkerInitContainer() corev1.Container {
initContainerWorker := corev1.Container{
Name: "create-cert",
Image: "quay.io/project-codeflare/ray:latest-py39-cu118",
Command: []string{
"sh",
"-c",
`cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext`,
},
VolumeMounts: initCaVolumeMounts(),
}
return initContainerWorker
}

func validateHeadInitContainer(rayCluster *rayv1.RayCluster, domain string) field.ErrorList {
var allErrors field.ErrorList

if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers, rayHeadInitContainer(rayCluster, domain), byContainerName,
field.NewPath("spec", "headGroupSpec", "template", "spec", "initContainers"),
"create-cert Init Container is immutable"); err != nil {
allErrors = append(allErrors, err)
}

return allErrors
}

func validateWorkerInitContainer(rayCluster *rayv1.RayCluster) field.ErrorList {
var allErrors field.ErrorList

if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(), byContainerName,
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "initContainers"),
"create-cert Init Container is immutable"); err != nil {
allErrors = append(allErrors, err)
}

return allErrors
}

func validateCaVolumes(rayCluster *rayv1.RayCluster) field.ErrorList {
var allErrors field.ErrorList

for _, caVol := range caVolumes(rayCluster) {
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, caVol, byVolumeName,
field.NewPath("spec", "headGroupSpec", "template", "spec", "volumes"),
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
allErrors = append(allErrors, err)
}
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, byVolumeName,
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "volumes"),
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
allErrors = append(allErrors, err)
}
}

return allErrors
}

func validateHeadEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList {
var allErrors field.ErrorList

for _, envVar := range envVarList() {
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[0].Env, envVar, byEnvVarName,
field.NewPath("spec", "headGroupSpec", "template", "spec", "containers", strconv.Itoa(0), "env"),
"RAY_TLS related environment variables are immutable"); err != nil {
allErrors = append(allErrors, err)
}
}

return allErrors
}

func validateWorkerEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList {
var allErrors field.ErrorList

for _, envVar := range envVarList() {
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Env, envVar, byEnvVarName,
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "containers", strconv.Itoa(0), "env"),
"RAY_TLS related environment variables are immutable"); err != nil {
allErrors = append(allErrors, err)
}
}

return allErrors
}
11 changes: 11 additions & 0 deletions pkg/controllers/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,14 @@ func withVolumeName(name string) compare[corev1.Volume] {
return v1.Name == name
}
}

var byEnvVarName = compare[corev1.EnvVar](
func(e1, e2 corev1.EnvVar) bool {
return e1.Name == e2.Name
})

func withEnvVarName(name string) compare[corev1.EnvVar] {
return func(e1, e2 corev1.EnvVar) bool {
return e1.Name == name
}
}
Loading