Skip to content

Commit 22401b8

Browse files
suofacebook-github-bot
authored andcommitted
port all JIT tests to gtest (pytorch#45264)
Summary: Pull Request resolved: pytorch#45264 Context for why we are porting to gtest in: pytorch#45018. This PR completes the process of porting and removes unused files/macros. Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D23901392 Pulled By: suo fbshipit-source-id: 89526890e1a49462f3f77718f4ee273c5bc578ba
1 parent 5a0514e commit 22401b8

31 files changed

+1336
-1638
lines changed

aten/src/ATen/test/thread_init_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include <ATen/ATen.h>
22
#include <ATen/Parallel.h>
3-
#include <test/cpp/jit/test_base.h>
3+
#include <test/cpp/tensorexpr/test_base.h>
44
#include <thread>
55

66

test/cpp/jit/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@ endif()
1919

2020
# Build the cpp gtest binary containing the cpp-only tests.
2121
set(JIT_TEST_SRCS
22-
${JIT_TEST_ROOT}/gtest.cpp
2322
${JIT_TEST_ROOT}/test_alias_analysis.cpp
2423
${JIT_TEST_ROOT}/test_argument_spec.cpp
2524
${JIT_TEST_ROOT}/test_autodiff.cpp
26-
${JIT_TEST_ROOT}/test_base.cpp
27-
${JIT_TEST_ROOT}/test_base.h
2825
${JIT_TEST_ROOT}/test_class_import.cpp
2926
${JIT_TEST_ROOT}/test_class_parser.cpp
3027
${JIT_TEST_ROOT}/test_class_type.cpp

test/cpp/jit/README.md

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,44 @@
11
# JIT C++ Tests
22

3-
## How to add a new test
3+
## Adding a new test
44
First, create a new test file. Test files should have be placed in this
55
directory, with a name that starts with `test_`, like `test_foo.cpp`.
66

7-
Here is an example test file you can copy-paste.
7+
In general a single test suite
8+
9+
Add your test file to the `JIT_TEST_SRCS` list in `test/cpp/jit/CMakeLists.txt`.
10+
11+
A test file may look like:
812
```cpp
9-
#include <test/cpp/jit/test_base.h>
13+
#include <gtest/gtest.h>
1014

11-
// Tests go in torch::jit
12-
namespace torch {
13-
namespace jit {
15+
using namespace ::torch::jit
1416

15-
// 1. Test cases are void() functions.
16-
// 2. They start with the prefix `test`
17-
void testCaseOne() {
18-
// ...
17+
TEST(FooTest, BarBaz) {
18+
// ...
1919
}
2020

21-
void testCaseTwo() {
22-
// ...
23-
}
24-
}
21+
// Append '_CUDA' to the test case name will automatically filter it out if CUDA
22+
// is not compiled.
23+
TEST(FooTest, NeedsAGpu_CUDA) {
24+
// ...
2525
}
26-
```
2726

28-
Then, register your test in `tests.h`:
29-
```cpp
30-
// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests
31-
#define TH_FORALL_TESTS(_) \
32-
_(ADFormulas) \
33-
_(Attributes) \
34-
...
35-
_(CaseOne) // note that the `test` prefix is omitted.
36-
_(CaseTwo)
37-
```
38-
39-
We glob all the test files together in `CMakeLists.txt` so that you don't
40-
have to edit it every time you add a test. Unfortunately, this means that in
41-
order to get the build to pick up your new test file, you need to re-run
42-
cmake:
43-
```
44-
python setup.py build --cmake
27+
// Similarly, if only one GPU is detected, tests with `_MultiCUDA` at the end
28+
// will not be run.
29+
TEST(FooTest, NeedsMultipleGpus_MultiCUDA) {
30+
// ...
31+
}
4532
```
4633
47-
## Why do we have two different test runners?
48-
We have two different ways of running our cpp tests:
49-
1. With `gtest`, from a standalone binary.
50-
2. With Python, from `TestJit.test_cpp` and `TestJit.test_cpp_cuda` (in
51-
`test/test_jit.py`)
52-
53-
We want both because we need to test things from a pure-C++ environment and
54-
with all our various Python patch-points enabled.
55-
56-
## How do I run the tests?
34+
## Building and running the tests
5735
The following commands assume you are in PyTorch root.
5836
59-
1. With `gtest`:
60-
```bash
61-
# (re)build the test binary
62-
ninja build/bin/test_jit
63-
# run
64-
build/bin/test_jit --gtest_filter='glob_style_filter*'
65-
```
66-
2. With Python:
67-
```
68-
python test/test_jit.py TestJit.test_cpp TestJit.test_cpp_cuda
69-
```
37+
```bash
38+
# ... Build PyTorch from source, e.g.
39+
python setup.py develop
40+
# (re)build just the binary
41+
ninja -C build bin/test_jit
42+
# run tests
43+
build/bin/test_jit --gtest_filter='glob_style_filter*'
44+
```

test/cpp/jit/gtest.cpp

Lines changed: 0 additions & 23 deletions
This file was deleted.

test/cpp/jit/test_base.cpp

Lines changed: 0 additions & 26 deletions
This file was deleted.

test/cpp/jit/test_base.h

Lines changed: 0 additions & 47 deletions
This file was deleted.

test/cpp/jit/test_class_parser.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <gtest/gtest.h>
22

3-
#include <test/cpp/jit/test_base.h>
43
#include <torch/csrc/jit/frontend/parser.h>
54
#include <torch/csrc/jit/frontend/resolver.h>
65

test/cpp/jit/test_class_type.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
#include <test/cpp/jit/test_base.h>
1+
#include <gtest/gtest.h>
2+
23
#include <test/cpp/jit/test_utils.h>
34
#include <torch/torch.h>
45

56
namespace torch {
67
namespace jit {
78

8-
void testClassTypeAddRemoveAttr() {
9+
TEST(ClassTypeTest, AddRemoveAttr) {
910
auto cu = std::make_shared<CompilationUnit>();
1011
auto cls = ClassType::create("foo.bar", cu, true);
1112
cls->addAttribute("attr1", TensorType::get(), true);
@@ -32,12 +33,12 @@ void testClassTypeAddRemoveAttr() {
3233
cls->addAttribute("attr1", IntType::get());
3334
}
3435

35-
void testClassTypeAddRemoveConstant() {
36+
TEST(ClassTypeTest, AddRemoveConstant) {
3637
auto cu = std::make_shared<CompilationUnit>();
3738
auto cls = ClassType::create("foo.bar", cu);
3839
cls->addConstant("const1", IValue(1));
3940
cls->addConstant("const2", IValue(2));
40-
cls->addConstant("const3", IValue(2));
41+
cls->addConstant("const3", IValue(3));
4142
ASSERT_EQ(cls->numConstants(), 3);
4243
ASSERT_TRUE(cls->hasConstant("const1"));
4344
ASSERT_TRUE(cls->hasConstant("const2"));
@@ -46,7 +47,7 @@ void testClassTypeAddRemoveConstant() {
4647

4748
ASSERT_EQ(cls->getConstant("const1").toInt(), 1);
4849
ASSERT_EQ(cls->getConstant("const2").toInt(), 2);
49-
ASSERT_EQ(cls->getConstant("const2").toInt(), 3);
50+
ASSERT_EQ(cls->getConstant("const3").toInt(), 3);
5051

5152
cls->unsafeRemoveConstant("const2");
5253
ASSERT_TRUE(cls->hasConstant("const1"));

0 commit comments

Comments
 (0)