Spaces:
Build error
Build error
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend for server environments | |
| import networkx as nx | |
| import json | |
| import numpy as np | |
| from loguru import logger | |
| import os | |
| import tempfile | |
| from datetime import datetime | |
| class DAGVisualizer: | |
| def __init__(self): | |
| # Configure Matplotlib to use IEEE-style parameters | |
| plt.rcParams.update({ | |
| 'font.family': 'DejaVu Sans', # Use available font instead of Times New Roman | |
| 'font.size': 10, | |
| 'axes.linewidth': 1.2, | |
| 'axes.labelsize': 12, | |
| 'xtick.labelsize': 10, | |
| 'ytick.labelsize': 10, | |
| 'legend.fontsize': 10, | |
| 'figure.titlesize': 14 | |
| }) | |
| def create_dag_from_tasks(self, task_data): | |
| """ | |
| Create a directed graph from task data. | |
| Args: | |
| task_data: Dictionary containing tasks with structure like: | |
| { | |
| "tasks": [ | |
| { | |
| "task": "task_name", | |
| "instruction_function": { | |
| "name": "function_name", | |
| "robot_ids": ["robot1", "robot2"], | |
| "dependencies": ["dependency_task"], | |
| "object_keywords": ["object1", "object2"] | |
| } | |
| } | |
| ] | |
| } | |
| Returns: | |
| NetworkX DiGraph object | |
| """ | |
| if not task_data or "tasks" not in task_data: | |
| logger.warning("Invalid task data structure") | |
| return None | |
| # Create a directed graph | |
| G = nx.DiGraph() | |
| # Add nodes and store mapping from task name to ID | |
| task_mapping = {} | |
| for i, task in enumerate(task_data["tasks"]): | |
| task_id = i + 1 | |
| task_name = task["task"] | |
| task_mapping[task_name] = task_id | |
| # Add node with attributes | |
| G.add_node(task_id, | |
| name=task_name, | |
| function=task["instruction_function"]["name"], | |
| robots=task["instruction_function"].get("robot_ids", []), | |
| objects=task["instruction_function"].get("object_keywords", [])) | |
| # Add dependency edges | |
| for i, task in enumerate(task_data["tasks"]): | |
| task_id = i + 1 | |
| dependencies = task["instruction_function"]["dependencies"] | |
| for dep in dependencies: | |
| if dep in task_mapping: | |
| dep_id = task_mapping[dep] | |
| G.add_edge(dep_id, task_id) | |
| return G | |
| def calculate_layout(self, G): | |
| """ | |
| Calculate hierarchical layout for the graph based on dependencies. | |
| """ | |
| if not G: | |
| return {} | |
| # Calculate layers based on dependencies | |
| layers = {} | |
| def get_layer(node_id, visited=None): | |
| if visited is None: | |
| visited = set() | |
| if node_id in visited: | |
| return 0 | |
| visited.add(node_id) | |
| predecessors = list(G.predecessors(node_id)) | |
| if not predecessors: | |
| return 0 | |
| return max(get_layer(pred, visited.copy()) for pred in predecessors) + 1 | |
| for node in G.nodes(): | |
| layer = get_layer(node) | |
| layers.setdefault(layer, []).append(node) | |
| # Calculate positions by layer | |
| pos = {} | |
| layer_height = 3.0 | |
| node_width = 4.0 | |
| for layer_idx, nodes in layers.items(): | |
| y = layer_height * (len(layers) - 1 - layer_idx) | |
| start_x = -(len(nodes) - 1) * node_width / 2 | |
| for i, node in enumerate(sorted(nodes)): | |
| pos[node] = (start_x + i * node_width, y) | |
| return pos | |
| def create_dag_visualization(self, task_data, title="Robot Task Dependency Graph"): | |
| """ | |
| Create a DAG visualization from task data and return the image path. | |
| Args: | |
| task_data: Task data dictionary | |
| title: Title for the graph | |
| Returns: | |
| str: Path to the generated image file | |
| """ | |
| try: | |
| # Create graph | |
| G = self.create_dag_from_tasks(task_data) | |
| if not G or len(G.nodes()) == 0: | |
| logger.warning("No tasks found or invalid graph structure") | |
| return None | |
| # Calculate layout | |
| pos = self.calculate_layout(G) | |
| # Create figure | |
| fig, ax = plt.subplots(1, 1, figsize=(max(12, len(G.nodes()) * 2), 8)) | |
| # Draw edges with arrows | |
| nx.draw_networkx_edges(G, pos, | |
| edge_color='#2E86AB', | |
| arrows=True, | |
| arrowsize=20, | |
| arrowstyle='->', | |
| width=2, | |
| alpha=0.8, | |
| connectionstyle="arc3,rad=0.1") | |
| # Color nodes based on their position in the graph | |
| node_colors = [] | |
| for node in G.nodes(): | |
| if G.in_degree(node) == 0: # Start nodes | |
| node_colors.append('#F24236') | |
| elif G.out_degree(node) == 0: # End nodes | |
| node_colors.append('#A23B72') | |
| else: # Intermediate nodes | |
| node_colors.append('#F18F01') | |
| # Draw nodes | |
| nx.draw_networkx_nodes(G, pos, | |
| node_color=node_colors, | |
| node_size=3500, | |
| alpha=0.9, | |
| edgecolors='black', | |
| linewidths=2) | |
| # Label nodes with task IDs | |
| node_labels = {node: f"T{node}" for node in G.nodes()} | |
| nx.draw_networkx_labels(G, pos, node_labels, | |
| font_size=18, | |
| font_weight='bold', | |
| font_color='white') | |
| # Add detailed info text boxes for each task | |
| for i, node in enumerate(G.nodes()): | |
| x, y = pos[node] | |
| function_name = G.nodes[node]['function'] | |
| robots = G.nodes[node]['robots'] | |
| objects = G.nodes[node]['objects'] | |
| # Create info text content | |
| info_text = f"Task {node}: {function_name.replace('_', ' ').title()}\n" | |
| if robots: | |
| robot_text = ", ".join([r.replace('robot_', '').replace('_', ' ').title() for r in robots]) | |
| info_text += f"Robots: {robot_text}\n" | |
| if objects: | |
| object_text = ", ".join(objects) | |
| info_text += f"Objects: {object_text}" | |
| # Calculate offset based on node position to avoid overlaps | |
| offset_x = 2.2 if i % 2 == 0 else -2.2 | |
| offset_y = 0.5 if i % 4 < 2 else -0.5 | |
| # Choose alignment based on offset direction | |
| h_align = 'left' if offset_x > 0 else 'right' | |
| # Draw text box | |
| bbox_props = dict(boxstyle="round,pad=0.4", | |
| facecolor='white', | |
| edgecolor='gray', | |
| alpha=0.95, | |
| linewidth=1) | |
| ax.text(x + offset_x, y + offset_y, info_text, | |
| bbox=bbox_props, | |
| fontsize=12, | |
| verticalalignment='center', | |
| horizontalalignment=h_align, | |
| weight='bold') | |
| # Draw dashed connector line from node to text box | |
| ax.plot([x, x + offset_x], [y, y + offset_y], | |
| linestyle='--', color='gray', alpha=0.6, linewidth=1) | |
| # Expand axis limits to fit everything | |
| x_vals = [coord[0] for coord in pos.values()] | |
| y_vals = [coord[1] for coord in pos.values()] | |
| ax.set_xlim(min(x_vals) - 4.0, max(x_vals) + 4.0) | |
| ax.set_ylim(min(y_vals) - 2.0, max(y_vals) + 2.0) | |
| # Set overall figure properties | |
| ax.set_title(title, fontsize=16, fontweight='bold', pad=20) | |
| ax.set_aspect('equal') | |
| ax.margins(0.2) | |
| ax.axis('off') | |
| # Add legend for node types - Hidden to avoid covering content | |
| # legend_elements = [ | |
| # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F24236', | |
| # markersize=10, label='Start Tasks', markeredgecolor='black'), | |
| # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#A23B72', | |
| # markersize=10, label='End Tasks', markeredgecolor='black'), | |
| # plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='#F18F01', | |
| # markersize=10, label='Intermediate Tasks', markeredgecolor='black'), | |
| # plt.Line2D([0], [0], color='#2E86AB', linewidth=2, label='Dependencies') | |
| # ] | |
| # ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.05, 1.05)) | |
| # Adjust layout and save | |
| plt.tight_layout() | |
| # Create temporary file for saving the image | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| temp_dir = tempfile.gettempdir() | |
| image_path = os.path.join(temp_dir, f'dag_visualization_{timestamp}.png') | |
| plt.savefig(image_path, dpi=400, bbox_inches='tight', | |
| pad_inches=0.1, facecolor='white', edgecolor='none') | |
| plt.close(fig) # Close figure to free memory | |
| logger.info(f"DAG visualization saved to: {image_path}") | |
| return image_path | |
| except Exception as e: | |
| logger.error(f"Error creating DAG visualization: {e}") | |
| return None | |
| def create_simplified_dag_visualization(self, task_data, title="Robot Task Graph"): | |
| """ | |
| Create a simplified DAG visualization suitable for smaller displays. | |
| Args: | |
| task_data: Task data dictionary | |
| title: Title for the graph | |
| Returns: | |
| str: Path to the generated image file | |
| """ | |
| try: | |
| # Create graph | |
| G = self.create_dag_from_tasks(task_data) | |
| if not G or len(G.nodes()) == 0: | |
| logger.warning("No tasks found or invalid graph structure") | |
| return None | |
| # Calculate layout | |
| pos = self.calculate_layout(G) | |
| # Create figure for simplified graph | |
| fig, ax = plt.subplots(1, 1, figsize=(10, 6)) | |
| # Draw edges | |
| nx.draw_networkx_edges(G, pos, | |
| edge_color='black', | |
| arrows=True, | |
| arrowsize=15, | |
| arrowstyle='->', | |
| width=1.5) | |
| # Draw nodes | |
| nx.draw_networkx_nodes(G, pos, | |
| node_color='lightblue', | |
| node_size=3000, | |
| edgecolors='black', | |
| linewidths=1.5) | |
| # Add node labels with simplified names | |
| labels = {} | |
| for node in G.nodes(): | |
| function_name = G.nodes[node]['function'] | |
| simplified_name = function_name.replace('_', ' ').title() | |
| if len(simplified_name) > 15: | |
| simplified_name = simplified_name[:12] + "..." | |
| labels[node] = f"T{node}\n{simplified_name}" | |
| nx.draw_networkx_labels(G, pos, labels, | |
| font_size=11, | |
| font_weight='bold') | |
| ax.set_title(title, fontsize=14, fontweight='bold') | |
| ax.axis('off') | |
| # Adjust layout and save | |
| plt.tight_layout() | |
| # Create temporary file for saving the image | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| temp_dir = tempfile.gettempdir() | |
| image_path = os.path.join(temp_dir, f'simple_dag_{timestamp}.png') | |
| plt.savefig(image_path, dpi=400, bbox_inches='tight') | |
| plt.close(fig) # Close figure to free memory | |
| logger.info(f"Simplified DAG visualization saved to: {image_path}") | |
| return image_path | |
| except Exception as e: | |
| logger.error(f"Error creating simplified DAG visualization: {e}") | |
| return None |