Skip to content

Commit c998bf2

Browse files
authored
Merge thinc-apple-ops into Thinc (#927)
This change adds `AppleOps` to Thinc, to ensure that the AMX unit is always used on Apple Silicon Macs. Before this change, a user would get much worse performance if they forgot to install `thinc-apple-ops`. The `apple_ops` and `_accelerate` modules are built conditionally. When detecting the best CPU implementation, we rely on a `try...except` import to determine whether Apple ops are available. Even though x86_64 Macs do not have an AMX unit, Accelerate is competitive with BLIS, so it does not hurt to enable Apple ops on all Macs.
1 parent ec68d7d commit c998bf2

File tree

13 files changed

+279
-22
lines changed

13 files changed

+279
-22
lines changed

.github/workflows/tests.yml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,3 @@ jobs:
152152

153153
- name: Run tests with extras
154154
run: python -m pytest --pyargs thinc --cov=thinc --cov-report=term -p thinc.tests.enable_tensorflow -p thinc.tests.enable_mxnet
155-
156-
- name: Run tests for thinc-apple-ops
157-
run: |
158-
pip uninstall -y tensorflow
159-
pip install thinc-apple-ops
160-
python -m pytest --pyargs thinc_apple_ops
161-
if: matrix.os == 'macos-latest' && matrix.python_version == '3.10'
162-
163-
- name: Run tests with thinc-apple-ops
164-
run: python -m pytest --pyargs thinc
165-
if: matrix.os == 'macos-latest' && matrix.python_version == '3.10'

setup.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python
2+
import platform
23
import sys
34
from setuptools.command.build_ext import build_ext
45
from sysconfig import get_path
@@ -13,14 +14,16 @@
1314
# http://docs.cython.org/en/latest/src/userguide/source_files_and_compilation.html#compiler-options
1415
Options.docstrings = True
1516

17+
ACCELERATE = "thinc.backends._accelerate"
18+
APPLE_OPS = ["thinc.backends.apple_ops", ACCELERATE]
1619

1720
PACKAGES = find_packages()
1821
MOD_NAMES = [
1922
"thinc.backends.cblas",
2023
"thinc.backends.numpy_ops",
2124
"thinc.layers.sparselinear",
2225
"thinc.layers.premap_ids",
23-
]
26+
] + (APPLE_OPS if platform.system() == "Darwin" else [])
2427
COMPILE_OPTIONS = {
2528
"msvc": ["/Ox", "/EHsc"],
2629
"other": ["-O3", "-Wno-strict-prototypes", "-Wno-unused-function", "-std=c++11"],
@@ -78,7 +81,16 @@ def setup_package():
7881
ext_modules = []
7982
for name in MOD_NAMES:
8083
mod_path = name.replace(".", "/") + ".pyx"
81-
ext = Extension(name, [mod_path], language="c++", include_dirs=include_dirs)
84+
if name == ACCELERATE:
85+
ext = Extension(
86+
name,
87+
[mod_path],
88+
language="c++",
89+
include_dirs=include_dirs,
90+
libraries=["blas"],
91+
)
92+
else:
93+
ext = Extension(name, [mod_path], language="c++", include_dirs=include_dirs)
8294
ext_modules.append(ext)
8395
print("Cythonizing sources")
8496
ext_modules = cythonize(

thinc/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@
162162
xp2torch,
163163
)
164164

165+
try:
166+
from .backends import AppleOps
167+
except ImportError:
168+
AppleOps = None
169+
165170
# fmt: off
166171
__all__ = [
167172
# .config
@@ -198,7 +203,7 @@
198203
"has_cupy",
199204
# .backends
200205
"get_ops", "set_current_ops", "get_current_ops", "use_ops",
201-
"Ops", "CupyOps", "MPSOps", "NumpyOps", "set_gpu_allocator",
206+
"Ops", "AppleOps", "CupyOps", "MPSOps", "NumpyOps", "set_gpu_allocator",
202207
"use_pytorch_for_gpu_memory", "use_tensorflow_for_gpu_memory",
203208
# .layers
204209
"Dropout", "Embed", "expand_window", "HashEmbed", "LayerNorm", "Linear",

thinc/backends/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
from .numpy_ops import NumpyOps
2020
from .ops import Ops
2121

22+
try:
23+
from .apple_ops import AppleOps
24+
except ImportError:
25+
AppleOps = None
26+
2227
context_ops: ContextVar[Optional[Ops]] = ContextVar("context_ops", default=None)
2328
context_pools: ContextVar[dict] = ContextVar("context_pools", default={})
2429

@@ -83,10 +88,6 @@ def use_tensorflow_for_gpu_memory() -> None: # pragma: no cover
8388

8489

8590
def _import_extra_cpu_backends():
86-
try:
87-
from thinc_apple_ops import AppleOps
88-
except ImportError:
89-
pass
9091
try:
9192
from thinc_bigendian_ops import BigEndianOps
9293
except ImportError:
@@ -171,6 +172,7 @@ def _get_thread_state() -> threading.local:
171172
"use_ops",
172173
"ParamServer",
173174
"Ops",
175+
"AppleOps",
174176
"CupyOps",
175177
"MPSOps",
176178
"NumpyOps",

thinc/backends/_accelerate.pxd

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
cdef extern from "Accelerate/Accelerate.h":
2+
enum CBLAS_ORDER: CblasRowMajor, CblasColMajor
3+
enum CBLAS_TRANSPOSE: CblasNoTrans, CblasTrans, CblasConjTrans
4+
enum CBLAS_UPLO: CblasUpper, CblasLower
5+
enum CBLAS_DIAG: CblasNonUnit, CblasUnit
6+
enum CBLAS_SIDE: CblasLeft, CblasRight
7+
8+
# BLAS level 1 routines
9+
10+
void cblas_sswap(int M, float *x, int incX, float *y, int incY) nogil
11+
void cblas_sscal(int N, float alpha, float *x, int incX) nogil
12+
void cblas_scopy(int N, float *x, int incX, float *y, int incY) nogil
13+
void cblas_saxpy(int N, float alpha, float *x, int incX, float *y, int incY ) nogil
14+
float cblas_sdot(int N, float *x, int incX, float *y, int incY ) nogil
15+
float cblas_snrm2(int N, float *x, int incX) nogil
16+
float cblas_sasum(int N, float *x, int incX) nogil
17+
int cblas_isamax(int N, float *x, int incX) nogil
18+
19+
# BLAS level 2 routines
20+
void cblas_sgemv(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, int M, int N,
21+
float alpha, float *A, int lda, float *x, int incX,
22+
float beta, float *y, int incY) nogil
23+
24+
void cblas_sger(CBLAS_ORDER Order, int M, int N, float alpha, float *x,
25+
int incX, float *y, int incY, float *A, int lda) nogil
26+
27+
# BLAS level 3 routines
28+
void cblas_sgemm(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA,
29+
CBLAS_TRANSPOSE TransB, int M, int N, int K,
30+
float alpha, float *A, int lda, float *B, int ldb,
31+
float beta, float *C, int ldc) nogil
32+
33+
34+
cdef void sgemm(bint TransA, bint TransB, int M, int N, int K,
35+
float alpha, const float* A, int lda, const float *B,
36+
int ldb, float beta, float* C, int ldc) nogil
37+
38+
39+
cdef void saxpy(int N, float alpha, const float* X, int incX,
40+
float *Y, int incY) nogil

thinc/backends/_accelerate.pyx

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
cimport numpy as np
2+
from libc.stdint cimport uintptr_t
3+
4+
import numpy
5+
6+
7+
cpdef np.ndarray gemm(float[:, ::1] A, float[:, ::1] B,
8+
bint trans1=False, bint trans2=False,
9+
np.ndarray out=None):
10+
cdef int nM = A.shape[0] if not trans1 else A.shape[1]
11+
cdef int nK = A.shape[1] if not trans1 else A.shape[0]
12+
cdef int nK_b = B.shape[0] if not trans2 else B.shape[1]
13+
cdef int nN = B.shape[1] if not trans2 else B.shape[0]
14+
15+
cdef float[:, ::1] C = out
16+
17+
if out is None:
18+
out = numpy.empty((nM, nN), dtype="f")
19+
C = out
20+
else:
21+
if C.shape[0] != nM or C.shape[1] != nN:
22+
msg = "Shape mismatch for output matrix, was: (%d, %d), expected (%d, %d)"
23+
raise ValueError(msg % (C.shape[0], C.shape[1], nM, nN))
24+
25+
26+
if nK != nK_b:
27+
msg = "Shape mismatch for gemm: (%d, %d), (%d, %d)"
28+
raise ValueError(msg % (nM, nK, nK_b, nN))
29+
30+
if nM == 0 or nK == 0 or nN == 0:
31+
return out
32+
33+
cblas_sgemm(
34+
CblasRowMajor,
35+
CblasTrans if trans1 else CblasNoTrans,
36+
CblasTrans if trans2 else CblasNoTrans,
37+
nM,
38+
nN,
39+
nK,
40+
1.0,
41+
&A[0, 0],
42+
A.shape[1],
43+
&B[0, 0],
44+
B.shape[1],
45+
0.0,
46+
&C[0, 0],
47+
C.shape[1]
48+
)
49+
return out
50+
51+
52+
cdef void sgemm(bint TransA, bint TransB, int M, int N, int K,
53+
float alpha, const float* A, int lda, const float *B,
54+
int ldb, float beta, float* C, int ldc) nogil:
55+
cblas_sgemm(
56+
CblasRowMajor,
57+
CblasTrans if TransA else CblasNoTrans,
58+
CblasTrans if TransB else CblasNoTrans,
59+
M,
60+
N,
61+
K,
62+
alpha,
63+
A,
64+
lda,
65+
B,
66+
ldb,
67+
beta,
68+
C,
69+
ldc
70+
)
71+
72+
73+
cdef void saxpy(int N, float alpha, const float* X, int incX,
74+
float *Y, int incY) nogil:
75+
cblas_saxpy(N, alpha, X, incX, Y, incY)

thinc/backends/apple_ops.pyx

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Optional
2+
3+
import numpy
4+
5+
from ._accelerate import gemm
6+
7+
from ._accelerate cimport saxpy, sgemm
8+
from .cblas cimport CBlas, set_saxpy, set_sgemm
9+
10+
from .. import registry
11+
from ..types import Floats2d
12+
from .numpy_ops import NumpyOps
13+
14+
15+
@registry.ops("AppleOps")
16+
class AppleOps(NumpyOps):
17+
"""Thinc Ops class that calls into Apple's native libraries for some
18+
operations. Other operations fall back to numpy."""
19+
name = "apple"
20+
xp = numpy
21+
22+
def cblas(self) -> CBlas:
23+
cdef CBlas cblas = CBlas()
24+
set_saxpy(cblas, saxpy)
25+
set_sgemm(cblas, sgemm)
26+
return cblas
27+
28+
def gemm(
29+
self,
30+
x: Floats2d,
31+
y: Floats2d,
32+
out: Optional[Floats2d] = None,
33+
trans1: bool = False,
34+
trans2: bool = False,
35+
) -> Floats2d:
36+
"""Perform General Matrix Multiplication (GeMM) and optionally store
37+
the result in the specified output variable.
38+
"""
39+
return gemm(x, y, out=out, trans1=trans1, trans2=trans2)

thinc/backends/mps_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy
44

55
from .. import registry
6+
from ..compat import has_apple_ops
67
from .numpy_ops import NumpyOps
78
from .ops import Ops
89

@@ -12,11 +13,11 @@
1213
# during type checking.
1314
_Ops = Ops
1415
else:
15-
try:
16-
from thinc_apple_ops import AppleOps
16+
if has_apple_ops:
17+
from .apple_ops import AppleOps
1718

1819
_Ops = AppleOps
19-
except ImportError:
20+
else:
2021
_Ops = NumpyOps
2122

2223

thinc/compat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import platform
12
import warnings
23

34
from packaging.version import Version
@@ -119,6 +120,9 @@ def enable_mxnet():
119120
has_blis = False
120121

121122

123+
# AppleOps is available unconditionally on macOS.
124+
has_apple_ops = platform.system() == "Darwin"
125+
122126
has_gpu = has_cupy_gpu or has_torch_mps_gpu
123127

124128
__all__ = [

thinc/tests/backends/_apple_blas/__init__.py

Whitespace-only changes.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy
2+
import pytest
3+
4+
from thinc.compat import has_apple_ops
5+
6+
try:
7+
import thinc.backends._accelerate as accelerate
8+
except:
9+
pass
10+
11+
12+
@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available")
13+
def test_basic_sgemm():
14+
A = numpy.random.randn(5, 4).astype("f")
15+
B = numpy.random.randn(4, 7).astype("f")
16+
C = accelerate.gemm(A, B)
17+
assert C.shape == (A.shape[0], B.shape[1])
18+
19+
C_out = numpy.empty((5, 7), dtype="f")
20+
accelerate.gemm(A, B, out=C_out)
21+
22+
numpy.testing.assert_allclose(C, C_out)
23+
24+
25+
@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available")
26+
def test_incorrect_output_size():
27+
A = numpy.ndarray((5, 4), dtype="f")
28+
B = numpy.ndarray((4, 7), dtype="f")
29+
30+
with pytest.raises(ValueError, match=r"Shape mismatch for output matrix"):
31+
accelerate.gemm(A, B, out=numpy.ndarray((3, 7), dtype="f"))
32+
33+
with pytest.raises(ValueError, match=r"Shape mismatch for output matrix"):
34+
accelerate.gemm(A, B, out=numpy.ndarray((5, 3), dtype="f"))
35+
36+
37+
@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available")
38+
@pytest.mark.parametrize(
39+
"A_shape,B_shape,transA,transB",
40+
[
41+
[(0, 0), (0, 0), False, False],
42+
[(0, 0), (0, 0), True, False],
43+
[(0, 0), (0, 0), False, True],
44+
[(0, 0), (0, 0), True, True],
45+
[(0, 5), (5, 0), False, False],
46+
[(5, 0), (5, 0), False, True],
47+
[(5, 0), (5, 0), True, False],
48+
],
49+
)
50+
def test_zero_size(A_shape, B_shape, transA, transB):
51+
A = numpy.ndarray(A_shape, dtype="f")
52+
B = numpy.ndarray(B_shape, dtype="f")
53+
if not transA and not transB:
54+
C = numpy.dot(A, B)
55+
elif transA:
56+
C = numpy.dot(A.T, B)
57+
elif transB:
58+
C = numpy.dot(A, B.T)
59+
else:
60+
C = numpy.dot(A.T, B.T)
61+
C_ = accelerate.gemm(A, B, trans1=transA, trans2=transB)
62+
assert C.shape == C_.shape
63+
64+
65+
@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available")
66+
@pytest.mark.parametrize(
67+
"A_shape,B_shape,transA,transB",
68+
[
69+
[(4, 5), (4, 5), False, False],
70+
[(5, 4), (4, 5), True, False],
71+
[(4, 5), (5, 4), False, True],
72+
[(5, 4), (5, 4), True, True],
73+
],
74+
)
75+
def test_incorrect_shapes(A_shape, B_shape, transA, transB):
76+
A = numpy.ndarray(A_shape, dtype="f")
77+
B = numpy.ndarray(B_shape, dtype="f")
78+
with pytest.raises(ValueError, match=r"Shape mismatch"):
79+
accelerate.gemm(A, B, trans1=transA, trans2=transB)

thinc/tests/backends/test_mps_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from thinc.api import NumpyOps, get_ops
2+
from thinc.compat import has_apple_ops
3+
4+
5+
def test_mps_ops_inherits_apple_ops():
6+
ops = get_ops("mps")
7+
assert isinstance(ops, NumpyOps)
8+
if has_apple_ops:
9+
# We can't import AppleOps directly, because its' not
10+
# available on non-Darwin systems.
11+
assert "AppleOps" in [base.__name__ for base in type(ops).__bases__]

0 commit comments

Comments
 (0)