diff --git a/cadquery/occ_impl/shapes.py b/cadquery/occ_impl/shapes.py index 040712235..4abada27b 100644 --- a/cadquery/occ_impl/shapes.py +++ b/cadquery/occ_impl/shapes.py @@ -756,6 +756,35 @@ def CombinedCenter(objects: Iterable["Shape"]) -> Vector: return Vector(sum_wc.multiply(1.0 / total_mass)) + @staticmethod + def _mass_calc_function(obj: "Shape") -> Any: + """ + Helper to find the correct mass calculation function with special compound handling. + """ + + type_ = shapetype(obj.wrapped) + + # special handling of compounds - first non-compound child is assumed to define the type of the operation + if type_ == ta.TopAbs_COMPOUND: + + # if the compound is not empty check its children + if obj: + # first child + child = next(iter(obj)) + + # if compound, go deeper + while child.ShapeType() == "Compound": + child = next(iter(child)) + + type_ = shapetype(child.wrapped) + + # if the compound is empty assume it was meant to be a solid + else: + type_ = ta.TopAbs_SOLID + + # get the function based on dimensionality of the object + return shape_properties_LUT[type_] + @staticmethod def computeMass(obj: "Shape") -> float: """ @@ -764,13 +793,11 @@ def computeMass(obj: "Shape") -> float: :param obj: Compute the mass of this object """ Properties = GProp_GProps() - calc_function = shape_properties_LUT[shapetype(obj.wrapped)] + calc_function = Shape._mass_calc_function(obj) - if calc_function: - calc_function(obj.wrapped, Properties) - return Properties.Mass() - else: - raise NotImplementedError + calc_function(obj.wrapped, Properties) + + return Properties.Mass() @staticmethod def centerOfMass(obj: "Shape") -> Vector: @@ -780,13 +807,11 @@ def centerOfMass(obj: "Shape") -> Vector: :param obj: Compute the center of mass of this object """ Properties = GProp_GProps() - calc_function = shape_properties_LUT[shapetype(obj.wrapped)] + calc_function = Shape._mass_calc_function(obj) - if calc_function: - calc_function(obj.wrapped, Properties) - return Vector(Properties.CentreOfMass()) - else: - raise NotImplementedError + calc_function(obj.wrapped, Properties) + + return Vector(Properties.CentreOfMass()) @staticmethod def CombinedCenterOfBoundBox(objects: List["Shape"]) -> Vector: diff --git a/tests/test_cadquery.py b/tests/test_cadquery.py index bc4ab87e4..86b919190 100644 --- a/tests/test_cadquery.py +++ b/tests/test_cadquery.py @@ -5904,3 +5904,16 @@ def test_loft_to_vertex(self): # in both cases we get a solid assert w1.solids().size() == 1 assert w2.solids().size() == 1 + + def test_compound_faces_center(self): + sk = Sketch().rect(50, 50).faces() + face1 = sk.val() + face2 = face1.copy().translate(Vector(100, 0, 0)) + compound = Compound.makeCompound([face1, face2]) + expected_center = Shape.CombinedCenter([face1, face2]) + + assert ( + compound.Center() == expected_center + ), "Incorrect center of mass of the compound, expected {}, got {}".format( + expected_center, compound.Center() + )