Skip to content

Commit a66625d

Browse files
swolchokpytorchmergebot
authored andcommitted
[PyTorch] Optimize DictType::annotation_str_impl (pytorch#96498)
stringstream construction is expensive, and we can exactly reserve space for the output string while doing the same number of string copies. (If we wanted to improve performance further, we could introduce annotation_str_out to append the output to a given std::string and thus avoid copying subtype annotation strings. It occurs to me that the existing approach is quadratic in the number of layers of nesting, so we should probably do this!) Differential Revision: [D43919651](https://our.internmc.facebook.com/intern/diff/D43919651/) Pull Request resolved: pytorch#96498 Approved by: https://github.com/Skylion007
1 parent 000cfeb commit a66625d

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

aten/src/ATen/core/jit_type.h

+1-6
Original file line numberDiff line numberDiff line change
@@ -999,12 +999,7 @@ struct TORCH_API DictType : public SharedType {
999999
types.push_back(std::move(value));
10001000
}
10011001

1002-
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1003-
std::stringstream ss;
1004-
ss << "Dict[" << getKeyType()->annotation_str(printer) << ", ";
1005-
ss << getValueType()->annotation_str(std::move(printer)) << "]";
1006-
return ss.str();
1007-
}
1002+
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
10081003

10091004
std::vector<TypePtr> types;
10101005
bool has_free_variables;

aten/src/ATen/core/type.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,21 @@ TypePtr DictType::get(std::string identifier, TypePtr key, TypePtr value) {
304304
return containerTypePtrs[map_key];
305305
}
306306

307+
std::string DictType::annotation_str_impl(TypePrinter printer) const {
308+
auto keyAnnotation = getKeyType()->annotation_str(printer);
309+
auto valueAnnotation = getValueType()->annotation_str(std::move(printer));
310+
311+
std::string result;
312+
result.reserve(5 /* "Dict[" */ + keyAnnotation.size() + 2 /* ", " */ + valueAnnotation.size() + 1 /* "]" */);
313+
result = "Dict[";
314+
result += keyAnnotation;
315+
result.push_back(',');
316+
result.push_back(' ');
317+
result += valueAnnotation;
318+
result.push_back(']');
319+
return result;
320+
}
321+
307322
AnyListTypePtr AnyListType::get() {
308323
static AnyListTypePtr value(new AnyListType());
309324
return value;

test/test_jit.py

+4
Original file line numberDiff line numberDiff line change
@@ -11528,6 +11528,10 @@ def test_tuple_str(self):
1152811528
tuple2_type = torch._C.TupleType([torch._C.StringType.get(), torch._C.StringType.get()])
1152911529
self.assertEqual(tuple2_type.annotation_str, "Tuple[str, str]")
1153011530

11531+
def test_dict_str(self):
11532+
dict_type = torch._C.DictType(torch._C.StringType.get(), torch._C.StringType.get())
11533+
self.assertEqual(dict_type.annotation_str, "Dict[str, str]")
11534+
1153111535
def test_none_type_str(self):
1153211536
none_type = torch._C.NoneType.get()
1153311537
g = {'NoneType' : type(None)}

0 commit comments

Comments
 (0)