Skip to content

Commit 3abc796

Browse files
committed
Added Mtls patch
(cherry picked from commit de2de96fc88022df783b637ccb145d1d73ba66ff)
1 parent 2fe0a52 commit 3abc796

File tree

5 files changed

+272
-7
lines changed

5 files changed

+272
-7
lines changed

config/rbac/role.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ rules:
4444
- subjectaccessreviews
4545
verbs:
4646
- create
47+
- apiGroups:
48+
- config.openshift.io
49+
resources:
50+
- ingresses
51+
verbs:
52+
- get
4753
- apiGroups:
4854
- ""
4955
resources:

main.go

+25
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ import (
5252
"sigs.k8s.io/yaml"
5353

5454
routev1 "github.com/openshift/api/route/v1"
55+
clientset "github.com/openshift/client-go/config/clientset/versioned"
5556

5657
"github.com/project-codeflare/codeflare-operator/pkg/config"
5758
"github.com/project-codeflare/codeflare-operator/pkg/controllers"
@@ -75,6 +76,8 @@ func init() {
7576
utilruntime.Must(routev1.Install(scheme))
7677
}
7778

79+
// +kubebuilder:rbac:groups=config.openshift.io,resources=ingresses,verbs=get;
80+
7881
func main() {
7982
var configMapName string
8083
flag.StringVar(&configMapName, "config", "codeflare-operator-config",
@@ -117,6 +120,7 @@ func main() {
117120
KubeRay: &config.KubeRayConfiguration{
118121
RayDashboardOAuthEnabled: ptr.To(true),
119122
IngressDomain: "",
123+
MTLSEnabled: ptr.To(true),
120124
},
121125
}
122126

@@ -155,6 +159,13 @@ func main() {
155159
certsReady := make(chan struct{})
156160
exitOnError(setupCertManagement(mgr, namespace, certsReady), "unable to setup cert-controller")
157161

162+
if cfg.KubeRay.IngressDomain == "" {
163+
configClient, err := clientset.NewForConfig(kubeConfig)
164+
exitOnError(err, "unable to create Route Client Set")
165+
cfg.KubeRay.IngressDomain, err = getClusterDomain(ctx, configClient)
166+
exitOnError(err, cfg.KubeRay.IngressDomain)
167+
}
168+
158169
go setupControllers(mgr, kubeClient, cfg, isOpenShift(ctx, kubeClient.DiscoveryClient), certsReady)
159170

160171
setupLog.Info("setting up health endpoints")
@@ -332,3 +343,17 @@ func isOpenShift(ctx context.Context, dc discovery.DiscoveryInterface) bool {
332343
logger.Info("We detected being on Vanilla Kubernetes!")
333344
return false
334345
}
346+
347+
func getClusterDomain(ctx context.Context, configClient *clientset.Clientset) (string, error) {
348+
ingress, err := configClient.ConfigV1().Ingresses().Get(ctx, "cluster", metav1.GetOptions{})
349+
if err != nil {
350+
return "", fmt.Errorf("failed to get Ingress object: %v", err)
351+
}
352+
353+
domain := ingress.Spec.Domain
354+
if domain == "" {
355+
return "", fmt.Errorf("domain is not set in the Ingress object")
356+
}
357+
358+
return domain, nil
359+
}

pkg/config/config.go

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ type KubeRayConfiguration struct {
3535
RayDashboardOAuthEnabled *bool `json:"rayDashboardOAuthEnabled,omitempty"`
3636

3737
IngressDomain string `json:"ingressDomain"`
38+
39+
MTLSEnabled *bool `json:"mTLSEnabled,omitempty"`
3840
}
3941

4042
type ControllerManager struct {

pkg/controllers/raycluster_webhook.go

+228-7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package controllers
1818

1919
import (
2020
"context"
21+
"strconv"
2122

2223
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
2324

@@ -36,6 +37,7 @@ import (
3637
const (
3738
oauthProxyContainerName = "oauth-proxy"
3839
oauthProxyVolumeName = "proxy-tls-secret"
40+
initContainerName = "create-cert"
3941
)
4042

4143
// log is for logging in this package.
@@ -66,17 +68,47 @@ var _ webhook.CustomValidator = &rayClusterWebhook{}
6668
func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) error {
6769
rayCluster := obj.(*rayv1.RayCluster)
6870

69-
if !ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
70-
return nil
71-
}
71+
if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
72+
rayclusterlog.V(2).Info("Adding OAuth sidecar container")
73+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), withContainerName(oauthProxyContainerName))
7274

73-
rayclusterlog.V(2).Info("Adding OAuth sidecar container")
75+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), withVolumeName(oauthProxyVolumeName))
7476

75-
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers, oauthProxyContainer(rayCluster), withContainerName(oauthProxyContainerName))
77+
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
78+
}
7679

77-
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, oauthProxyTLSSecretVolume(rayCluster), withVolumeName(oauthProxyVolumeName))
80+
if ptr.Deref(w.Config.MTLSEnabled, true) {
81+
rayclusterlog.V(2).Info("Adding create-cert Init Containers")
82+
// HeadGroupSpec //
83+
// Append the list of environment variables for the ray-head container
84+
for index := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
85+
for _, envVar := range envVarList() {
86+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[index].Env = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[index].Env, envVar, withEnvVarName(envVar.Name))
87+
}
88+
}
89+
90+
// Append the create-cert Init Container
91+
rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers, rayHeadInitContainer(rayCluster, w.Config.IngressDomain), withContainerName(initContainerName))
92+
93+
// Append the CA volumes
94+
for _, caVol := range caVolumes(rayCluster) {
95+
rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes = upsert(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
96+
}
97+
// WorkerGroupSpec //
98+
// Append the list of environment variables for the machine-learning container
99+
for index := range rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers {
100+
for _, envVar := range envVarList() {
101+
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[index].Env = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[index].Env, envVar, withEnvVarName(envVar.Name))
102+
}
103+
}
104+
// Append the CA volumes
105+
for _, caVol := range caVolumes(rayCluster) {
106+
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, withVolumeName(caVol.Name))
107+
}
108+
// Append the create-cert Init Container
109+
rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers = upsert(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(), withContainerName(initContainerName))
78110

79-
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = rayCluster.Name + "-oauth-proxy"
111+
}
80112

81113
return nil
82114
}
@@ -117,6 +149,14 @@ func (w *rayClusterWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj r
117149
allErrors = append(allErrors, validateHeadGroupServiceAccountName(rayCluster)...)
118150
}
119151

152+
// Init Container related errors
153+
if ptr.Deref(w.Config.MTLSEnabled, true) {
154+
allErrors = append(allErrors, validateHeadInitContainer(rayCluster, w)...)
155+
allErrors = append(allErrors, validateWorkerInitContainer(rayCluster)...)
156+
allErrors = append(allErrors, validateHeadEnvVars(rayCluster)...)
157+
allErrors = append(allErrors, validateWorkerEnvVars(rayCluster)...)
158+
allErrors = append(allErrors, validateCaVolumes(rayCluster)...)
159+
}
120160
return warnings, allErrors.ToAggregate()
121161
}
122162

@@ -225,3 +265,184 @@ func oauthProxyTLSSecretVolume(rayCluster *rayv1.RayCluster) corev1.Volume {
225265
},
226266
}
227267
}
268+
269+
func initCaVolumeMounts() []corev1.VolumeMount {
270+
return []corev1.VolumeMount{
271+
{
272+
Name: "ca-vol",
273+
MountPath: "/home/ray/workspace/ca",
274+
ReadOnly: true,
275+
},
276+
{
277+
Name: "server-cert",
278+
MountPath: "/home/ray/workspace/tls",
279+
ReadOnly: false,
280+
},
281+
}
282+
}
283+
284+
func envVarList() []corev1.EnvVar {
285+
return []corev1.EnvVar{
286+
{
287+
Name: "MY_POD_IP",
288+
ValueFrom: &corev1.EnvVarSource{
289+
FieldRef: &corev1.ObjectFieldSelector{
290+
FieldPath: "status.podIP",
291+
},
292+
},
293+
},
294+
{
295+
Name: "RAY_USE_TLS",
296+
Value: "1",
297+
},
298+
{
299+
Name: "RAY_TLS_SERVER_CERT",
300+
Value: "/home/ray/workspace/tls/server.crt",
301+
},
302+
{
303+
Name: "RAY_TLS_SERVER_KEY",
304+
Value: "/home/ray/workspace/tls/server.key",
305+
},
306+
{
307+
Name: "RAY_TLS_CA_CERT",
308+
Value: "/home/ray/workspace/tls/ca.crt",
309+
},
310+
}
311+
}
312+
313+
func caVolumes(rayCluster *rayv1.RayCluster) []corev1.Volume {
314+
return []corev1.Volume{
315+
{
316+
Name: "ca-vol",
317+
VolumeSource: corev1.VolumeSource{
318+
Secret: &corev1.SecretVolumeSource{
319+
SecretName: `ca-secret-` + rayCluster.Name,
320+
},
321+
},
322+
},
323+
{
324+
Name: "server-cert",
325+
VolumeSource: corev1.VolumeSource{
326+
EmptyDir: &corev1.EmptyDirVolumeSource{},
327+
},
328+
},
329+
}
330+
}
331+
332+
func rayHeadInitContainer(rayCluster *rayv1.RayCluster, domain string) corev1.Container {
333+
rayClientRoute := "rayclient-" + rayCluster.Name + "-" + rayCluster.Namespace + "." + domain
334+
// Service name for basic interactive
335+
svcDomain := rayCluster.Name + "-head-svc." + rayCluster.Namespace + ".svc"
336+
337+
initContainerHead := corev1.Container{
338+
Name: "create-cert",
339+
Image: "quay.io/project-codeflare/ray:latest-py39-cu118",
340+
Command: []string{
341+
"sh",
342+
"-c",
343+
`cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)\nDNS.5 = ` + rayClientRoute + `\nDNS.6 = ` + svcDomain + `">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext`,
344+
},
345+
VolumeMounts: initCaVolumeMounts(),
346+
}
347+
return initContainerHead
348+
}
349+
350+
func rayWorkerInitContainer() corev1.Container {
351+
initContainerWorker := corev1.Container{
352+
Name: "create-cert",
353+
Image: "quay.io/project-codeflare/ray:latest-py39-cu118",
354+
Command: []string{
355+
"sh",
356+
"-c",
357+
`cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext`,
358+
},
359+
VolumeMounts: initCaVolumeMounts(),
360+
}
361+
return initContainerWorker
362+
}
363+
364+
func validateHeadInitContainer(rayCluster *rayv1.RayCluster, w *rayClusterWebhook) field.ErrorList {
365+
var allErrors field.ErrorList
366+
367+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.InitContainers, rayHeadInitContainer(rayCluster, w.Config.IngressDomain), byContainerName,
368+
field.NewPath("spec", "headGroupSpec", "template", "spec", "initContainers"),
369+
"create-cert Init Container is immutable"); err != nil {
370+
allErrors = append(allErrors, err)
371+
}
372+
373+
return allErrors
374+
}
375+
376+
func validateWorkerInitContainer(rayCluster *rayv1.RayCluster) field.ErrorList {
377+
var allErrors field.ErrorList
378+
379+
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers, rayWorkerInitContainer(), byContainerName,
380+
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "initContainers"),
381+
"create-cert Init Container is immutable"); err != nil {
382+
allErrors = append(allErrors, err)
383+
}
384+
385+
return allErrors
386+
}
387+
388+
func validateCaVolumes(rayCluster *rayv1.RayCluster) field.ErrorList {
389+
var allErrors field.ErrorList
390+
391+
for _, caVol := range caVolumes(rayCluster) {
392+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Volumes, caVol, byVolumeName,
393+
field.NewPath("spec", "headGroupSpec", "template", "spec", "volumes"),
394+
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
395+
allErrors = append(allErrors, err)
396+
}
397+
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Volumes, caVol, byVolumeName,
398+
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "volumes"),
399+
"ca-vol and server-cert Secret volumes are immutable"); err != nil {
400+
allErrors = append(allErrors, err)
401+
}
402+
}
403+
404+
return allErrors
405+
}
406+
407+
func validateHeadEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList {
408+
var allErrors field.ErrorList
409+
item := 0
410+
for index, container := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
411+
if container.Name == "ray-head" {
412+
item = index
413+
break
414+
}
415+
}
416+
417+
for _, envVar := range envVarList() {
418+
if err := contains(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[item].Env, envVar, byEnvVarName,
419+
field.NewPath("spec", "headGroupSpec", "template", "spec", "containers", strconv.Itoa(item), "env"),
420+
"RAY_TLS related environment variables are immutable"); err != nil {
421+
allErrors = append(allErrors, err)
422+
}
423+
}
424+
425+
return allErrors
426+
}
427+
428+
func validateWorkerEnvVars(rayCluster *rayv1.RayCluster) field.ErrorList {
429+
var allErrors field.ErrorList
430+
item := 0
431+
432+
for index, container := range rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers {
433+
if container.Name == "machine-learning" {
434+
item = index
435+
break
436+
}
437+
}
438+
439+
for _, envVar := range envVarList() {
440+
if err := contains(rayCluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[item].Env, envVar, byEnvVarName,
441+
field.NewPath("spec", "workerGroupSpecs", "0", "template", "spec", "containers", strconv.Itoa(item), "env"),
442+
"RAY_TLS related environment variables are immutable"); err != nil {
443+
allErrors = append(allErrors, err)
444+
}
445+
}
446+
447+
return allErrors
448+
}

pkg/controllers/support.go

+11
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,14 @@ func withVolumeName(name string) compare[corev1.Volume] {
140140
return v1.Name == name
141141
}
142142
}
143+
144+
var byEnvVarName = compare[corev1.EnvVar](
145+
func(e1, e2 corev1.EnvVar) bool {
146+
return e1.Name == e2.Name
147+
})
148+
149+
func withEnvVarName(name string) compare[corev1.EnvVar] {
150+
return func(e1, e2 corev1.EnvVar) bool {
151+
return e1.Name == name
152+
}
153+
}

0 commit comments

Comments
 (0)