Skip to content

Commit 73571fe

Browse files
committed
Basic matplotlib output for scatterplots
1 parent 371a193 commit 73571fe

File tree

4 files changed

+177
-1
lines changed

4 files changed

+177
-1
lines changed

Orange/widgets/io.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010

1111
from Orange.data.io import FileFormat
12+
from Orange.widgets.utils.matplotlib_export import scene_code
1213

1314
# Importing WebviewWidget can fail if neither QWebKit (old, deprecated) nor
1415
# QWebEngine (bleeding-edge, hard to install) are available
@@ -181,6 +182,35 @@ def write_image(cls, filename, scene):
181182
super().write_image(filename, scene)
182183

183184

185+
class MatplotlibFormat(FileFormat):
186+
EXTENSIONS = ('.py',)
187+
DESCRIPTION = 'Python Code (with Matplotlib)'
188+
PRIORITY = 300
189+
190+
@classmethod
191+
def write_image(cls, filename, scene):
192+
code = scene_code(scene) + "\n\nplt.show()"
193+
with open(filename, "wt") as f:
194+
f.write(code)
195+
196+
@classmethod
197+
def write(cls, filename, scene):
198+
if type(scene) == dict:
199+
scene = scene['scene']
200+
cls.write_image(filename, scene)
201+
202+
203+
class MatplotlibPDFFormat(MatplotlibFormat):
204+
EXTENSIONS = ('.matplotlib.pdf',) # file formats with same extension are not supported
205+
DESCRIPTION = 'Portable Document Format (from Matplotlib)'
206+
PRIORITY = 200
207+
208+
@classmethod
209+
def write_image(cls, filename, scene):
210+
code = scene_code(scene) + "\n\nplt.savefig({})".format(repr(filename))
211+
exec(code, {}) # will generate a pdf
212+
213+
184214
if hasattr(QtGui, "QPdfWriter"):
185215
class PdfFormat(ImgFormat):
186216
EXTENSIONS = ('.pdf', )

Orange/widgets/tests/test_io.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import os
22
import tempfile
33
import unittest
4+
from unittest.mock import patch
45

56
from AnyQt.QtWidgets import QGraphicsScene, QGraphicsRectItem
6-
from Orange.widgets.tests.base import GuiTest
7+
8+
import Orange
9+
from Orange.tests import named_file
10+
from Orange.widgets.tests.base import GuiTest, WidgetTest
711

812
from Orange.widgets import io as imgio
13+
from Orange.widgets.io import MatplotlibFormat, MatplotlibPDFFormat
14+
from Orange.widgets.visualize.owscatterplot import OWScatterPlot
15+
916

1017
@unittest.skipUnless(hasattr(imgio, "PdfFormat"), "QPdfWriter not available")
1118
class TestIO(GuiTest):
@@ -18,3 +25,33 @@ def test_pdf(self):
1825
imgio.PdfFormat.write_image(fname, sc)
1926
finally:
2027
os.unlink(fname)
28+
29+
30+
class TestMatplotlib(WidgetTest):
31+
32+
def test_python(self):
33+
iris = Orange.data.Table("iris")
34+
self.widget = self.create_widget(OWScatterPlot)
35+
self.send_signal(OWScatterPlot.Inputs.data, iris[::10])
36+
with named_file("", suffix=".py") as fname:
37+
with patch("Orange.widgets.utils.filedialogs.open_filename_dialog_save",
38+
lambda *x: (fname, MatplotlibFormat, None)):
39+
self.widget.save_graph()
40+
with open(fname, "rt") as f:
41+
code = f.read()
42+
self.assertIn("plt.show()", code)
43+
self.assertIn("plt.scatter", code)
44+
# test if the runs
45+
exec(code.replace("plt.show()", ""), {})
46+
47+
def test_pdf(self):
48+
iris = Orange.data.Table("iris")
49+
self.widget = self.create_widget(OWScatterPlot)
50+
self.send_signal(OWScatterPlot.Inputs.data, iris[::10])
51+
with named_file("", suffix=".pdf") as fname:
52+
with patch("Orange.widgets.utils.filedialogs.open_filename_dialog_save",
53+
lambda *x: (fname, MatplotlibPDFFormat, None)):
54+
self.widget.save_graph()
55+
with open(fname, "rb") as f:
56+
code = f.read()
57+
self.assertTrue(code.startswith(b"%PDF"))
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from itertools import chain
2+
3+
import numpy as np
4+
5+
from pyqtgraph.graphicsItems.ScatterPlotItem import ScatterPlotItem
6+
7+
8+
def numpy_repr(a):
9+
""" A numpy repr without summarization """
10+
opts = np.get_printoptions()
11+
try:
12+
np.set_printoptions(threshold=10**10)
13+
return repr(a)
14+
finally:
15+
np.set_printoptions(**opts)
16+
17+
18+
def scatterplot_code(scatterplot_item):
19+
x = scatterplot_item.data['x']
20+
y = scatterplot_item.data['y']
21+
sizes = scatterplot_item.data["size"]
22+
23+
code = []
24+
25+
code.append("# data")
26+
code.append("x = {}".format(numpy_repr(x)))
27+
code.append("y = {}".format(numpy_repr(y)))
28+
29+
code.append("# style")
30+
code.append("sizes = {}".format(numpy_repr(sizes)))
31+
32+
def colortuple(color):
33+
return color.redF(), color.greenF(), color.blueF(), color.alphaF()
34+
35+
edgecolors = np.array([colortuple(a.color()) for a in scatterplot_item.data["pen"]])
36+
facecolors = np.array([colortuple(a.color()) for a in scatterplot_item.data["brush"]])
37+
38+
code.append("edgecolors = {}".format(numpy_repr(edgecolors)))
39+
code.append("facecolors = {}".format(numpy_repr(facecolors)))
40+
41+
# possible_markers for scatterplot are in .graph.CurveSymbols
42+
def matplotlib_marker(m):
43+
if m == "t":
44+
return "^"
45+
elif m == "t2":
46+
return ">"
47+
elif m == "t3":
48+
return "<"
49+
elif m == "star":
50+
return "*"
51+
elif m == "+":
52+
return "P"
53+
elif m == "x":
54+
return "X"
55+
return m
56+
57+
# TODO labels are missing
58+
59+
# each marker requires one call to matplotlib's scatter!
60+
markers = np.array([matplotlib_marker(m) for m in scatterplot_item.data["symbol"]])
61+
for m in set(markers):
62+
indices = np.where(markers == m)[0]
63+
if np.all(indices == np.arange(x.shape[0])):
64+
# indices are unused
65+
code.append("plt.scatter(x=x, y=y, s=sizes**2/4, marker={},".format(repr(m)))
66+
code.append(" facecolors=facecolors, edgecolors=edgecolors)")
67+
else:
68+
code.append("indices = {}".format(numpy_repr(indices)))
69+
code.append("plt.scatter(x=x[indices], y=y[indices], s=sizes[indices]**2/4, "
70+
"marker={},".format(repr(m)))
71+
code.append(" facecolors=facecolors[indices], "
72+
"edgecolors=edgecolors[indices])")
73+
74+
return "\n".join(code)
75+
76+
77+
def scene_code(scene):
78+
79+
code = []
80+
81+
code.append("import matplotlib.pyplot as plt")
82+
code.append("from numpy import array")
83+
84+
code.append("")
85+
code.append("plt.clf()")
86+
87+
code.append("")
88+
89+
for item in scene.items:
90+
if isinstance(item, ScatterPlotItem):
91+
code.append(scatterplot_code(item))
92+
93+
# TODO currently does not work for graphs without axes and for multiple axes!
94+
for position, set_ticks, set_label in [("bottom", "plt.xticks", "plt.xlabel"),
95+
("left", "plt.yticks", "plt.ylabel")]:
96+
axis = scene.getAxis(position)
97+
code.append("{}({})".format(set_label, repr(str(axis.labelText))))
98+
99+
# textual tick labels
100+
if axis._tickLevels is not None:
101+
major_minor = list(chain(*axis._tickLevels))
102+
locs = [a[0] for a in major_minor]
103+
labels = [a[1] for a in major_minor]
104+
code.append("{}({}, {})".format(set_ticks, locs, repr(labels)))
105+
106+
return "\n".join(code)
107+
108+

requirements-gui.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
AnyQt>=0.0.8
33

44
pyqtgraph>=0.10.0
5+
matplotlib>=2.0.0
56

67
# For add-ons' descriptions
78
docutils

0 commit comments

Comments
 (0)