Skip to content

Commit 275b0d9

Browse files
committed
[GR-64498] Implement ML-KEM intrinsics.
PullRequest: graal/20787
2 parents ae016b5 + ece3883 commit 275b0d9

File tree

6 files changed

+254
-12
lines changed

6 files changed

+254
-12
lines changed

compiler/src/jdk.graal.compiler.test/src/jdk/graal/compiler/hotspot/test/HotSpotCryptoSubstitutionTest.java

+71-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import java.util.Random;
4646

4747
import javax.crypto.Cipher;
48+
import javax.crypto.KEM;
4849
import javax.crypto.KeyGenerator;
4950
import javax.crypto.SecretKey;
5051

@@ -400,13 +401,17 @@ void testWithInstalledIntrinsic(String className, String methodName, String test
400401
Assume.assumeTrue(className + " is not available", false);
401402
return;
402403
}
404+
testWithInstalledIntrinsic(getMetaAccess().lookupJavaMethod(getMethod(c, methodName)), testSnippetName, args);
405+
}
406+
407+
void testWithInstalledIntrinsic(ResolvedJavaMethod intrinsicMethod, String testSnippetName, Object... args) {
403408
InstalledCode code = null;
404409
try {
405410
ResolvedJavaMethod method = getResolvedJavaMethod(testSnippetName);
406411
Object receiver = method.isStatic() ? null : this;
407412
GraalCompilerTest.Result expect = executeExpected(method, receiver, args);
408-
code = compileAndInstallSubstitution(c, methodName);
409-
assertTrue("Failed to install " + methodName, code != null);
413+
code = compileAndInstallSubstitution(intrinsicMethod);
414+
assertTrue("Failed to install " + intrinsicMethod.getName(), code != null);
410415
testAgainstExpected(method, expect, receiver, args);
411416
} catch (AssumptionViolatedException e) {
412417
// Suppress so that subsequent calls to this method within the
@@ -528,4 +533,68 @@ public void testMLDSASigVer() {
528533
testWithInstalledIntrinsic("sun.security.provider.ML_DSA", "implDilithiumMontMulByConstant", "testSignVer", "ML-DSA-87");
529534
testWithInstalledIntrinsic("sun.security.provider.ML_DSA", "implDilithiumDecomposePoly", "testSignVer", "ML-DSA-87");
530535
}
536+
537+
public boolean testMLKEMEncapsulateDecapsulate(String algorithm) throws GeneralSecurityException {
538+
var kp = generateKeyPair(algorithm);
539+
var senderKem = KEM.getInstance(algorithm);
540+
541+
var encapsulator = senderKem.newEncapsulator(kp.getPublic(), new SeededSecureRandom());
542+
var enc = encapsulator.encapsulate();
543+
SecretKey key = enc.key();
544+
545+
var receiverKem = KEM.getInstance(algorithm);
546+
byte[] ciphertext = enc.encapsulation();
547+
var decapsulator = receiverKem.newDecapsulator(kp.getPrivate());
548+
SecretKey decapsulatedKey = decapsulator.decapsulate(ciphertext);
549+
550+
return key.equals(decapsulatedKey);
551+
}
552+
553+
@Test
554+
public void testMLKEM() {
555+
Assume.assumeTrue("ML_KEM not supported", runtime().getVMConfig().stubKyberNtt != 0L);
556+
Assume.assumeTrue("ML_KEM not supported", runtime().getVMConfig().stubKyberInverseNtt != 0L);
557+
Assume.assumeTrue("ML_KEM not supported", runtime().getVMConfig().stubKyberNttMult != 0L);
558+
Assume.assumeTrue("ML_KEM not supported", runtime().getVMConfig().stubKyberAddPoly2 != 0L);
559+
Assume.assumeTrue("ML_KEM not supported", runtime().getVMConfig().stubKyberAddPoly3 != 0L);
560+
Assume.assumeTrue("ML_KEM not supported", runtime().getVMConfig().stubKyber12To16 != 0L);
561+
Assume.assumeTrue("ML_KEM not supported", runtime().getVMConfig().stubKyberBarrettReduce != 0L);
562+
563+
Class<?> c;
564+
try {
565+
c = Class.forName("sun.security.provider.ML_KEM");
566+
} catch (ClassNotFoundException e) {
567+
Assume.assumeTrue("sun.security.provider.ML_KEM is not available", false);
568+
return;
569+
}
570+
571+
// ML-KEM-512
572+
testWithInstalledIntrinsic("sun.security.provider.ML_KEM", "implKyberNtt", "testMLKEMEncapsulateDecapsulate", "ML-KEM-512");
573+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberInverseNtt", "testMLKEMEncapsulateDecapsulate", "ML-KEM-512");
574+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberNttMult", "testMLKEMEncapsulateDecapsulate", "ML-KEM-512");
575+
testWithInstalledIntrinsic(getMetaAccess().lookupJavaMethod(getMethod(c, "implKyberAddPoly", short[].class, short[].class, short[].class)), "testMLKEMEncapsulateDecapsulate", "ML-KEM-512");
576+
testWithInstalledIntrinsic(getMetaAccess().lookupJavaMethod(getMethod(c, "implKyberAddPoly", short[].class, short[].class, short[].class, short[].class)), "testMLKEMEncapsulateDecapsulate",
577+
"ML-KEM-512");
578+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyber12To16", "testMLKEMEncapsulateDecapsulate", "ML-KEM-512");
579+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyber12To16", "testMLKEMEncapsulateDecapsulate", "ML-KEM-512");
580+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberBarrettReduce", "testMLKEMEncapsulateDecapsulate", "ML-KEM-512");
581+
// ML-KEM-768
582+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberNtt", "testMLKEMEncapsulateDecapsulate", "ML-KEM-768");
583+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberInverseNtt", "testMLKEMEncapsulateDecapsulate", "ML-KEM-768");
584+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberNttMult", "testMLKEMEncapsulateDecapsulate", "ML-KEM-768");
585+
testWithInstalledIntrinsic(getMetaAccess().lookupJavaMethod(getMethod(c, "implKyberAddPoly", short[].class, short[].class, short[].class)), "testMLKEMEncapsulateDecapsulate", "ML-KEM-768");
586+
testWithInstalledIntrinsic(getMetaAccess().lookupJavaMethod(getMethod(c, "implKyberAddPoly", short[].class, short[].class, short[].class, short[].class)), "testMLKEMEncapsulateDecapsulate",
587+
"ML-KEM-768");
588+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyber12To16", "testMLKEMEncapsulateDecapsulate", "ML-KEM-768");
589+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberBarrettReduce", "testMLKEMEncapsulateDecapsulate", "ML-KEM-768");
590+
// ML-KEM-1024
591+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberNtt", "testMLKEMEncapsulateDecapsulate", "ML-KEM-1024");
592+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberInverseNtt", "testMLKEMEncapsulateDecapsulate", "ML-KEM-1024");
593+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberNttMult", "testMLKEMEncapsulateDecapsulate", "ML-KEM-1024");
594+
testWithInstalledIntrinsic(getMetaAccess().lookupJavaMethod(getMethod(c, "implKyberAddPoly", short[].class, short[].class, short[].class)), "testMLKEMEncapsulateDecapsulate", "ML-KEM-1024");
595+
testWithInstalledIntrinsic(getMetaAccess().lookupJavaMethod(getMethod(c, "implKyberAddPoly", short[].class, short[].class, short[].class, short[].class)), "testMLKEMEncapsulateDecapsulate",
596+
"ML-KEM-1024");
597+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyber12To16", "testMLKEMEncapsulateDecapsulate", "ML-KEM-1024");
598+
testWithInstalledIntrinsic("sun.security.provider.ML-KEM", "implKyberBarrettReduce", "testMLKEMEncapsulateDecapsulate", "ML-KEM-1024");
599+
}
531600
}

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/hotspot/GraalHotSpotVMConfig.java

+8
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,14 @@ public int threadTlabTopOffset() {
540540
public final long stubDilithiumMontMulByConstant = getFieldValue("StubRoutines::_dilithiumMontMulByConstant", Long.class, "address");
541541
public final long stubDilithiumDecomposePoly = getFieldValue("StubRoutines::_dilithiumDecomposePoly", Long.class, "address");
542542

543+
public final long stubKyberNtt = getFieldValue("StubRoutines::_kyberNtt", Long.class, "address");
544+
public final long stubKyberInverseNtt = getFieldValue("StubRoutines::_kyberInverseNtt", Long.class, "address");
545+
public final long stubKyberNttMult = getFieldValue("StubRoutines::_kyberNttMult", Long.class, "address");
546+
public final long stubKyberAddPoly2 = getFieldValue("StubRoutines::_kyberAddPoly_2", Long.class, "address");
547+
public final long stubKyberAddPoly3 = getFieldValue("StubRoutines::_kyberAddPoly_3", Long.class, "address");
548+
public final long stubKyber12To16 = getFieldValue("StubRoutines::_kyber12To16", Long.class, "address");
549+
public final long stubKyberBarrettReduce = getFieldValue("StubRoutines::_kyberBarrettReduce", Long.class, "address");
550+
543551
// Allocation stubs that return null when allocation fails
544552
public final long newInstanceOrNullAddress = getAddress("JVMCIRuntime::new_instance_or_null");
545553
public final long newArrayOrNullAddress = getAddress("JVMCIRuntime::new_array_or_null");

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/hotspot/HotSpotBackend.java

+15
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,21 @@ public static void unsafeArraycopy(Word srcAddr, Word dstAddr, Word size) {
271271
public static final HotSpotForeignCallDescriptor DILITHIUM_DECOMPOSE_POLY = new HotSpotForeignCallDescriptor(LEAF, HAS_SIDE_EFFECT, any(), "_dilithiumDecomposePoly", int.class,
272272
WordBase.class, WordBase.class, WordBase.class, int.class, int.class);
273273

274+
public static final HotSpotForeignCallDescriptor KYBER_NTT = new HotSpotForeignCallDescriptor(LEAF, HAS_SIDE_EFFECT, any(), "_kyberNtt", int.class,
275+
WordBase.class, WordBase.class);
276+
public static final HotSpotForeignCallDescriptor KYBER_INVERSE_NTT = new HotSpotForeignCallDescriptor(LEAF, HAS_SIDE_EFFECT, any(), "_kyberInverseNtt", int.class,
277+
WordBase.class, WordBase.class);
278+
public static final HotSpotForeignCallDescriptor KYBER_NTT_MULT = new HotSpotForeignCallDescriptor(LEAF, HAS_SIDE_EFFECT, any(), "_kyberNttMult", int.class,
279+
WordBase.class, WordBase.class, WordBase.class, WordBase.class);
280+
public static final HotSpotForeignCallDescriptor KYBER_ADD_POLY_2 = new HotSpotForeignCallDescriptor(LEAF, HAS_SIDE_EFFECT, any(), "_kyberAddPoly_2", int.class,
281+
WordBase.class, WordBase.class, WordBase.class);
282+
public static final HotSpotForeignCallDescriptor KYBER_ADD_POLY_3 = new HotSpotForeignCallDescriptor(LEAF, HAS_SIDE_EFFECT, any(), "_kyberAddPoly_3", int.class,
283+
WordBase.class, WordBase.class, WordBase.class, WordBase.class);
284+
public static final HotSpotForeignCallDescriptor KYBER_12_TO_16 = new HotSpotForeignCallDescriptor(LEAF, HAS_SIDE_EFFECT, any(), "_kyber12To16", int.class,
285+
WordBase.class, int.class, WordBase.class, int.class);
286+
public static final HotSpotForeignCallDescriptor KYBER_BARRETT_REDUCE = new HotSpotForeignCallDescriptor(LEAF, HAS_SIDE_EFFECT, any(), "_kyberBarrettReduce", int.class,
287+
WordBase.class);
288+
274289
public static final HotSpotForeignCallDescriptor SHAREDRUNTIME_NOTIFY_JVMTI_VTHREAD_START = new HotSpotForeignCallDescriptor(SAFEPOINT, HAS_SIDE_EFFECT, any(),
275290
"notify_jvmti_vthread_start", void.class,
276291
Object.class, boolean.class, Word.class);

0 commit comments

Comments
 (0)