Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer | |
| import os | |
| from slices.core import SLICES | |
| from pymatgen.core.structure import Structure | |
| from pymatgen.io.cif import CifWriter | |
| from pymatgen.io.ase import AseAtomsAdaptor | |
| from ase.io import write as ase_write | |
| import tempfile | |
| import time | |
| # 设置PyTorch使用的线程数 | |
| torch.set_num_threads(2) | |
| def load_quantized_model(model_path): | |
| model = MatterGPTWrapper.from_pretrained(model_path) | |
| model.to('cpu') | |
| model.eval() | |
| quantized_model = torch.quantization.quantize_dynamic( | |
| model, {torch.nn.Linear}, dtype=torch.qint8 | |
| ) | |
| return quantized_model | |
| # Load and quantize the model | |
| model_path = "./" | |
| quantized_model = load_quantized_model(model_path) | |
| quantized_model.to("cpu") | |
| quantized_model.eval() | |
| # Load the tokenizer | |
| tokenizer_path = "Voc_prior" | |
| tokenizer = SimpleTokenizer(tokenizer_path) | |
| # Initialize SLICES backend | |
| try: | |
| backend = SLICES(relax_model="chgnet",fmax=0.4,steps=25) | |
| except Exception as e: | |
| backend = SLICES(relax_model=None) | |
| def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p): | |
| condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32) | |
| context = '>' | |
| x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long) | |
| with torch.no_grad(): | |
| generated = quantized_model.generate(x, prop=condition, max_length=max_length, | |
| temperature=temperature, do_sample=do_sample, | |
| top_k=top_k, top_p=top_p) | |
| return tokenizer.decode(generated[0].tolist()) | |
| def generate_slices(formation_energy, band_gap): | |
| return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, | |
| quantized_model.config.block_size, 1.2, True, 0, 0.9) | |
| def wrap_structure(structure): | |
| """Wrap all atoms back into the unit cell.""" | |
| for i, site in enumerate(structure): | |
| frac_coords = site.frac_coords % 1.0 | |
| structure.replace(i, species=site.species, coords=frac_coords, coords_are_cartesian=False) | |
| return structure | |
| def convert_and_visualize(slices_string): | |
| try: | |
| structure, energy = backend.SLICES2structure(slices_string) | |
| # Wrap atoms back into the unit cell | |
| structure = wrap_structure(structure) | |
| # Generate CIF and save to temporary file | |
| cif_file = tempfile.NamedTemporaryFile(mode='w', suffix='.cif', delete=False) | |
| cif_writer = CifWriter(structure) | |
| cif_writer.write_file(cif_file.name) | |
| # Generate structure summary | |
| summary = f"Formula: {structure.composition.reduced_formula}\n" | |
| summary += f"Number of sites: {len(structure)}\n" | |
| summary += f"Lattice parameters: a={structure.lattice.a:.3f}, b={structure.lattice.b:.3f}, c={structure.lattice.c:.3f}\n" | |
| summary += f"Angles: alpha={structure.lattice.alpha:.2f}, beta={structure.lattice.beta:.2f}, gamma={structure.lattice.gamma:.2f}\n" | |
| summary += f"Volume: {structure.volume:.3f} ų\n" | |
| summary += f"Density: {structure.density:.3f} g/cm³" | |
| # Generate structure image using ASE and save to temporary file | |
| atoms = AseAtomsAdaptor.get_atoms(structure) | |
| image_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
| ase_write(image_file.name, atoms, format='png', rotation='10x,10y,10z') | |
| return cif_file.name, image_file.name, summary, f"Conversion successful. Energy: {energy:.4f} eV/atom", True | |
| except Exception as e: | |
| return "", "", "", f"Conversion failed. Error: {str(e)}", False | |
| def generate_and_convert(formation_energy, band_gap): | |
| max_attempts = 5 | |
| start_time = time.time() | |
| max_time = 300 # 5 minutes maximum execution time | |
| for attempt in range(max_attempts): | |
| if time.time() - start_time > max_time: | |
| return "Exceeded maximum execution time", "", "", "", "Generation and conversion failed due to timeout" | |
| slices_string = generate_slices(formation_energy, band_gap) | |
| cif_file, image_file, structure_summary, status, success = convert_and_visualize(slices_string) | |
| if success: | |
| return slices_string, cif_file, image_file, structure_summary, f"Successful on attempt {attempt + 1}: {status}" | |
| if attempt == max_attempts - 1: | |
| return slices_string, "", "", "", f"Failed after {max_attempts} attempts: {status}" | |
| return "Failed to generate valid SLICES string", "", "", "", "Generation failed" | |
| # Create the Gradio interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Crystal Inverse Designer: From Properties to Structures") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300) | |
| gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure.**") | |
| gr.Markdown("**Allow 1-2 minutes for completion using 2 CPUs.**") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| band_gap = gr.Number(label="Band Gap (eV)", value=2.0) | |
| formation_energy = gr.Number(label="Formation Energy (eV/atom)", value=-1.0) | |
| generate_button = gr.Button("Generate") | |
| with gr.Column(scale=3): | |
| slices_output = gr.Textbox(label="Generated SLICES String") | |
| cif_output = gr.File(label="Download CIF", file_types=[".cif"]) | |
| structure_image = gr.Image(label="Structure Visualization") | |
| structure_summary = gr.Textbox(label="Structure Summary", lines=6) | |
| conversion_status = gr.Textbox(label="Conversion Status") | |
| generate_button.click( | |
| generate_and_convert, | |
| inputs=[formation_energy, band_gap], | |
| outputs=[slices_output, cif_output, structure_image, structure_summary, conversion_status] | |
| ) | |
| iface.launch(share=True) |