| | import os |
| | import json |
| |
|
| | def create_graph(lora_path, lora_name): |
| | try: |
| | import matplotlib.pyplot as plt |
| | from matplotlib.ticker import ScalarFormatter |
| | |
| | peft_model_path = f'{lora_path}/training_graph.json' |
| | image_model_path = f'{lora_path}/training_graph.png' |
| | |
| | if os.path.exists(peft_model_path): |
| | |
| | with open(peft_model_path, 'r') as file: |
| | data = json.load(file) |
| | |
| | x = [item['epoch'] for item in data] |
| | y1 = [item['learning_rate'] for item in data] |
| | y2 = [item['loss'] for item in data] |
| |
|
| | |
| | fig, ax1 = plt.subplots(figsize=(10, 6)) |
| | |
| |
|
| | |
| | ax1.plot(x, y1, 'b-', label='Learning Rate') |
| | ax1.set_xlabel('Epoch') |
| | ax1.set_ylabel('Learning Rate', color='b') |
| | ax1.tick_params('y', colors='b') |
| |
|
| | |
| | ax2 = ax1.twinx() |
| |
|
| | |
| | ax2.plot(x, y2, 'r-', label='Loss') |
| | ax2.set_ylabel('Loss', color='r') |
| | ax2.tick_params('y', colors='r') |
| |
|
| | |
| | ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True)) |
| | ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0)) |
| |
|
| | |
| | ax1.grid(True) |
| |
|
| | |
| | lines, labels = ax1.get_legend_handles_labels() |
| | lines2, labels2 = ax2.get_legend_handles_labels() |
| | ax2.legend(lines + lines2, labels + labels2, loc='best') |
| |
|
| | |
| | plt.title(f'{lora_name} LR and Loss vs Epoch') |
| |
|
| | |
| | plt.savefig(image_model_path) |
| |
|
| | print(f"Graph saved in {image_model_path}") |
| | else: |
| | print(f"File 'training_graph.json' does not exist in the {lora_path}") |
| | |
| | except ImportError: |
| | print("matplotlib is not installed. Please install matplotlib to create PNG graphs") |