Skip to content

Commit 98b3ea6

Browse files
committed
Stubs about GCD
Signed-off-by: Kakadu <Kakadu@pm.me>
1 parent bb44c96 commit 98b3ea6

File tree

4 files changed

+288
-0
lines changed

4 files changed

+288
-0
lines changed

GCD1.ml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
let rec naive_gcd u v = if u mod v = 0 then v else naive_gcd v (u / v)
2+
3+
external c_naive_gcd : int -> int -> int = "caml_naive_gcd" [@@noalloc]
4+
external c_binary_gcd : int -> int -> int = "caml_binary_gcd" [@@noalloc]
5+
6+
external c_hybrid_binary_gcd : int -> int -> int = "caml_hybrid_binary_gcd"
7+
[@@noalloc]
8+
9+
let test ~name f = (name, f 1100087778366101931, 679891637638612258)
10+
11+
let () =
12+
Benchmark.tabulate
13+
@@ Benchmark.throughputN ~style:Nil ~repeat:1 1
14+
[
15+
test ~name:"OCaml naive_gcd" naive_gcd;
16+
test ~name:"C naive_gcd" c_naive_gcd;
17+
test ~name:"C binary_gcd" c_binary_gcd;
18+
test ~name:"C hybrid_binary_gcd" c_hybrid_binary_gcd;
19+
]
20+
21+
let () =
22+
Benchmark.tabulate
23+
@@ Benchmark.latencyN ~style:Nil ~repeat:10 4L
24+
[
25+
test ~name:"OCaml naive_gcd" naive_gcd;
26+
test ~name:"C naive_gcd" c_naive_gcd;
27+
test ~name:"C binary_gcd" c_binary_gcd;
28+
test ~name:"C hybrid_binary_gcd" c_hybrid_binary_gcd;
29+
]
30+
(* It looks like core_bench should be better *)

GCDs.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include <iostream>
2+
#include <numeric>
3+
#include "gcd.h"
4+
5+
extern "C" {
6+
#include <caml/memory.h>
7+
#include <caml/mlvalues.h>
8+
9+
#define WRAP(name) \
10+
value caml_##name(value u, value v) { \
11+
CAMLparam2(u,v);\
12+
CAMLreturn(Val_int(name(\
13+
(uint64_t)Int_val(u),\
14+
(uint64_t)Int_val(v))));\
15+
}
16+
17+
WRAP(naive_gcd)
18+
WRAP(binary_gcd)
19+
WRAP(hybrid_binary_gcd)
20+
// WRAP(extended_gcd)
21+
// WRAP(extended_one_gcd)
22+
// WRAP(binary_extended_gcd)
23+
24+
}

dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,14 @@
6161
(libraries benchmark)
6262
(ocamlopt_flags
6363
(:standard -S -g)))
64+
65+
(executable
66+
(name GCD1)
67+
(modules GCD1)
68+
(libraries benchmark)
69+
(foreign_stubs
70+
(language cxx)
71+
(flags -std=c++20 -march=native -O3 -Wall)
72+
(names GCDs))
73+
(ocamlopt_flags
74+
(:standard -S -g)))

gcd.h

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
// https://github.com/lemire/Code-used-on-Daniel-Lemire-s-blog/blob/master/2024/04/13/module/gcd.h
2+
#include <bit>
3+
#include <utility>
4+
5+
// computes the greatest common divisor between u and v
6+
template <std::unsigned_integral int_type>
7+
int_type naive_gcd(int_type u, int_type v) {
8+
return (u % v) == 0 ? v : naive_gcd(v, u % v);
9+
}
10+
11+
// computes the greatest common divisor between u and v
12+
template <std::unsigned_integral int_type>
13+
int_type binary_gcd(int_type u, int_type v) {
14+
if (u == 0) {
15+
return v;
16+
}
17+
if (v == 0) {
18+
return u;
19+
}
20+
auto shift = std::countr_zero(u | v);
21+
u >>= std::countr_zero(u);
22+
do {
23+
v >>= std::countr_zero(v);
24+
if (u > v) {
25+
std::swap(u, v);
26+
}
27+
v = v - u;
28+
} while (v != 0);
29+
return u << shift;
30+
}
31+
32+
// credit: Paolo Bonzini
33+
template <std::unsigned_integral int_type>
34+
int_type binary_gcd_noswap(int_type u, int_type v) {
35+
if (u == 0) {
36+
return v;
37+
}
38+
if (v == 0) {
39+
return u;
40+
}
41+
auto shift = std::countr_zero(u | v);
42+
u >>= std::countr_zero(u);
43+
do {
44+
int_type t = v >> std::countr_zero(v);
45+
if (u > t)
46+
v = u - t, u = t;
47+
else
48+
v = t - u;
49+
} while (v != 0);
50+
return u << shift;
51+
}
52+
53+
template <class T> T hybrid_binary_gcd(T u, T v) {
54+
if (u < v) {
55+
std::swap(u, v);
56+
}
57+
if (v == 0) {
58+
return u;
59+
}
60+
u %= v;
61+
if (u == 0) {
62+
return v;
63+
}
64+
auto zu = std::countr_zero(u);
65+
auto zv = std::countr_zero(v);
66+
auto shift = std::min(zu, zv);
67+
u >>= zu;
68+
v >>= zv;
69+
do {
70+
T u_minus_v = u - v;
71+
if (u > v)
72+
u = v, v = u_minus_v;
73+
else
74+
v = v - u;
75+
v >>= std::countr_zero(u_minus_v);
76+
} while (v != 0);
77+
return u << shift;
78+
}
79+
80+
template <std::unsigned_integral int_type> struct bezout {
81+
int_type gcd;
82+
int_type x;
83+
int_type y;
84+
};
85+
86+
template <std::unsigned_integral int_type> struct pair {
87+
int_type old_value;
88+
int_type new_value;
89+
};
90+
91+
// computes the greatest common divisor between a and b,
92+
// as well as the Bézout coefficients x and y such as
93+
// a*x + b*y = gcd(a,b)
94+
template <std::unsigned_integral int_type>
95+
bezout<int_type> extended_gcd(int_type u, int_type v) {
96+
pair<int_type> r = {u, v};
97+
pair<int_type> s = {1, 0};
98+
pair<int_type> t = {0, 1};
99+
while (r.new_value != 0) {
100+
auto quotient = r.old_value / r.new_value;
101+
r = {r.new_value, r.old_value - quotient * r.new_value};
102+
s = {s.new_value, s.old_value - quotient * s.new_value};
103+
t = {t.new_value, t.old_value - quotient * t.new_value};
104+
}
105+
return {r.old_value, s.old_value, t.old_value};
106+
}
107+
108+
// This computes just one of the Bézout coefficients
109+
template <std::unsigned_integral int_type>
110+
bezout<int_type> extended_one_gcd(int_type u, int_type v) {
111+
pair<int_type> r = {u, v};
112+
pair<int_type> s = {1, 0};
113+
while (r.new_value != 0) {
114+
auto quotient = r.old_value / r.new_value;
115+
r = {r.new_value, r.old_value - quotient * r.new_value};
116+
s = {s.new_value, s.old_value - quotient * s.new_value};
117+
}
118+
return {r.old_value, s.old_value, 0};
119+
}
120+
121+
// From section 14.61 in https://cacr.uwaterloo.ca/hac/
122+
// Warning: signed integer overflow may occur if
123+
// std::max(a, b) >= std::numeric_limits<int_type>::max() / 16
124+
template <std::unsigned_integral int_type>
125+
bezout<int_type> binary_extended_gcd(int_type a, int_type b) {
126+
using sint_type = typename std::make_signed<int_type>::type;
127+
128+
if (a == 0)
129+
return {b, 0, !!b}; // {0, 0, 0} if b == 0 else {b, 0, 1}
130+
if (b == 0)
131+
return {a, 1, 0};
132+
133+
bool swapped = false;
134+
if (a > b) {
135+
swapped = true;
136+
std::swap(a, b);
137+
}
138+
139+
auto r = std::countr_zero(a | b);
140+
a >>= r;
141+
b >>= r;
142+
143+
sint_type x = (sint_type)a;
144+
sint_type y = (sint_type)b;
145+
sint_type s = 1;
146+
sint_type t = 0;
147+
sint_type u = 0;
148+
sint_type v = 1;
149+
while (x) {
150+
while ((x & 1) == 0) { // a is even
151+
x /= 2;
152+
if (((s | t) & 1) == 0) {
153+
s /= 2;
154+
t /= 2;
155+
} else {
156+
s = std::midpoint(s, (sint_type)b);
157+
t = std::midpoint(t, -(sint_type)a);
158+
}
159+
}
160+
while ((y & 1) == 0) { // b is even
161+
y /= 2;
162+
if (((u | v) & 1) == 0) {
163+
u /= 2;
164+
v /= 2;
165+
} else {
166+
u = std::midpoint(u, (sint_type)b);
167+
v = std::midpoint(v, -(sint_type)a);
168+
}
169+
}
170+
if (x >= y) {
171+
x -= y;
172+
s -= u;
173+
t -= v;
174+
} else {
175+
y -= x;
176+
u -= s;
177+
v -= t;
178+
}
179+
}
180+
181+
if (swapped) {
182+
std::swap(a, b);
183+
std::swap(u, v);
184+
std::swap(s, t);
185+
}
186+
187+
// Enable below if you want to make sure that
188+
// |x| + |y| is the minimal (primary)
189+
// and x <= y (secondarily)
190+
191+
// if (y > 1) {
192+
// a /= (int_type)y;
193+
// b /= (int_type)y;
194+
// }
195+
// if (a && (int_type)std::abs(v) >= a) {
196+
// sint_type _ = v / (sint_type)a;
197+
// v -= _ * (sint_type)a;
198+
// u += _ * (sint_type)b;
199+
// }
200+
// if (b && (int_type)std::abs(u) >= b) {
201+
// sint_type _ = u / (sint_type)b;
202+
// u -= _ * (sint_type)b;
203+
// v += _ * (sint_type)a;
204+
// }
205+
// {
206+
// sint_type u_ = u + (sint_type)b;
207+
// sint_type v_ = v - (sint_type)a;
208+
// if (std::abs(u_) + std::abs(v_) <= std::abs(u) + std::abs(v)) {
209+
// u = u_;
210+
// v = v_;
211+
// }
212+
// }
213+
// {
214+
// sint_type u_ = u - (sint_type)b;
215+
// sint_type v_ = v + (sint_type)a;
216+
// if (std::abs(u_) + std::abs(v_) <= std::abs(u) + std::abs(v)) {
217+
// u = u_;
218+
// v = v_;
219+
// }
220+
// }
221+
222+
return {(int_type)y << r, (int_type)u, (int_type)v};
223+
}

0 commit comments

Comments
 (0)