Skip to content

Commit 9349296

Browse files
committed
Update tests, and array validation logic.
1 parent ad0da18 commit 9349296

File tree

8 files changed

+67
-15
lines changed

8 files changed

+67
-15
lines changed

arrayfire.cabal

+2
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ test-suite test
138138
directory,
139139
hspec,
140140
hspec-discover,
141+
QuickCheck,
142+
quickcheck-classes,
141143
vector
142144
default-language:
143145
Haskell2010

default.nix

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
{ pkgs ? import <nixpkgs> { config.allowUnfree = true; } }:
22
# Latest arrayfire is not yet procured w/ nix.
33
let
4-
af = pkgs.callPackage ./nix {};
5-
pkg = pkgs.haskellPackages.callCabal2nix "arrayfire" ./. { inherit af; };
4+
pkg = pkgs.haskellPackages.callCabal2nix "arrayfire" ./. {
5+
af = null;
6+
quickcheck-classes = pkgs.haskellPackages.quickcheck-classes_0_6_4_0;
7+
};
68
in
7-
with pkgs.haskell.lib;
8-
enableCabalFlag pkg "disable-default-paths"
9+
pkg

src/ArrayFire/Array.hs

+11-3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
module ArrayFire.Array where
3838

3939
import Control.Exception
40+
import Control.Monad
4041
import Data.Proxy
4142
import Data.Vector.Storable hiding (mapM_, take, concat, concatMap)
4243
import qualified Data.Vector.Storable as V
@@ -106,7 +107,7 @@ cube (x,y,z)
106107
. concat
107108
. fmap concat
108109
. take x
109-
. (fmap . take) y
110+
. fmap (take y)
110111
. (fmap . fmap . take) z
111112

112113
-- | Smart constructor for creating a tensor 'Array'
@@ -158,7 +159,14 @@ mkArray
158159
-- ^ Returned array
159160
{-# NOINLINE mkArray #-}
160161
mkArray dims xs =
161-
unsafePerformIO . mask_ $ do
162+
unsafePerformIO $ do
163+
when (Prelude.length (take size xs) < size) $ do
164+
let msg = "Invalid elements provided. "
165+
<> "Expected "
166+
<> show size
167+
<> " elements received "
168+
<> show (Prelude.length xs)
169+
throwIO (AFException SizeError 203 msg)
162170
dataPtr <- castPtr <$> newArray (Prelude.take size xs)
163171
let ndims = fromIntegral (Prelude.length dims)
164172
alloca $ \arrayPtr -> do
@@ -169,7 +177,7 @@ mkArray dims xs =
169177
arr <- peek arrayPtr
170178
Array <$> newForeignPtr af_release_array_finalizer arr
171179
where
172-
size = Prelude.product (fromIntegral <$> dims)
180+
size = Prelude.product dims
173181
dType = afType (Proxy @ array)
174182

175183
-- af_err af_create_handle(af_array *arr, const unsigned ndims, const dim_t * const dims, const af_dtype type);

src/ArrayFire/Data.hs

+5
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,11 @@ moddims (Array fptr) dims =
428428
-- ArrayFire Array
429429
-- [4 1 1 1]
430430
-- 1.0000 2.0000 1.0000 2.0000
431+
--
432+
-- >>> flat $ cube @Int (2,2,2) [[[1,1],[1,1]],[[1,1],[1,1]]]
433+
-- ArrayFire Array
434+
-- [8 1 1 1]
435+
-- 1 1 1 1 1 1 1 1
431436
flat
432437
:: Array a
433438
-> Array a

test/ArrayFire/ArraySpec.hs

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ spec =
4545
setManualEvalFlag False
4646
(`shouldBe` False) =<< getManualEvalFlag
4747
it "Should return the number of elements" $ do
48-
let arr = mkArray @Int [9,9,1,1] []
48+
let arr = mkArray @Int [9,9,1,1] [1..]
4949
getElements arr `shouldBe` 81
5050
-- it "Should give an empty array" $ do
5151
-- let arr = mkArray @Int [-1,1,1,1] []
@@ -57,9 +57,9 @@ spec =
5757
it "Should get number of dims specified" $ do
5858
let arr = mkArray @Int [1,1,1,1] [1]
5959
getNumDims arr `shouldBe` 1
60-
let arr = mkArray @Int [2,3,4,5] [1]
60+
let arr = mkArray @Int [2,3,4,5] [1..]
6161
getNumDims arr `shouldBe` 4
62-
let arr = mkArray @Int [2,3,4] [1]
62+
let arr = mkArray @Int [2,3,4] [1..]
6363
getNumDims arr `shouldBe` 3
6464
it "Should get value of dims specified" $ do
6565
let arr = mkArray @Int [2,3,4,5] (repeat 1)

test/ArrayFire/LAPACKSpec.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ spec =
1111
it "Should have LAPACK available" $ do
1212
A.isLAPACKAvailable `shouldBe` True
1313
it "Should perform svd" $ do
14-
let (s,v,d) = A.svd $ A.matrix @Double (4,2) [[1..],[1..]]
14+
let (s,v,d) = A.svd $ A.matrix @Double (4,2) [ [1,2], [3,4], [5,6], [7,8] ]
1515
A.getDims s `shouldBe` (4,4,1,1)
1616
A.getDims v `shouldBe` (2,1,1,1)
1717
A.getDims d `shouldBe` (2,2,1,1)

test/ArrayFire/SignalSpec.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ spec :: Spec
1313
spec =
1414
describe "Signal spec" $ do
1515
it "Should do FFT in place" $ do
16-
A.fftInPlace (A.matrix @(Complex Double) (10,10) [[1 :+ 1]]) 10.2
16+
A.fftInPlace (A.matrix @(Complex Double) (1,1) [[1 :+ 1]]) 10.2
1717
`shouldReturn` ()
1818
it "Should do FFT" $ do
1919
A.fft (A.matrix @(Complex Float) (1,1) [[1 :+ 1]]) 1 1

test/Main.hs

+39-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,43 @@
1+
{-# LANGUAGE TypeApplications #-}
2+
{-# LANGUAGE ScopedTypeVariables #-}
13
module Main where
24

3-
import Spec (spec)
4-
import Test.Hspec (hspec)
5+
import Control.Monad
6+
7+
import Data.Proxy
8+
import Spec (spec)
9+
import Test.Hspec (hspec)
10+
import Test.QuickCheck
11+
import Test.QuickCheck.Classes
12+
13+
import qualified ArrayFire as A
14+
import ArrayFire (Array)
15+
16+
import System.IO.Unsafe
17+
18+
instance (A.AFType a, Arbitrary a) => Arbitrary (Array a) where
19+
arbitrary = pure $ unsafePerformIO (A.randu [2,2])
520

621
main :: IO ()
7-
main = hspec spec
22+
main = do
23+
-- checks (Proxy :: Proxy (A.Array (A.Complex Float)))
24+
-- checks (Proxy :: Proxy (A.Array (A.Complex Double)))
25+
-- checks (Proxy :: Proxy (A.Array Double))
26+
-- checks (Proxy :: Proxy (A.Array Float))
27+
-- checks (Proxy :: Proxy (A.Array Double))
28+
-- checks (Proxy :: Proxy (A.Array A.Int16))
29+
-- checks (Proxy :: Proxy (A.Array A.Int32))
30+
-- checks (Proxy :: Proxy (A.Array A.CBool))
31+
-- checks (Proxy :: Proxy (A.Array Word))
32+
-- checks (Proxy :: Proxy (A.Array A.Word8))
33+
-- checks (Proxy :: Proxy (A.Array A.Word16))
34+
-- checks (Proxy :: Proxy (A.Array A.Word32))
35+
-- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Double))
36+
-- lawsCheck $ semigroupLaws (Proxy :: Proxy (A.Array Float))
37+
hspec spec
38+
39+
checks proxy = do
40+
lawsCheck (numLaws proxy)
41+
lawsCheck (eqLaws proxy)
42+
lawsCheck (ordLaws proxy)
43+
-- lawsCheck (semigroupLaws proxy)

0 commit comments

Comments
 (0)