import numpy
from optimeed.core.tools import printIfShown, SHOW_WARNING
[docs]class Data:
"""This class is used to store informations necessary to plot a 2D graph. It has to be combined with a gui to be useful (ex. pyqtgraph)"""
def __init__(self, x: list, y: list, x_label='', y_label='',
legend='', is_scattered=False,
transfo_x=lambda selfData, x: x, transfo_y=lambda selfData, y: y,
xlim=None, ylim=None,
permutations=None, sort_output=False,
color=None,
symbol='o', symbolsize=8, fillsymbol=True, outlinesymbol=1.8,
linestyle='-', width=2):
"""
:param x_label: label of x axis
:param x: x (list) coordinates. If None => it will be the indices
:param y_label: label of y axis
:param y: y (list) coordinates.
:param color: Color to use for the trace (either rgb tuple or matlab-like diminutif (e.g.:'r')
:param legend: legend associated with the data
:param linestyle: style of the line linking points of the trace
:param is_scattered: boolean. If True: plot will be scattered
:param sort_output: boolean. If True: x axis will be sorted to ascending number
:param symbol: symbol to display the points of the trace (e.g.: 'o', 't1', ...)
:param fillsymbol: paint a solid symbol or not (bool)
:param outlinesymbol: make the outline of the symbol darker (>1) or clearer (<1). No outline = 1
:param transfo_x: Method applied to all the x coordinates. Useful to change units for instance
:param transfo_y: Method applied to all the y coordinates. Useful to change units for instance
:param xlim: [x_min, x_max] view box
:param ylim: [y_min, y_max] view box
"""
if x is None:
x = []
if y is None:
y = []
self.x = x
self.x_label = x_label
self.y = y
self.y_label = y_label
self.legend = legend
self.isScattered = is_scattered
self.sort_output = sort_output
self.symbol = symbol
self.fillsymbol = fillsymbol
self.outlineSymbol = outlinesymbol
self.color = color
self.transfo_x = transfo_x
self.transfo_y = transfo_y
self.xlim = xlim
self.ylim = ylim
self.points_to_plot = None # None Value => plot all, [] => plot None
self.linestyle = linestyle
self.permutations = None
self.set_permutations(permutations)
self.width = width
self.symbolsize = symbolsize
[docs] def set_data(self, x: list, y: list):
"""Overwrites current datapoints with new set"""
self.x = x
self.y = y
[docs] def get_x(self):
"""Get x coordinates of datapoints"""
the_list = self.x
if len(self.y) != len(self.x):
the_list = range(len(self.y))
return the_list
[docs] def get_symbolsize(self):
"""Get size of the symbols"""
return self.symbolsize
[docs] def symbol_isfilled(self):
"""Check if symbols has to be filled or not"""
return self.fillsymbol
[docs] def get_symbolOutline(self):
"""Get color factor of outline of symbols"""
return self.outlineSymbol
[docs] def get_length_data(self):
"""Get number of points"""
return max(len(self.x), len(self.y))
[docs] def get_xlim(self):
"""Get x limits of viewbox"""
return self.xlim
[docs] def get_ylim(self):
"""Get y limits of viewbox"""
return self.ylim
[docs] def get_y(self):
"""Get y coordinates of datapoints"""
return self.y
[docs] def get_color(self):
"""Get color of the line"""
return self.color
[docs] def get_width(self):
"""Get width of the line"""
return self.width
[docs] def get_number_of_points(self):
"""Get number of points"""
return len(self.get_permutations())
[docs] def get_plot_data(self):
"""
Call this method to get the x and y coordinates of the points that have to be displayed.
=> After transformation, and after permutations.
:return: x (list), y (list)
"""
x = self.get_x()
y = self.get_y()
x = [self.transfo_x(self, i) for i in x]
y = [self.transfo_y(self, i) for i in y]
permutations = self.get_permutations(x=x)
x, y = [x[perm_i] for perm_i in permutations], [y[perm_i] for perm_i in permutations]
min_length = min(len(x), len(y))
return x[:min_length], y[:min_length]
[docs] def get_permutations(self, x=None):
"""Return the transformation 'permutation':
xplot[i] = xdata[permutation[i]]
"""
indices_to_plot = self.get_indices_points_to_plot()
if x is None:
x = self.get_x()
if self.permutations is not None:
permutations = self.permutations
elif self.sort_output:
permutations = list(numpy.argsort(x))
else:
permutations = range(len(x))
if len(indices_to_plot) == self.get_length_data():
return permutations
return [permutations[index] for index in indices_to_plot]
[docs] def get_invert_permutations(self):
"""Return the inverse of permutations:
xdata[i] = xplot[revert[i]]
"""
return numpy.argsort(self.get_permutations())
[docs] def get_dataIndex_from_graphIndex(self, index_graph_point):
"""
From an index given in graph, recovers the index of the data.
:param index_graph_point: Index in the graph
:return: index of the data
"""
return self.get_permutations()[index_graph_point]
[docs] def get_dataIndices_from_graphIndices(self, index_graph_point_list):
"""
Same as get_dataIndex_from_graphIndex but with a list in entry.
Can (?) improve performances for huge dataset.
:param index_graph_point_list: List of Index in the graph
:return: List of index of the data
"""
permutations = self.get_permutations()
return [permutations[index] for index in index_graph_point_list]
[docs] def get_graphIndex_from_dataIndex(self, index_data):
"""
From an index given in the data, recovers the index of the graph.
:param index_data: Index in the data
:return: index of the graph
"""
return self.get_permutations().index(index_data)
[docs] def get_graphIndices_from_dataIndices(self, index_data_list):
"""
Same as get_graphIndex_from_dataIndex but with a list in entry.
Can (?) improve performances for huge dataset.
:param index_data_list: List of Index in the data
:return: List of index of the graph
"""
invert_permutation = self.get_invert_permutations()
return [invert_permutation[index] for index in index_data_list]
[docs] def set_permutations(self, permutations):
"""
Set permutations between datapoints of the trace
:param permutations: list of indices to plot (example: [0, 2, 1] means that the first point will be plotted, then the third, then the second one)
"""
if permutations is not None:
if len(self.get_x()) == len(permutations):
if self.sort_output:
print("Warning : Permutations due to flag 'sort_output' are overridden by user")
self.permutations = permutations
else:
print("Error : Permutations have not the same length as the data")
[docs] def get_x_label(self):
""" Get x label of the trace """
return self.x_label
[docs] def get_y_label(self):
""" Get y label of the trace """
return self.y_label
[docs] def get_legend(self):
""" Get name of the trace """
return self.legend
[docs] def get_symbol(self):
""" Get symbol """
return self.symbol
[docs] def add_point(self, x, y):
"""Add point(s) to trace (inputs can be list or numeral)"""
if not isinstance(x, list):
x = [x]
if not isinstance(y, list):
y = [y]
self.x.extend(x)
self.y.extend(y)
[docs] def delete_point(self, index_point):
"""Delete a point from the datapoints"""
if len(self.x):
del self.x[index_point]
del self.y[index_point]
else:
del self.y[index_point]
[docs] def is_scattered(self):
"""Delete a point from the datapoints"""
return self.isScattered
[docs] def set_indices_points_to_plot(self, indices):
"""Set indices points to plot"""
self.points_to_plot = indices
[docs] def get_indices_points_to_plot(self):
"""Get indices points to plot"""
indices = self.points_to_plot
if indices is None:
indices = range(self.get_length_data())
return indices
[docs] def get_linestyle(self):
"""Get linestyle"""
return str(self.linestyle)
[docs] def __str__(self):
theStr = self.x_label + '\t' + self.y_label + '\n'
theX = self.get_x()
theY = self.get_y()
for i in range(self.get_length_data()):
theStr += str(theX[i]) + '\t\t\t' + str(theY[i]) + '\n'
return theStr
[docs] def export_str(self):
"""Method to save the points constituting the trace"""
theStr = "# $NEW TRACE 2D$\n"
theStr += "# $LEGEND$: ${}$\n".format(self.get_legend())
theStr += "# $LABEL X$: ${}$\n".format(self.get_x_label())
theStr += "# $LABEL Y$: ${}$\n".format(self.get_y_label())
theStr += "# $BEGIN DATA$\n"
x, y = self.get_plot_data()
for i in range(len(x)):
theStr += "{}\t{}\n".format(x[i], y[i])
theStr += "# $END DATA$\n"
return theStr
[docs] def set_color(self, theColor):
self.color = theColor
[docs]class Graph:
"""Simple graph container that contains several traces"""
def __init__(self):
self.traces = dict()
self.curr_id = 0
[docs] def add_trace(self, data):
"""Add a trace to the graph
:param data: :class:`~Data`
:return: id of the created trace
"""
idTrace = self.curr_id
self.traces[idTrace] = data
self.curr_id += 1
return idTrace
[docs] def remove_trace(self, idTrace):
"""Delete a trace from the graph
:param idTrace: id of the trace to delete
"""
try:
del self.traces[idTrace]
except KeyError:
printIfShown("Key {} not found".format(idTrace), SHOW_WARNING)
[docs] def get_trace(self, idTrace) -> Data:
"""Get data object of idTrace
:param idTrace: id of the trace to get
:return: :class:`~Data`
"""
return self.traces[idTrace]
[docs] def get_all_traces(self):
"""Get all the traces id of the graph"""
return self.traces
[docs] def export_str(self):
theStr = "# $NEW GRAPH$\n\n"
for idTrace in self.get_all_traces():
theStr += self.get_trace(idTrace).export_str()
theStr += "\n"
return theStr
[docs]class Graphs:
"""Contains several :class:`Graph`"""
def __init__(self):
self.graphs = dict()
self.curr_id = 0
self.updateMethods = set()
[docs] def updateChildren(self):
for updateMethod in self.updateMethods:
updateMethod()
[docs] def add_trace_firstGraph(self, data, updateChildren=True):
"""
Same as add_trace, but only if graphs has only one id
:param data:
:param updateChildren:
:return:
"""
all_ids = self.get_all_graphs_ids()
if len(all_ids) == 1:
return self.add_trace(all_ids[0], data, updateChildren=updateChildren)
printIfShown("Cannot add trace .. graphs are multiple", SHOW_WARNING)
return None
[docs] def add_trace(self, idGraph, data, updateChildren=True):
"""Add a trace to the graph
:param idGraph: id of the graph
:param data: :class:`~Data`
:param updateChildren: Automatically calls callback functions
:return: id of the created trace
"""
idTrace = self.get_graph(idGraph).add_trace(data)
if updateChildren:
self.updateChildren()
return idTrace
[docs] def remove_trace(self, idGraph, idTrace, updateChildren=True):
"""Remove the trace from the graph
:param idGraph: id of the graph
:param idTrace: id of the trace to remove
:param updateChildren: Automatically calls callback functions
"""
self.get_graph(idGraph).remove_trace(idTrace)
if updateChildren:
self.updateChildren()
[docs] def get_first_graph(self):
"""Get id of the first graph
:return: id of the first graph
"""
return self.get_graph(self.get_all_graphs_ids()[0])
[docs] def get_graph(self, idGraph):
"""Get graph object at idgraph
:param idGraph: id of the graph to get
:return: :class:`~Graph`
"""
return self.graphs[idGraph]
[docs] def get_all_graphs_ids(self):
"""Get all ids of the graphs
:return: list of id graphs
"""
return list(self.graphs.keys())
[docs] def get_all_graphs(self):
"""Get all graphs. Return dict {id: :class:`~Graph`}"""
return self.graphs
[docs] def add_graph(self, updateChildren=True):
"""Add a new graph
:return: id of the created graph
"""
idGraph = self.curr_id
self.graphs[idGraph] = Graph()
self.curr_id += 1
if updateChildren:
self.updateChildren()
return idGraph
[docs] def remove_graph(self, idGraph):
"""Delete a graph
:param idGraph: id of the graph to delete
"""
try:
del self.graphs[idGraph]
except KeyError:
printIfShown("Key {} not found".format(idGraph), SHOW_WARNING)
self.updateChildren()
[docs] def add_update_method(self, childObject):
"""Add a callback each time a graph is modified.
:param childObject: method without arguments
"""
self.updateMethods.add(childObject)
[docs] def export_str(self):
"""Export all the graphs in text
:return: str"""
theStr = ""
for graphId in self.get_all_graphs_ids():
theStr += self.get_graph(graphId).export_str()
theStr += "\n\n\n"
return theStr
[docs] def merge(self, otherGraphs):
curr_id = 0
mappings = [{}]*2
all_graphs_1 = self.get_all_graphs()
all_graphs_2 = otherGraphs.get_all_graphs()
new_graphs = dict()
for index, all_graphs in enumerate([all_graphs_1, all_graphs_2]):
for graphId in all_graphs_1:
new_graphs[curr_id] = all_graphs[graphId]
mappings[index][graphId] = curr_id
curr_id += 1
self.graphs = new_graphs
self.updateChildren()
return mappings
[docs] def reset(self):
self.graphs = dict()
self.updateChildren()
[docs] def is_empty(self):
return len(self.get_all_graphs()) == 0