I'm trying to define a plot class MyPlot, based on the abstract class GenericPlot, that creates 3 Matplotlib figures in parallel using pathos multiprocessing:
from attr import definefrom abc import abstractmethodfrom pathos.multiprocessing import ProcessPoolimport matplotlib.pyplot as pltimport numpy as npfrom matplotlib.backends.backend_agg import FigureCanvasAggimport dilldill.settings['recurse'] = Truedef generate_array(fig): canvas = FigureCanvasAgg(fig) canvas.draw() figure_array = np.asarray(canvas.buffer_rgba()) return figure_arraydef parallel_process_function(args): function, data = args return function(*data)@defineclass GenericPlot(): x: np.ndarray @abstractmethod def plot_series(self): ... @abstractmethod def generate_figures(self): ... def plot_figures(self): self.plot_series() return self.generate_figures()@define(slots=False)class MyPlot(GenericPlot): def plot_series(self): self.series1 = np.sin(self.x) self.series2 = np.cos(self.x) self.series3 = np.tan(self.x) def create_figure1(self, a): fig, axes = plt.subplots(1, 1, figsize=(12, 3)) axes.plot(self.series1*a) axes.plot(self.series2) plt.close("all") return fig def create_figure2(self, a): fig, axes = plt.subplots(1, 1, figsize=(12, 3)) axes.plot(self.series2*a) axes.plot(self.series3) plt.close("all") return fig def create_figure3(self, a): fig, axes = plt.subplots(1, 1, figsize=(12, 3)) axes.plot(self.series3*a) axes.plot(self.series1) plt.close("all") return fig def create_figure_array(self, create_figure_method, a): fig = create_figure_method(a) figure_array = generate_array(fig) return fig, figure_array def parallel_create_figure_array(self, create_figure_method, a): return self.create_figure_array(create_figure_method, a) def generate_figures(self): pool = ProcessPool(nodes=3) functions = [self.parallel_create_figure_array, self.parallel_create_figure_array, self.parallel_create_figure_array] create_figure_methods = [self.create_figure1, self.create_figure2, self.create_figure3] a = [1, 2, 3] data = [(method, a) for method, a in zip(create_figure_methods, a)] results = pool.map(parallel_process_function, zip(functions, data)) figures, image_arrays = zip(*results) return figures, image_arraysmy_plotter = MyPlot(x=np.linspace(0, 10, 100))figures, image_arrays = my_plotter.plot_figures()The code runs successfully without abstract class, i.e. if I only define
@define(slots=False)class MyPlot(GenericPlot): x: np.ndarray ...and then run
my_plotter = MyPlot(x=np.linspace(0, 10, 100))my_plotter.plot_series()figures, image_arrays = my_plotter.generate_figures()but as soon as I attempt to use the abstract class (yes, I need it even though in this example it is not necessary), the pool.map function throws AttributeError: 'MyPlot' object has no attribute 'series1'.
Any idea why the multiprocessing pool doesn't correctly parse the class attributes?
I'm very new to the use of multiprocessing, so I may be missing something very obvious...Thanks in advance for your help