import copy
import json
import logging
import uuid
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union, TypeVar, TYPE_CHECKING
import re
import matplotlib
import matplotlib.cm as cm
import networkx as nx
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns
import yaml
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from networkx.algorithms.dag import is_directed_acyclic_graph
from plotly.subplots import make_subplots
from elphick.geomet import Sample
from elphick.geomet.base import MC
from elphick.geomet.config.config_read import get_column_config
from elphick.geomet.flowsheet.operation import NodeType, OP, PartitionOperation, Operation
from elphick.geomet.plot import parallel_plot, comparison_plot
from elphick.geomet.utils.layout import digraph_linear_layout
from elphick.geomet.flowsheet.loader import streams_from_dataframe
# if TYPE_CHECKING:
from elphick.geomet.flowsheet.stream import Stream
# generic type variable, used for type hinting that play nicely with subclasses
FS = TypeVar('FS', bound='Flowsheet')
[docs]
class Flowsheet:
[docs]
    def __init__(self, name: str = 'Flowsheet'):
        self.graph: nx.DiGraph = nx.DiGraph(name=name)
        self._logger: logging.Logger = logging.getLogger(__class__.__name__) 
    @property
    def name(self) -> str:
        return self.graph.name
    @name.setter
    def name(self, value: str):
        self.graph.name = value
    @property
    def healthy(self) -> bool:
        return self.all_nodes_healthy and self.all_streams_healthy
    @property
    def all_nodes_healthy(self) -> bool:
        bal_vals: List = [self.graph.nodes[n]['mc'].is_balanced for n in self.graph.nodes]
        bal_vals = [bv for bv in bal_vals if bv is not None]
        return all(bal_vals)
    @property
    def all_streams_healthy(self) -> bool:
        """Check if all streams are healthy"""
        # account for the fact that some edges may not have an mc object
        if not all([d['mc'] for u, v, d in self.graph.edges(data=True)]):
            return False
        return all([self.graph.edges[u, v]['mc'].status.ok for u, v in self.graph.edges])
[docs]
    @classmethod
    def from_objects(cls, objects: list[MC],
                     name: Optional[str] = 'Flowsheet') -> FS:
        """Instantiate from a list of objects
        This method is only suitable for objects that have the `_nodes` property set, such as objects that have
        been created from math operations, which preserve relationships between objects (via the nodes property)
        Args:
            objects: List of MassComposition objects, such as Sample, IntervalSample, Stream or BlockModel
            name: name of the flowsheet/network
        Returns:
        """
        from elphick.geomet.flowsheet.operation import Operation
        cls._check_indexes(objects)
        bunch_of_edges: list = []
        for stream in objects:
            if stream.nodes is None:
                raise KeyError(f'Stream {stream.name} does not have the node property set')
            nodes = stream.nodes
            # add the objects to the edges
            bunch_of_edges.append((nodes[0], nodes[1], {'mc': stream, 'name': stream.name}))
        graph = nx.DiGraph(name=name)
        graph.add_edges_from(bunch_of_edges)
        operation_objects: dict = {}
        for node in graph.nodes:
            operation_objects[node] = Operation(name=node)
        nx.set_node_attributes(graph, operation_objects, 'mc')
        for node in graph.nodes:
            operation_objects[node].inputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.in_edges(node)]
            operation_objects[node].outputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.out_edges(node)]
        graph = nx.convert_node_labels_to_integers(graph)
        # update the temporary nodes on the mc object property to match the renumbered integers
        for node1, node2, data in graph.edges(data=True):
            data['mc'].nodes = [node1, node2]
        # update the node names after renumbering
        for node in graph.nodes:
            graph.nodes[node]['mc'].name = str(node)
        obj = cls()
        obj.graph = graph
        return obj 
[docs]
    @classmethod
    def from_dataframe(cls, df: pd.DataFrame, name: Optional[str] = 'Flowsheet',
                       mc_name_col: Optional[str] = None, n_jobs: int = 1) -> FS:
        """Instantiate from a DataFrame
        Args:
            df: The DataFrame
            name: name of the network
            mc_name_col: The column specified contains the names of objects to create.
              If None the DataFrame is assumed to be wide and the mc objects will be extracted from column prefixes.
            n_jobs: The number of parallel jobs to run.  If -1, will use all available cores.
        Returns:
            Flowsheet: An instance of the Flowsheet class initialized from the provided DataFrame.
        """
        streams: list[Sample] = streams_from_dataframe(df=df, mc_name_col=mc_name_col, n_jobs=n_jobs)
        return cls().from_objects(objects=streams, name=name) 
[docs]
    @classmethod
    def from_dict(cls, config: dict) -> FS:
        """Create a flowsheet from a dictionary
        Args:
            config: dictionary containing the flowsheet configuration
        Returns:
            A Flowsheet object with no data on the edges
        """
        from elphick.geomet.flowsheet.operation import Operation
        if 'FLOWSHEET' not in config:
            raise ValueError("Dictionary does not contain 'FLOWSHEET' root node")
        flowsheet_config = config['FLOWSHEET']
        # create the Stream objects
        bunch_of_edges: list = []
        for stream, stream_config in flowsheet_config['streams'].items():
            bunch_of_edges.append(
                (stream_config['node_in'], stream_config['node_out'], {'mc': None, 'name': stream_config['name']}))
        graph = nx.DiGraph(name=flowsheet_config['flowsheet']['name'])
        graph.add_edges_from(bunch_of_edges)
        operation_objects: dict = {}
        for node in graph.nodes:
            # create the correct type of node object
            if node in flowsheet_config['operations']:
                operation_type = flowsheet_config['operations'][node].get('type', 'Operation')
                if operation_type == 'PartitionOperation':
                    # get the output stream names from the graph
                    output_stream_names = [d['name'] for u, v, d in graph.out_edges(node, data=True)]
                    node_config = flowsheet_config['operations'][node]
                    node_config['output_stream_names'] = output_stream_names
                    operation_objects[node] = PartitionOperation.from_dict(node_config)
                else:
                    operation_objects[node] = Operation.from_dict(flowsheet_config['operations'][node])
                # set the input and output streams on the operation object for the selected node
                operation_objects[node].inputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.in_edges(node)]
                operation_objects[node].outputs = [graph.get_edge_data(e[0], e[1])['mc'] for e in graph.out_edges(node)]
        nx.set_node_attributes(graph, operation_objects, 'mc')
        graph = nx.convert_node_labels_to_integers(graph)
        obj = cls()
        obj.graph = graph
        return obj 
    # @classmethod
    # def from_dict_todo(cls, config: dict) -> FS:
    # TODO: This method is not yet implemented - fails because the Operations do not have inputs or outputs set.
    #     flowsheet = cls()
    #
    #     # Process streams
    #     for stream_name, stream_data in config['FLOWSHEET']['streams'].items():
    #         stream = Stream.from_dict(stream_data)
    #         flowsheet.add_stream(stream)
    #
    #     # Process operations
    #     for operation_name, operation_data in config['FLOWSHEET']['operations'].items():
    #         operation_type = operation_data.get('type', 'Operation')
    #         if operation_type == 'PartitionOperation':
    #             operation = PartitionOperation.from_dict(operation_data)
    #         else:
    #             operation = Operation.from_dict(operation_data)
    #         flowsheet.add_operation(operation)
    #
    #     return flowsheet
[docs]
    @classmethod
    def from_yaml(cls, file_path: Path) -> FS:
        """Create a flowsheet from yaml
        Args:
            file_path: path to the yaml file
        Returns:
            A Flowsheet object with no data on the edges
        """
        with open(file_path, 'r') as file:
            config = yaml.safe_load(file)
        return cls.from_dict(config) 
[docs]
    @classmethod
    def from_json(cls, file_path: Path) -> FS:
        """Create a flowsheet from json
        Args:
            file_path: path to the json file
        Returns:
            A Flowsheet object with no data on the edges
        """
        with open(file_path, 'r') as file:
            config = json.load(file)
        return cls.from_dict(config) 
[docs]
    def add_stream(self, stream: 'Stream'):
        """Add a stream to the flowsheet."""
        self.graph.add_edge(stream.nodes[0], stream.nodes[1], mc=stream, name=stream.name) 
[docs]
    def add_operation(self, operation: 'Operation'):
        """Add an operation to the flowsheet."""
        self.graph.add_node(operation.name, mc=operation) 
[docs]
    def unhealthy_stream_records(self) -> pd.DataFrame:
        """Return on unhealthy streams
        Return the records for all streams that are not healthy.
        Returns:
            DataFrame: A DataFrame containing the unhealthy stream records
        """
        unhealthy_edges = [e for e in self.graph.edges if not self.graph.edges[e]['mc'].status.ok]
        unhealthy_data: pd.DataFrame = pd.concat(
            [self.graph.edges[e]['mc'].status.oor.assign(stream=self.graph.edges[e]['mc'].name) for e in
             unhealthy_edges], axis=1)
        # move the last column to the front
        unhealthy_data = unhealthy_data[[unhealthy_data.columns[-1]] + list(unhealthy_data.columns[:-1])]
        # append the flowsheet records for additional context
        records: pd.DataFrame = self.to_dataframe()
        records = records.unstack(level='name').swaplevel(axis=1).sort_index(axis=1, level=0, sort_remaining=False)
        records.columns = [f"{col[0]}_{col[1]}" for col in records.columns]
        result = unhealthy_data.merge(records, left_index=True, right_index=True, how='left')
        return result 
[docs]
    def unhealthy_node_records(self) -> pd.DataFrame:
        """Return unhealthy nodes
        Return the records for all nodes that are not healthy.
        Returns:
            DataFrame: A DataFrame containing the unhealthy node records
        """
        unhealthy_nodes = [n for n in self.graph.nodes if
                           self.graph.nodes[n]['mc'].node_type == NodeType.BALANCE and not self.graph.nodes[n][
                               'mc'].is_balanced]
        unhealthy_data: pd.DataFrame = pd.concat(
            [self.graph.nodes[n]['mc'].unbalanced_records.assign(node=self.graph.nodes[n]['mc'].name) for n in
             unhealthy_nodes], axis=1)
        # move the last column to the front
        unhealthy_data = unhealthy_data[[unhealthy_data.columns[-1]] + list(unhealthy_data.columns[:-1])]
        # todo: append  the streams around the node
        return unhealthy_data 
[docs]
    def copy_without_stream_data(self):
        """Copy without stream data"""
        new_flowsheet = Flowsheet(name=self.name)
        new_graph = nx.DiGraph()
        # Copy nodes with Operation objects
        for node, data in self.graph.nodes(data=True):
            new_data = data.copy()
            new_graph.add_node(node, **new_data)
        # Copy edges with mc attribute set to None
        for u, v, data in self.graph.edges(data=True):
            new_data = {k: (None if k == 'mc' else copy.deepcopy(v)) for k, v in data.items()}
            new_graph.add_edge(u, v, **new_data)
        new_flowsheet.graph = new_graph
        return new_flowsheet 
[docs]
    def solve(self):
        """Solve missing streams"""
        if not is_directed_acyclic_graph(self.graph):
            self._logger.error("Graph is not a Directed Acyclic Graph (DAG), so cannot be solved.")
            self._logger.debug(f"Graph nodes: {self.graph.nodes(data=True)}")
            self._logger.debug(f"Graph edges: {self.graph.edges(data=True)}")
            raise ValueError("Graph is not a Directed Acyclic Graph (DAG), so cannot be solved.")
        # Check the number of missing mc's on edges in the network
        missing_count: int = sum([1 for u, v, d in self.graph.edges(data=True) if d['mc'] is None])
        prev_missing_count = missing_count + 1  # Initialize with a value greater than missing_count
        while 0 < missing_count < prev_missing_count:
            prev_missing_count = missing_count
            for node in nx.topological_sort(self.graph):
                if self.graph.nodes[node]['mc'].node_type == NodeType.BALANCE:
                    if self.graph.nodes[node]['mc'].has_empty_input:
                        mc: MC = self.graph.nodes[node]['mc'].solve()
                        # copy the solved object to the empty input edges
                        for predecessor in self.graph.predecessors(node):
                            edge_data = self.graph.get_edge_data(predecessor, node)
                            if edge_data and edge_data['mc'] is None:
                                edge_data['mc'] = mc
                                edge_data['mc'].name = edge_data['name']
                                self.set_operation_data(predecessor)
                    if self.graph.nodes[node]['mc'].has_empty_output:
                        # There are two cases to be managed, 1. a single output missing,
                        # 2. a partition operation that returns two outputs
                        if isinstance(self.graph.nodes[node]['mc'], PartitionOperation):
                            partition_stream: str = self.graph.nodes[node]['mc'].partition['partition_stream']
                            mc1, mc2 = self.graph.nodes[node]['mc'].solve()
                            # copy the solved object to the empty output edges
                            for successor in self.graph.successors(node):
                                edge_data = self.graph.get_edge_data(node, successor)
                                if edge_data and edge_data['mc'] is None:
                                    edge_data['mc'] = mc1 if edge_data['name'] == partition_stream else mc2
                                    edge_data['mc'].name = edge_data['name']
                                    self.set_operation_data(successor)
                        else:
                            mc: MC = self.graph.nodes[node]['mc'].solve()
                            # copy the solved object to the empty output edges
                            for successor in self.graph.successors(node):
                                edge_data = self.graph.get_edge_data(node, successor)
                                if edge_data and edge_data['mc'] is None:
                                    edge_data['mc'] = mc
                                    edge_data['mc'].name = edge_data['name']
                                    self.set_operation_data(successor)
                        self.set_operation_data(node)
            missing_count: int = sum([1 for u, v, d in self.graph.edges(data=True) if d['mc'] is None])
            self._logger.info(f"Missing count: {missing_count}")
        if missing_count > 0:
            self._logger.error(f"Failed to solve the flowsheet. Missing count: {missing_count}")
            raise ValueError(
                f"Failed to solve the flowsheet. Some streams are still missing. Missing count: {missing_count}") 
[docs]
    def query(self, expr: str, stream_name: Optional[str] = None, inplace=False) -> 'Flowsheet':
        """Reduce the Flowsheet Stream records with a query
        Args:
            expr: The query string to apply to all streams. The query is applied in place. The LHS of the
                expression requires a prefix that defines the stream name e.g. stream_name.var_name > 0.5
            stream_name: The name of the stream to apply the query to. If None, the query is applied to the
                first input stream.
            inplace: If True, apply the query in place on the same object, otherwise return a new instance.
        Returns:
            A Flowsheet object where the stream records conform to the query
        """
        if stream_name is None:
            input_stream: MC = self.get_input_streams()[0]
        else:
            input_stream: MC = self.get_stream_by_name(name=stream_name)
        filtered_index: pd.Index = input_stream.data.query(expr).index
        return self._filter(filtered_index, inplace) 
[docs]
    def filter_by_index(self, index: pd.Index, inplace: bool = False) -> 'Flowsheet':
        """Filter the Flowsheet Stream records by a given index.
        Args:
            index: The index to filter the data.
            inplace: If True, apply the filter in place on the same object, otherwise return a new instance.
        Returns:
            A Flowsheet object where the stream records are filtered by the given index.
        """
        return self._filter(index, inplace) 
    def _filter(self, index: pd.Index, inplace: bool = False) -> 'Flowsheet':
        """Private method to filter the Flowsheet Stream records by a given index.
        Args:
            index: The index to filter the data.
            inplace: If True, apply the filter in place on the same object, otherwise return a new instance.
        Returns:
            A Flowsheet object where the stream records are filtered by the given index.
        """
        if inplace:
            for u, v, d in self.graph.edges(data=True):
                if d.get('mc') is not None:
                    d.get('mc').filter_by_index(index)
            return self
        else:
            obj: Flowsheet = self.copy_without_stream_data()
            for u, v, d in self.graph.edges(data=True):
                if d.get('mc') is not None:
                    mc: MC = d.get('mc')
                    mc_new = mc.__class__(name=mc.name)
                    # Copy each attribute
                    for attr, value in mc.__dict__.items():
                        if attr in ['_mass_data', '_supplementary_data'] and value is not None:
                            value = value.loc[index]
                        setattr(mc_new, attr, copy.deepcopy(value))
                    mc_new.aggregate = mc_new.weight_average()
                    obj.graph[u][v]['mc'] = mc_new
            return obj
[docs]
    def get_output_streams(self) -> list[MC]:
        """Get the output (product) streams (edge objects)
        Returns:
            List of MassComposition-like objects
        """
        # Create a dictionary that maps node names to their degrees
        degrees = {n: d for n, d in self.graph.degree()}
        res: list[MC] = [d['mc'] for u, v, d in self.graph.edges(data=True) if degrees[v] == 1]
        if not res:
            raise ValueError("No output streams found")
        return res 
    @staticmethod
    def _check_indexes(streams):
        list_of_indexes = [s._mass_data.index for s in streams]
        types_of_indexes = [type(i) for i in list_of_indexes]
        # check the index types are consistent
        if len(set(types_of_indexes)) != 1:
            raise KeyError("stream index types are not consistent")
[docs]
    def plot(self, orientation: str = 'horizontal') -> plt.Figure:
        """Plot the network with matplotlib
        Args:
            orientation: 'horizontal'|'vertical' network layout
        Returns:
        """
        hf, ax = plt.subplots()
        # pos = nx.spring_layout(self, seed=1234)
        pos = digraph_linear_layout(self.graph, orientation=orientation)
        edge_labels: Dict = {}
        edge_colors: List = []
        node_colors: List = []
        for node1, node2, data in self.graph.edges(data=True):
            edge_labels[(node1, node2)] = data['mc'].name if data['mc'] is not None else data['name']
            if data['mc'] and data['mc'].status.ok:
                edge_colors.append('gray')
            else:
                edge_colors.append('red')
        for n in self.graph.nodes:
            if self.graph.nodes[n]['mc'].node_type == NodeType.BALANCE:
                if self.graph.nodes[n]['mc'].is_balanced:
                    node_colors.append('green')
                else:
                    node_colors.append('red')
            else:
                node_colors.append('gray')
        nx.draw(self.graph, pos=pos, ax=ax, with_labels=True, font_weight='bold',
                node_color=node_colors, edge_color=edge_colors)
        nx.draw_networkx_edge_labels(self.graph, pos=pos, ax=ax, edge_labels=edge_labels, font_color='black')
        ax.set_title(self._plot_title(html=False), fontsize=10)
        return hf 
    def _plot_title(self, html: bool = True, compact: bool = False):
        # title = self.name
        title = (f"{self.name}<br><sup>Nodes Healthy: "
                 f"<span style='color: {'red' if not self.all_nodes_healthy else 'black'}'>{self.all_nodes_healthy}</span>, "
                 f"Streams Healthy: "
                 f"<span style='color: {'red' if not self.all_streams_healthy else 'black'}'>{self.all_streams_healthy}</span></sup>")
        # if compact:
        #     title = title.replace("<br><br>", "<br>").replace("<br>Edge", ", Edge")
        # if not self.edge_status[0]:
        #     title = title.replace("</sup>", "") + f", {self.edge_status[1]}</sup>"
        if not html:
            title = title.replace('<br><br>', '\n').replace('<br>', '\n').replace('<sup>', '').replace('</sup>', '')
            title = re.sub(r'<span style=.*?>(.*?)</span>', r'\1', title)
        return title
[docs]
    def report(self, apply_formats: bool = False) -> pd.DataFrame:
        """Summary Report
        Total Mass and weight averaged composition
        Returns:
        """
        chunks: List[pd.DataFrame] = []
        for n, nbrs in self.graph.adj.items():
            for nbr, eattr in nbrs.items():
                if eattr['mc'] is None or eattr['mc'].data.empty:
                    edge_name: str = eattr['mc']['name'] if eattr['mc'] is not None else eattr['name']
                    raise KeyError(f"Cannot generate report on empty dataset: {edge_name}")
                chunks.append(eattr['mc'].aggregate.assign(name=eattr['mc'].name))
        rpt: pd.DataFrame = pd.concat(chunks, axis='index').set_index('name')
        if apply_formats:
            fmts: Dict = self._get_column_formats(rpt.columns)
            for k, v in fmts.items():
                rpt[k] = rpt[k].apply((v.replace('%', '{:,') + '}').format)
        return rpt 
    def _get_column_formats(self, columns: List[str], strip_percent: bool = False) -> Dict[str, str]:
        """
        Args:
            columns: The columns to lookup format strings for
            strip_percent: If True remove the leading % symbol from the format (for plotly tables)
        Returns:
        """
        strm = self.get_input_streams()[0]
        d_format: dict = get_column_config(config_dict=strm.config, var_map=strm.variable_map, config_key='format')
        if strip_percent:
            d_format = {k: v.strip('%') for k, v in d_format.items()}
        return d_format
[docs]
    def plot_balance(self, facet_col_wrap: int = 3,
                     color: Optional[str] = 'node') -> go.Figure:
        """Plot input versus output across all nodes in the network
        Args:
            facet_col_wrap: the number of subplots per row before wrapping
            color: The optional variable to color by. If None color will be by Node
        Returns:
        """
        # prepare the data
        chunks_in: List = []
        chunks_out: List = []
        for n in self.graph.nodes:
            if self.graph.nodes[n]['mc'].node_type == NodeType.BALANCE:
                chunks_in.append(self.graph.nodes[n]['mc'].add('in').assign(**{'direction': 'in', 'node': n}))
                chunks_out.append(self.graph.nodes[n]['mc'].add('out').assign(**{'direction': 'out', 'node': n}))
        df_in: pd.DataFrame = pd.concat(chunks_in)
        index_names = ['direction', 'node'] + df_in.index.names
        df_in = df_in.reset_index().melt(id_vars=index_names)
        df_out: pd.DataFrame = pd.concat(chunks_out).reset_index().melt(id_vars=index_names)
        df_plot: pd.DataFrame = pd.concat([df_in, df_out])
        df_plot = df_plot.set_index(index_names + ['variable'], append=True).unstack(['direction'])
        df_plot.columns = df_plot.columns.droplevel(0)
        df_plot.reset_index(level=list(np.arange(-1, -len(index_names) - 1, -1)), inplace=True)
        df_plot['node'] = pd.Categorical(df_plot['node'])
        # plot
        fig = comparison_plot(data=df_plot,
                              x='in', y='out',
                              facet_col_wrap=facet_col_wrap,
                              color=color)
        return fig 
[docs]
    def plot_network(self, orientation: str = 'horizontal') -> go.Figure:
        """Plot the network with plotly
        Args:
            orientation: 'horizontal'|'vertical' network layout
        Returns:
        """
        # pos = nx.spring_layout(self, seed=1234)
        pos = digraph_linear_layout(self.graph, orientation=orientation)
        edge_traces, node_trace, edge_annotation_trace = self._get_scatter_node_edges(pos)
        title = self._plot_title()
        fig = go.Figure(data=[*edge_traces, node_trace, edge_annotation_trace],
                        layout=go.Layout(
                            title=title,
                            titlefont_size=16,
                            showlegend=False,
                            hovermode='closest',
                            margin=dict(b=20, l=5, r=5, t=40),
                            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                            paper_bgcolor='rgba(0,0,0,0)',
                            plot_bgcolor='rgba(0,0,0,0)'
                        ),
                        )
        # for k, d_args in edge_annotations.items():
        #     fig.add_annotation(x=d_args['pos'][0], y=d_args['pos'][1], text=k, textangle=d_args['angle'])
        return fig 
[docs]
    def plot_sankey(self,
                    width_var: str = 'mass_dry',
                    color_var: Optional[str] = None,
                    edge_colormap: Optional[str] = 'copper_r',
                    vmin: Optional[float] = None,
                    vmax: Optional[float] = None,
                    ) -> go.Figure:
        """Plot the Network as a sankey
        Args:
            width_var: The variable that determines the sankey width
            color_var: The optional variable that determines the sankey edge color
            edge_colormap: The optional colormap.  Used with color_var.
            vmin: The value that maps to the minimum color
            vmax: The value that maps to the maximum color
        Returns:
        """
        # Create a mapping of node names to indices, and the integer nodes
        node_indices = {node: index for index, node in enumerate(self.graph.nodes)}
        int_graph = nx.relabel_nodes(self.graph, node_indices)
        # Generate the sankey diagram arguments using the new graph with integer nodes
        d_sankey = self._generate_sankey_args(int_graph, color_var, edge_colormap, width_var, vmin, vmax)
        # Create the sankey diagram
        node, link = self._get_sankey_node_link_dicts(d_sankey)
        fig = go.Figure(data=[go.Sankey(node=node, link=link)])
        title = self._plot_title()
        fig.update_layout(title_text=title, font_size=10)
        return fig 
[docs]
    def table_plot(self,
                   plot_type: str = 'sankey',
                   cols_exclude: Optional[List] = None,
                   table_pos: str = 'left',
                   table_area: float = 0.4,
                   table_header_color: str = 'cornflowerblue',
                   table_odd_color: str = 'whitesmoke',
                   table_even_color: str = 'lightgray',
                   sankey_width_var: str = 'mass_dry',
                   sankey_color_var: Optional[str] = None,
                   sankey_edge_colormap: Optional[str] = 'copper_r',
                   sankey_vmin: Optional[float] = None,
                   sankey_vmax: Optional[float] = None,
                   network_orientation: Optional[str] = 'horizontal'
                   ) -> go.Figure:
        """Plot with table of edge averages
        Args:
            plot_type: The type of plot ['sankey', 'network']
            cols_exclude: List of columns to exclude from the table
            table_pos: Position of the table ['left', 'right', 'top', 'bottom']
            table_area: The proportion of width or height to allocate to the table [0, 1]
            table_header_color: Color of the table header
            table_odd_color: Color of the odd table rows
            table_even_color: Color of the even table rows
            sankey_width_var: If plot_type is sankey, the variable that determines the sankey width
            sankey_color_var: If plot_type is sankey, the optional variable that determines the sankey edge color
            sankey_edge_colormap: If plot_type is sankey, the optional colormap.  Used with sankey_color_var.
            sankey_vmin: The value that maps to the minimum color
            sankey_vmax: The value that maps to the maximum color
            network_orientation: The orientation of the network layout 'vertical'|'horizontal'
        Returns:
        """
        valid_plot_types: List[str] = ['sankey', 'network']
        if plot_type not in valid_plot_types:
            raise ValueError(f'The supplied plot_type is not in {valid_plot_types}')
        valid_table_pos: List[str] = ['top', 'bottom', 'left', 'right']
        if table_pos not in valid_table_pos:
            raise ValueError(f'The supplied table_pos is not in {valid_table_pos}')
        d_subplot, d_table, d_plot = self._get_position_kwargs(table_pos, table_area, plot_type)
        fig = make_subplots(**d_subplot, print_grid=False)
        df: pd.DataFrame = self.report().reset_index()
        if cols_exclude:
            df = df[[col for col in df.columns if col not in cols_exclude]]
        fmt: List[str] = ['%s'] + list(self._get_column_formats(df.columns, strip_percent=True).values())
        column_widths = [2] + [1] * (len(df.columns) - 1)
        fig.add_table(
            header=dict(values=list(df.columns),
                        fill_color=table_header_color,
                        align='center',
                        font=dict(color='black', size=12)),
            columnwidth=column_widths,
            cells=dict(values=df.transpose().values.tolist(),
                       align='left', format=fmt,
                       fill_color=[
                           [table_odd_color if i % 2 == 0 else table_even_color for i in range(len(df))] * len(
                               df.columns)]),
            **d_table)
        if plot_type == 'sankey':
            # Create a mapping of node names to indices, and the integer nodes
            node_indices = {node: index for index, node in enumerate(self.graph.nodes)}
            int_graph = nx.relabel_nodes(self.graph, node_indices)
            # Generate the sankey diagram arguments using the new graph with integer nodes
            d_sankey = self._generate_sankey_args(int_graph, sankey_color_var,
                                                  sankey_edge_colormap,
                                                  sankey_width_var,
                                                  sankey_vmin,
                                                  sankey_vmax)
            node, link = self._get_sankey_node_link_dicts(d_sankey)
            fig.add_trace(go.Sankey(node=node, link=link), **d_plot)
        elif plot_type == 'network':
            # pos = nx.spring_layout(self, seed=1234)
            pos = digraph_linear_layout(self.graph, orientation=network_orientation)
            edge_traces, node_trace, edge_annotation_trace = self._get_scatter_node_edges(pos)
            fig.add_traces(data=[*edge_traces, node_trace, edge_annotation_trace], **d_plot)
            fig.update_layout(showlegend=False, hovermode='closest',
                              xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                              yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                              paper_bgcolor='rgba(0,0,0,0)',
                              plot_bgcolor='rgba(0,0,0,0)'
                              )
        title = self._plot_title(compact=True)
        fig.update_layout(title_text=title, font_size=12)
        return fig 
[docs]
    def to_dataframe(self, stream_names: Optional[list[str]] = None, tidy: bool = True,
                     as_mass: bool = False) -> pd.DataFrame:
        """Return a tidy dataframe
        Adds the mc name to the index so indexes are unique.
        Args:
            stream_names: Optional List of names of Stream/MassComposition objects (network edges) for export
            tidy: If True, the data will be returned in a tidy format, otherwise wide
            as_mass: If True, the mass data will be returned instead of the mass-composition data
        Returns:
        """
        chunks: List[pd.DataFrame] = []
        for u, v, data in self.graph.edges(data=True):
            if (stream_names is None) or ((stream_names is not None) and (data['mc'].name in stream_names)):
                if as_mass:
                    chunks.append(data['mc'].mass_data.assign(name=data['mc'].name))
                else:
                    chunks.append(data['mc'].data.assign(name=data['mc'].name))
        results: pd.DataFrame = pd.concat(chunks, axis='index').set_index('name', append=True)
        if not tidy:  # wide format
            results = results.unstack(level='name')
            column_order: list[str] = [f'{name}_{attr}' for name in results.columns.levels[1] for attr in
                                       results.columns.levels[0]]
            results.columns = [f'{col[1]}_{col[0]}' for col in results.columns]
            results = results[column_order]
        return results 
[docs]
    def plot_parallel(self,
                      names: Optional[str] = None,
                      color: Optional[str] = None,
                      vars_include: Optional[List[str]] = None,
                      vars_exclude: Optional[List[str]] = None,
                      title: Optional[str] = None,
                      include_dims: Optional[Union[bool, List[str]]] = True,
                      plot_interval_edges: bool = False) -> go.Figure:
        """Create an interactive parallel plot
        Useful to explore multidimensional data like mass-composition data
        Args:
            names: Optional List of Names to plot
            color: Optional color variable
            vars_include: Optional List of variables to include in the plot
            vars_exclude: Optional List of variables to exclude in the plot
            title: Optional plot title
            include_dims: Optional boolean or list of dimension to include in the plot.  True will show all dims.
            plot_interval_edges: If True, interval edges will be plotted instead of interval mid
        Returns:
        """
        df: pd.DataFrame = self.to_dataframe(stream_names=names)
        if not title and hasattr(self, 'name'):
            title = self.name
        fig = parallel_plot(data=df, color=color, vars_include=vars_include, vars_exclude=vars_exclude, title=title,
                            include_dims=include_dims, plot_interval_edges=plot_interval_edges)
        return fig 
    def _generate_sankey_args(self, int_graph, color_var, edge_colormap, width_var, v_min, v_max):
        rpt: pd.DataFrame = self.report()
        if color_var is not None:
            cmap = sns.color_palette(edge_colormap, as_cmap=True)
            rpt: pd.DataFrame = self.report()
            if not v_min:
                v_min = np.floor(rpt[color_var].min())
            if not v_max:
                v_max = np.ceil(rpt[color_var].max())
        # run the report for the hover data
        d_custom_data: Dict = self._rpt_to_html(df=rpt)
        source: List = []
        target: List = []
        value: List = []
        edge_custom_data = []
        edge_color: List = []
        edge_labels: List = []
        node_colors: List = []
        node_labels: List = []
        for n in int_graph.nodes:
            node_labels.append(int_graph.nodes[n]['mc'].name)
            if int_graph.nodes[n]['mc'].node_type == NodeType.BALANCE:
                if int_graph.nodes[n]['mc'].is_balanced:
                    node_colors.append('green')
                else:
                    node_colors.append('red')
            else:
                node_colors.append('blue')
        for u, v, data in int_graph.edges(data=True):
            edge_labels.append(data['mc'].name)
            source.append(u)
            target.append(v)
            value.append(float(data['mc'].aggregate[width_var].iloc[0]))
            edge_custom_data.append(d_custom_data[data['mc'].name])
            if color_var is not None:
                val: float = float(data['mc'].aggregate[color_var].iloc[0])
                str_color: str = f'rgba{self._color_from_float(v_min, v_max, val, cmap)}'
                edge_color.append(str_color)
            else:
                edge_color: Optional[str] = None
        d_sankey: Dict = {'node_color': node_colors,
                          'edge_color': edge_color,
                          'edge_custom_data': edge_custom_data,
                          'edge_labels': edge_labels,
                          'labels': node_labels,
                          'source': source,
                          'target': target,
                          'value': value}
        return d_sankey
    @staticmethod
    def _get_sankey_node_link_dicts(d_sankey: Dict):
        node: Dict = dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=d_sankey['labels'],
            color=d_sankey['node_color'],
            customdata=d_sankey['labels']
        )
        link: Dict = dict(
            source=d_sankey['source'],  # indices correspond to labels, eg A1, A2, A1, B1, ...
            target=d_sankey['target'],
            value=d_sankey['value'],
            color=d_sankey['edge_color'],
            label=d_sankey['edge_labels'],  # over-written by hover template
            customdata=d_sankey['edge_custom_data'],
            hovertemplate='<b><i>%{label}</i></b><br />Source: %{source.customdata}<br />'
                          'Target: %{target.customdata}<br />%{customdata}'
        )
        return node, link
    def _get_scatter_node_edges(self, pos):
        # edges
        edge_color_map: Dict = {True: 'grey', False: 'red'}
        edge_annotations: Dict = {}
        edge_traces = []
        for u, v, data in self.graph.edges(data=True):
            x0, y0 = pos[u]
            x1, y1 = pos[v]
            edge_annotations[data['mc'].name] = {'pos': np.mean([pos[u], pos[v]], axis=0)}
            edge_traces.append(go.Scatter(x=[x0, x1], y=[y0, y1],
                                          line=dict(width=2, color=edge_color_map[data['mc'].status.ok]),
                                          hoverinfo='none',
                                          mode='lines+markers',
                                          text=str(data['mc'].name),
                                          marker=dict(
                                              symbol="arrow",
                                              color=edge_color_map[data['mc'].status.ok],
                                              size=16,
                                              angleref="previous",
                                              standoff=15)
                                          ))
        # nodes
        node_color_map: Dict = {None: 'grey', True: 'green', False: 'red'}
        node_x = []
        node_y = []
        node_color = []
        node_text = []
        node_label = []
        for node in self.graph.nodes():
            x, y = pos[node]
            node_x.append(x)
            node_y.append(y)
            node_color.append(node_color_map[self.graph.nodes[node]['mc'].is_balanced])
            node_text.append(node)
            node_label.append(self.graph.nodes[node]['mc'].name)
        node_trace = go.Scatter(
            x=node_x, y=node_y,
            mode='markers+text',
            hoverinfo='none',
            marker=dict(
                color=node_color,
                size=30,
                line_width=2),
            text=node_text,
            customdata=node_label,
            hovertemplate='%{customdata}<extra></extra>')
        # edge annotations
        edge_labels = list(edge_annotations.keys())
        edge_label_x = [edge_annotations[k]['pos'][0] for k, v in edge_annotations.items()]
        edge_label_y = [edge_annotations[k]['pos'][1] for k, v in edge_annotations.items()]
        edge_annotation_trace = go.Scatter(
            x=edge_label_x, y=edge_label_y,
            mode='markers',
            hoverinfo='text',
            marker=dict(
                color='grey',
                size=3,
                line_width=1),
            text=edge_labels)
        return edge_traces, node_trace, edge_annotation_trace
    @staticmethod
    def _get_position_kwargs(table_pos, table_area, plot_type):
        """Helper to manage location dependencies
        Args:
            table_pos: position of the table: left|right|top|bottom
            table_area: fraction of the plot to assign to the table [0, 1]
        Returns:
        """
        name_type_map: Dict = {'sankey': 'sankey', 'network': 'xy'}
        specs = [[{"type": 'table'}, {"type": name_type_map[plot_type]}]]
        widths: Optional[List[float]] = [table_area, 1.0 - table_area]
        subplot_kwargs: Dict = {'rows': 1, 'cols': 2, 'specs': specs}
        table_kwargs: Dict = {'row': 1, 'col': 1}
        plot_kwargs: Dict = {'row': 1, 'col': 2}
        if table_pos == 'left':
            subplot_kwargs['column_widths'] = widths
        elif table_pos == 'right':
            subplot_kwargs['column_widths'] = widths[::-1]
            subplot_kwargs['specs'] = [[{"type": name_type_map[plot_type]}, {"type": 'table'}]]
            table_kwargs['col'] = 2
            plot_kwargs['col'] = 1
        else:
            subplot_kwargs['rows'] = 2
            subplot_kwargs['cols'] = 1
            table_kwargs['col'] = 1
            plot_kwargs['col'] = 1
            if table_pos == 'top':
                subplot_kwargs['row_heights'] = widths
                subplot_kwargs['specs'] = [[{"type": 'table'}], [{"type": name_type_map[plot_type]}]]
                table_kwargs['row'] = 1
                plot_kwargs['row'] = 2
            elif table_pos == 'bottom':
                subplot_kwargs['row_heights'] = widths[::-1]
                subplot_kwargs['specs'] = [[{"type": name_type_map[plot_type]}], [{"type": 'table'}]]
                table_kwargs['row'] = 2
                plot_kwargs['row'] = 1
        if plot_type == 'network':  # different arguments for different plots
            plot_kwargs = {f'{k}s': v for k, v in plot_kwargs.items()}
        return subplot_kwargs, table_kwargs, plot_kwargs
    def _rpt_to_html(self, df: pd.DataFrame) -> Dict:
        custom_data: Dict = {}
        fmts: Dict = self._get_column_formats(df.columns)
        for i, row in df.iterrows():
            str_data: str = '<br />'
            for k, v in dict(row).items():
                str_data += f"{k}: {v:{fmts[k][1:]}}<br />"
            custom_data[i] = str_data
        return custom_data
    @staticmethod
    def _color_from_float(vmin: float, vmax: float, val: float,
                          cmap: Union[ListedColormap, LinearSegmentedColormap]) -> Tuple[float, float, float]:
        if isinstance(cmap, ListedColormap):
            color_index: int = int((val - vmin) / ((vmax - vmin) / 256.0))
            color_index = min(max(0, color_index), 255)
            color_rgba = tuple(cmap.colors[color_index])
        elif isinstance(cmap, LinearSegmentedColormap):
            norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
            m = cm.ScalarMappable(norm=norm, cmap=cmap)
            r, g, b, a = m.to_rgba(val, bytes=True)
            color_rgba = int(r), int(g), int(b), int(a)
        else:
            NotImplementedError("Unrecognised colormap type")
        return color_rgba
[docs]
    def set_node_names(self, node_names: Dict[int, str]):
        """Set the names of network nodes with a Dict
        """
        for node in node_names.keys():
            if ('mc' in self.graph.nodes[node].keys()) and (node in node_names.keys()):
                self.graph.nodes[node]['mc'].name = node_names[node] 
[docs]
    def set_stream_data(self, stream_data: dict[str, Optional[MC]]):
        """Set the data (MassComposition) of network edges (streams) with a Dict"""
        for stream_name, stream_data in stream_data.items():
            stream_found = False
            nodes_to_refresh = set()
            for u, v, data in self.graph.edges(data=True):
                if 'mc' in data.keys() and (data['mc'].name if data['mc'] is not None else data['name']) == stream_name:
                    self._logger.info(f'Setting data on stream {stream_name}')
                    data['mc'] = stream_data
                    stream_found = True
                    nodes_to_refresh.update([u, v])
            if not stream_found:
                self._logger.warning(f'Stream {stream_name} not found in graph')
            else:
                # refresh the node status
                for node in nodes_to_refresh:
                    self.graph.nodes[node]['mc'].inputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in
                                                           self.graph.in_edges(node)]
                    self.graph.nodes[node]['mc'].outputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in
                                                            self.graph.out_edges(node)] 
[docs]
    def set_operation_data(self, node):
        """Set the input and output data for a node.
        Uses the data on the edges (streams) connected to the node to refresh the data and check for node balance.
        """
        node_data: Operation = self.graph.nodes[node]['mc']
        node_data.inputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in self.graph.in_edges(node)]
        node_data.outputs = [self.graph.get_edge_data(e[0], e[1])['mc'] for e in self.graph.out_edges(node)]
        node_data.check_balance() 
[docs]
    def streams_to_dict(self) -> Dict[str, MC]:
        """Export the Stream objects to a Dict
        Returns:
            A dictionary keyed by name containing MassComposition objects
        """
        streams: Dict[str, MC] = {}
        for u, v, data in self.graph.edges(data=True):
            if 'mc' in data.keys():
                streams[data['mc'].name] = data['mc']
        return streams 
[docs]
    def nodes_to_dict(self) -> Dict[int, OP]:
        """Export the MCNode objects to a Dict
        Returns:
            A dictionary keyed by integer containing MCNode objects
        """
        nodes: Dict[int, OP] = {}
        for node in self.graph.nodes.keys():
            if 'mc' in self.graph.nodes[node].keys():
                nodes[node] = self.graph.nodes[node]['mc']
        return nodes 
    def set_nodes(self, stream: str, nodes: Tuple[int, int]):
        mc: MC = self.get_stream_by_name(stream)
        mc._nodes = nodes
        self._update_graph(mc)
[docs]
    def reset_nodes(self, stream: Optional[str] = None):
        """Reset stream nodes to break relationships
        Args:
            stream: The optional stream (edge) within the network.
              If None all streams nodes on the network will be reset.
        Returns:
        """
        if stream is None:
            streams: Dict[str, MC] = self.streams_to_dict()
            for k, v in streams.items():
                streams[k] = v.set_nodes([uuid.uuid4(), uuid.uuid4()])
            self.graph = Flowsheet(name=self.name).from_objects(objects=list(streams.values())).graph
        else:
            mc: MC = self.get_stream_by_name(stream)
            mc.set_nodes([uuid.uuid4(), uuid.uuid4()])
            self._update_graph(mc) 
    def _update_graph(self, mc: MC):
        """Update the graph with an existing stream object
        Args:
            mc: The stream object
        Returns:
        """
        # brutal approach - rebuild from streams
        strms: List[Union[Stream, MC]] = []
        for u, v, a in self.graph.edges(data=True):
            if a.get('mc') and a['mc'].name == mc.name:
                strms.append(mc)
            else:
                strms.append(a['mc'])
        self.graph = Flowsheet(name=self.name).from_objects(objects=strms).graph
[docs]
    def get_stream_by_name(self, name: str) -> MC:
        """Get the Stream object from the network by its name
        Args:
            name: The string name of the Stream object stored on an edge in the network.
        Returns:
        """
        res: Optional[Union[Stream, MC]] = None
        for u, v, a in self.graph.edges(data=True):
            if a.get('mc') and a['mc'].name == name:
                res = a['mc']
        if not res:
            raise ValueError(f"The specified name: {name} is not found on the network.")
        return res 
    def set_stream_parent(self, stream: str, parent: str):
        mc: MC = self.get_stream_by_name(stream)
        mc.set_parent_node(self.get_stream_by_name(parent))
        self._update_graph(mc)
    def set_stream_child(self, stream: str, child: str):
        mc: MC = self.get_stream_by_name(stream)
        mc.set_child_node(self.get_stream_by_name(child))
        self._update_graph(mc)
[docs]
    def reset_stream_nodes(self, stream: Optional[str] = None):
        """Reset stream nodes to break relationships
        Args:
            stream: The optional stream (edge) within the network.
              If None all streams nodes on the network will be reset.
        Returns:
        """
        if stream is None:
            streams: Dict[str, MC] = self.streams_to_dict()
            for k, v in streams.items():
                streams[k] = v.set_nodes([uuid.uuid4(), uuid.uuid4()])
            self.graph = Flowsheet(name=self.name).from_objects(objects=list(streams.values())).graph
        else:
            mc: MC = self.get_stream_by_name(stream)
            mc.set_nodes([uuid.uuid4(), uuid.uuid4()])
            self._update_graph(mc)