Spaces:
Runtime error
Runtime error
| import os, time, sys | |
| if not os.path.isfile("RF2_apr23.pt"): | |
| # send param download into background | |
| os.system( | |
| "(apt-get install aria2; aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/RF2_apr23.pt) &" | |
| ) | |
| if not os.path.isdir("RoseTTAFold2"): | |
| print("install RoseTTAFold2") | |
| os.system("git clone https://github.com/sokrypton/RoseTTAFold2.git") | |
| print(os.listdir("RoseTTAFold2")) | |
| os.system( | |
| "cd RoseTTAFold2/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install ." | |
| ) | |
| os.system( | |
| "wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py" | |
| ) | |
| # install hhsuite | |
| print("install hhsuite") | |
| os.makedirs("hhsuite", exist_ok=True) | |
| os.system( | |
| f"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/" | |
| ) | |
| print(os.listdir("hhsuite")) | |
| if os.path.isfile(f"RF2_apr23.pt.aria2"): | |
| print("downloading RoseTTAFold2 params") | |
| while os.path.isfile(f"RF2_apr23.pt.aria2"): | |
| time.sleep(5) | |
| os.environ["DGLBACKEND"] = "pytorch" | |
| sys.path.append("RoseTTAFold2/network") | |
| if "hhsuite" not in os.environ["PATH"]: | |
| os.environ["PATH"] += ":hhsuite/bin:hhsuite/scripts" | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from parsers import parse_a3m | |
| from api import run_mmseqs2 | |
| import torch | |
| from string import ascii_uppercase, ascii_lowercase | |
| import hashlib, re, os | |
| import random | |
| from Bio.PDB import * | |
| def get_hash(x): | |
| return hashlib.sha1(x.encode()).hexdigest() | |
| alphabet_list = list(ascii_uppercase + ascii_lowercase) | |
| from collections import OrderedDict, Counter | |
| import gradio as gr | |
| if not "pred" in dir(): | |
| from predict import Predictor | |
| print("compile RoseTTAFold2") | |
| model_params = "RF2_apr23.pt" | |
| if torch.cuda.is_available(): | |
| pred = Predictor(model_params, torch.device("cuda:0")) | |
| else: | |
| print("WARNING: using CPU") | |
| pred = Predictor(model_params, torch.device("cpu")) | |
| def get_unique_sequences(seq_list): | |
| unique_seqs = list(OrderedDict.fromkeys(seq_list)) | |
| return unique_seqs | |
| def get_msa(seq, jobname, cov=50, id=90, max_msa=2048, mode="unpaired_paired"): | |
| assert mode in ["unpaired", "paired", "unpaired_paired"] | |
| seqs = [seq] if isinstance(seq, str) else seq | |
| # collapse homooligomeric sequences | |
| counts = Counter(seqs) | |
| u_seqs = list(counts.keys()) | |
| u_nums = list(counts.values()) | |
| # expand homooligomeric sequences | |
| first_seq = "/".join(sum([[x] * n for x, n in zip(u_seqs, u_nums)], [])) | |
| msa = [first_seq] | |
| path = os.path.join(jobname, "msa") | |
| os.makedirs(path, exist_ok=True) | |
| if mode in ["paired", "unpaired_paired"] and len(u_seqs) > 1: | |
| print("getting paired MSA") | |
| out_paired = run_mmseqs2(u_seqs, f"{path}/", use_pairing=True) | |
| headers, sequences = [], [] | |
| for a3m_lines in out_paired: | |
| n = -1 | |
| for line in a3m_lines.split("\n"): | |
| if len(line) > 0: | |
| if line.startswith(">"): | |
| n += 1 | |
| if len(headers) < (n + 1): | |
| headers.append([]) | |
| sequences.append([]) | |
| headers[n].append(line) | |
| else: | |
| sequences[n].append(line) | |
| # filter MSA | |
| with open(f"{path}/paired_in.a3m", "w") as handle: | |
| for n, sequence in enumerate(sequences): | |
| handle.write(f">n{n}\n{''.join(sequence)}\n") | |
| os.system( | |
| f"hhfilter -i {path}/paired_in.a3m -id {id} -cov {cov} -o {path}/paired_out.a3m" | |
| ) | |
| with open(f"{path}/paired_out.a3m", "r") as handle: | |
| for line in handle: | |
| if line.startswith(">"): | |
| n = int(line[2:]) | |
| xs = sequences[n] | |
| # expand homooligomeric sequences | |
| xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)] | |
| msa.append("/".join(xs)) | |
| if len(msa) < max_msa and ( | |
| mode in ["unpaired", "unpaired_paired"] or len(u_seqs) == 1 | |
| ): | |
| print("getting unpaired MSA") | |
| out = run_mmseqs2(u_seqs, f"{path}/") | |
| Ls = [len(seq) for seq in u_seqs] | |
| sub_idx = [] | |
| sub_msa = [] | |
| sub_msa_num = 0 | |
| for n, a3m_lines in enumerate(out): | |
| sub_msa.append([]) | |
| with open(f"{path}/in_{n}.a3m", "w") as handle: | |
| handle.write(a3m_lines) | |
| # filter | |
| os.system( | |
| f"hhfilter -i {path}/in_{n}.a3m -id {id} -cov {cov} -o {path}/out_{n}.a3m" | |
| ) | |
| with open(f"{path}/out_{n}.a3m", "r") as handle: | |
| for line in handle: | |
| if not line.startswith(">"): | |
| xs = ["-" * l for l in Ls] | |
| xs[n] = line.rstrip() | |
| # expand homooligomeric sequences | |
| xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)] | |
| sub_msa[-1].append("/".join(xs)) | |
| sub_msa_num += 1 | |
| sub_idx.append(list(range(len(sub_msa[-1])))) | |
| while len(msa) < max_msa and sub_msa_num > 0: | |
| for n in range(len(sub_idx)): | |
| if len(sub_idx[n]) > 0: | |
| msa.append(sub_msa[n][sub_idx[n].pop(0)]) | |
| sub_msa_num -= 1 | |
| if len(msa) == max_msa: | |
| break | |
| with open(f"{jobname}/msa.a3m", "w") as handle: | |
| for n, sequence in enumerate(msa): | |
| handle.write(f">n{n}\n{sequence}\n") | |
| from Bio.PDB.PDBExceptions import PDBConstructionWarning | |
| import warnings | |
| from Bio.PDB import * | |
| import numpy as np | |
| def add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname): | |
| pdb_parser = PDBParser() | |
| warnings.filterwarnings("ignore", category=PDBConstructionWarning) | |
| structure = pdb_parser.get_structure( | |
| "pdb", f"{jobname}/rf2_seed{best_seed}_00_pred.pdb" | |
| ) | |
| io = MMCIFIO() | |
| io.set_structure(structure) | |
| io.save(f"{jobname}/rf2_seed{best_seed}_00_pred.cif") | |
| plddt_cif = f"""# | |
| loop_ | |
| _ma_qa_metric.id | |
| _ma_qa_metric.mode | |
| _ma_qa_metric.name | |
| _ma_qa_metric.software_group_id | |
| _ma_qa_metric.type | |
| 1 global pLDDT 1 pLDDT | |
| 2 local pLDDT 1 pLDDT | |
| # | |
| _ma_qa_metric_global.metric_id 1 | |
| _ma_qa_metric_global.metric_value {best_plddt:.3f} | |
| _ma_qa_metric_global.model_id 1 | |
| _ma_qa_metric_global.ordinal_id 1 | |
| # | |
| loop_ | |
| _ma_qa_metric_local.label_asym_id | |
| _ma_qa_metric_local.label_comp_id | |
| _ma_qa_metric_local.label_seq_id | |
| _ma_qa_metric_local.metric_id | |
| _ma_qa_metric_local.metric_value | |
| _ma_qa_metric_local.model_id | |
| _ma_qa_metric_local.ordinal_id""" | |
| for chain in structure[0]: | |
| for i, residue in enumerate(chain): | |
| plddt_cif += f"\n{chain.id} {residue.resname} {residue.id[1]} 2 {best_plddts[i]*100:.2f} 1 {residue.id[1]}" | |
| plddt_cif += "\n#" | |
| with open(f"{jobname}/rf2_seed{best_seed}_00_pred.cif", "a") as f: | |
| f.write(plddt_cif) | |
| def predict( | |
| sequence, | |
| jobname, | |
| sym, | |
| order, | |
| msa_concat_mode, | |
| msa_method, | |
| pair_mode, | |
| collapse_identical, | |
| num_recycles, | |
| use_mlm, | |
| use_dropout, | |
| max_msa, | |
| random_seed, | |
| num_models, | |
| mode="web", | |
| ): | |
| if os.path.exists("/home/user/app"): # crude check if on spaces | |
| if len(sequence) > 600: | |
| raise gr.Error( | |
| f"Your sequence is too long ({len(sequence)}). " | |
| "Please use the full version of RoseTTAfold2 directly from GitHub." | |
| ) | |
| random_seed = int(random_seed) | |
| num_models = int(num_models) | |
| max_msa = int(max_msa) | |
| num_recycles = int(num_recycles) | |
| order = int(order) | |
| max_extra_msa = max_msa * 8 | |
| print("sequence", sequence) | |
| sequence = re.sub("[^A-Z:]", "", sequence.replace("/", ":").upper()) | |
| sequence = re.sub(":+", ":", sequence) | |
| sequence = re.sub("^[:]+", "", sequence) | |
| sequence = re.sub("[:]+$", "", sequence) | |
| print("sequence", sequence) | |
| if sym in ["X", "C"]: | |
| copies = int(order) | |
| elif sym in ["D"]: | |
| copies = int(order) * 2 | |
| else: | |
| copies = {"T": 12, "O": 24, "I": 60}[sym] | |
| order = "" | |
| symm = sym + str(order) | |
| sequences = sequence.replace(":", "/").split("/") | |
| if collapse_identical: | |
| u_sequences = get_unique_sequences(sequences) | |
| else: | |
| u_sequences = sequences | |
| sequences = sum([u_sequences] * copies, []) | |
| lengths = [len(s) for s in sequences] | |
| # TODO | |
| subcrop = 1000 if sum(lengths) > 1400 else -1 | |
| sequence = "/".join(sequences) | |
| jobname = jobname + "_" + symm + "_" + get_hash(sequence)[:5] | |
| print(f"jobname: {jobname}") | |
| print(f"lengths: {lengths}") | |
| print("final_sequence", u_sequences) | |
| os.makedirs(jobname, exist_ok=True) | |
| if msa_method == "mmseqs2": | |
| get_msa(u_sequences, jobname, mode=pair_mode, max_msa=max_extra_msa) | |
| elif msa_method == "single_sequence": | |
| u_sequence = "/".join(u_sequences) | |
| with open(f"{jobname}/msa.a3m", "w") as a3m: | |
| a3m.write(f">{jobname}\n{u_sequence}\n") | |
| # elif msa_method == "custom_a3m": | |
| # print("upload custom a3m") | |
| # # msa_dict = files.upload() | |
| # lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines() | |
| # a3m_lines = [] | |
| # for line in lines: | |
| # line = line.replace("\x00", "") | |
| # if len(line) > 0 and not line.startswith("#"): | |
| # a3m_lines.append(line) | |
| # with open(f"{jobname}/msa.a3m", "w") as a3m: | |
| # a3m.write("\n".join(a3m_lines)) | |
| best_plddt = None | |
| best_seed = None | |
| for seed in range(int(random_seed), int(random_seed) + int(num_models)): | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| npz = f"{jobname}/rf2_seed{seed}_00.npz" | |
| mlm = 0.15 if use_mlm else 0 | |
| print("MLM", mlm, use_mlm) | |
| pred.predict( | |
| inputs=[f"{jobname}/msa.a3m"], | |
| out_prefix=f"{jobname}/rf2_seed{seed}", | |
| symm=symm, | |
| ffdb=None, # TODO (templates), | |
| n_recycles=num_recycles, | |
| msa_mask=0.15 if use_mlm else 0, | |
| msa_concat_mode=msa_concat_mode, | |
| nseqs=max_msa, | |
| nseqs_full=max_extra_msa, | |
| subcrop=subcrop, | |
| is_training=use_dropout, | |
| ) | |
| plddt = np.load(npz)["lddt"].mean() | |
| if best_plddt is None or plddt > best_plddt: | |
| best_plddt = plddt | |
| best_plddts = np.load(npz)["lddt"] | |
| best_seed = seed | |
| if mode == "web": | |
| # Mol* only displays AlphaFold plDDT if they are in a cif. | |
| pdb_parser = PDBParser() | |
| mmcif_parser = MMCIFParser() | |
| plddt_cif = add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname) | |
| return f"{jobname}/rf2_seed{best_seed}_00_pred.cif" | |
| else: | |
| # for api just return a pdb file | |
| return f"{jobname}/rf2_seed{best_seed}_00_pred.pdb" | |
| def predict_api( | |
| sequence, | |
| jobname, | |
| sym, | |
| order, | |
| msa_concat_mode, | |
| msa_method, | |
| pair_mode, | |
| collapse_identical, | |
| num_recycles, | |
| use_mlm, | |
| use_dropout, | |
| max_msa, | |
| random_seed, | |
| num_models, | |
| ): | |
| filename = predict( | |
| sequence, | |
| jobname, | |
| sym, | |
| order, | |
| msa_concat_mode, | |
| msa_method, | |
| pair_mode, | |
| collapse_identical, | |
| num_recycles, | |
| use_mlm, | |
| use_dropout, | |
| max_msa, | |
| random_seed, | |
| num_models, | |
| mode="api", | |
| ) | |
| with open(f"{filename}") as fp: | |
| return fp.read() | |
| def molecule(input_pdb, public_link): | |
| print(input_pdb) | |
| print(public_link + "/file=" + input_pdb) | |
| link = public_link + "/file=" + input_pdb | |
| x = ( | |
| """<!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8" /> | |
| <meta name="viewport" content="width=device-width, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0"> | |
| <title>PDBe Molstar - Helper functions</title> | |
| <!-- Molstar CSS & JS --> | |
| <link rel="stylesheet" type="text/css" href="https://www.ebi.ac.uk/pdbe/pdb-component-library/css/pdbe-molstar-light-3.1.0.css"> | |
| <script type="text/javascript" src="https://www.ebi.ac.uk/pdbe/pdb-component-library/js/pdbe-molstar-plugin-3.1.0.js"></script> | |
| <style> | |
| * { | |
| margin: 0; | |
| padding: 0; | |
| box-sizing: border-box; | |
| } | |
| .msp-plugin ::-webkit-scrollbar-thumb { | |
| background-color: #474748 !important; | |
| } | |
| .viewerSection { | |
| margin: 120px 0 0 0px; | |
| } | |
| #myViewer{ | |
| float:left; | |
| width:100%; | |
| height: 800px; | |
| position:relative; | |
| } | |
| .btn{ | |
| font-family: "Open Sans", sans-serif; | |
| display: inline-block; | |
| outline: none; | |
| cursor: pointer; | |
| font-weight: 600; | |
| border-radius: 3px; | |
| padding: 12px 24px; | |
| border: 0; | |
| margin:0 10px; | |
| line-height: 1.15; | |
| font-size: 16px; | |
| text-decoration: none; | |
| } | |
| .btn-orange{ | |
| background: #ff5000; | |
| color: #fff; | |
| } | |
| .btn-gray{ | |
| color: #3a4149; | |
| background: #e7ebee; | |
| } | |
| .btn:hover{ | |
| transition: all .1s ease; | |
| box-shadow: 0 0 0 0 #fff, 0 0 0 3px #ddd;} | |
| .text-center{ | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| padding: 20px 0; | |
| } | |
| .flex{ | |
| padding: 10px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| width:fit-content; | |
| } | |
| .flex svg{ | |
| margin-right: 10px; | |
| width:16px; | |
| height:16px; | |
| } | |
| .flex a{ | |
| margin:0 10px; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="text-center"> | |
| <a class="btn btn-orange flex" href=\"""" | |
| + link | |
| + """\" target="_blank"> <svg fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" aria-hidden="true"> | |
| <path stroke-linecap="round" stroke-linejoin="round" d="M19.5 13.5L12 21m0 0l-7.5-7.5M12 21V3"></path> | |
| </svg> <span>CIF File</span></a> | |
| <a class="btn btn-gray flex" href=\"""" | |
| + link.replace(".cif", ".pdb") | |
| + """\" target="_blank"> <svg fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" aria-hidden="true"> | |
| <path stroke-linecap="round" stroke-linejoin="round" d="M19.5 13.5L12 21m0 0l-7.5-7.5M12 21V3"></path> | |
| </svg> <span>PDB File</span></a> | |
| </div> | |
| <div class="viewerSection"> | |
| <!-- Molstar container --> | |
| <div id="myViewer"></div> | |
| </div> | |
| <script> | |
| //Create plugin instance | |
| var viewerInstance = new PDBeMolstarPlugin(); | |
| //Set options (Checkout available options list in the documentation) | |
| var options = { | |
| customData: { | |
| url: \"""" | |
| + link | |
| + """\", | |
| format: "cif" | |
| }, | |
| alphafoldView: true, | |
| bgColor: {r:255, g:255, b:255}, | |
| //hideCanvasControls: ["selection", "animation", "controlToggle", "controlInfo"] | |
| } | |
| //Get element from HTML/Template to place the viewer | |
| var viewerContainer = document.getElementById("myViewer"); | |
| //Call render method to display the 3D view | |
| viewerInstance.render(viewerContainer, options); | |
| </script> | |
| </body> | |
| </html>""" | |
| ) | |
| return f"""<iframe style="width: 100%; height: 1000px" name="result" allow="midi; geolocation; microphone; camera; | |
| display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
| allow-scripts allow-same-origin allow-popups | |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
| allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
| def predict_web( | |
| sequence, | |
| jobname, | |
| sym, | |
| order, | |
| msa_concat_mode, | |
| msa_method, | |
| pair_mode, | |
| collapse_identical, | |
| num_recycles, | |
| use_mlm, | |
| use_dropout, | |
| max_msa, | |
| random_seed, | |
| num_models, | |
| ): | |
| if os.path.exists("/home/user/app"): | |
| public_link = "https://simonduerr-rosettafold2.hf.space" | |
| else: | |
| public_link = "http://localhost:7860" | |
| filename = predict( | |
| sequence, | |
| jobname, | |
| sym, | |
| order, | |
| msa_concat_mode, | |
| msa_method, | |
| pair_mode, | |
| collapse_identical, | |
| num_recycles, | |
| use_mlm, | |
| use_dropout, | |
| max_msa, | |
| random_seed, | |
| num_models, | |
| mode="web", | |
| ) | |
| return molecule(filename, public_link) | |
| with gr.Blocks() as rosettafold: | |
| gr.Markdown("# RoseTTAFold2") | |
| gr.Markdown( | |
| """If using please cite: [manuscript](https://www.biorxiv.org/content/10.1101/2023.05.24.542179v1) | |
| <br> Heavily based on [RoseTTAFold2 ColabFold notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/RoseTTAFold2.ipynb)""" | |
| ) | |
| with gr.Accordion("How to use in PyMol", open=False): | |
| gr.HTML( | |
| """<code>os.system('wget https://huggingface.co/spaces/simonduerr/rosettafold2/raw/main/rosettafold_pymol.py') <br> | |
| run rosettafold_pymol.py <br> | |
| rosettafold2 sequence, jobname, [sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models] <br> | |
| color_plddt jobname</code> | |
| """ | |
| ) | |
| sequence = gr.Textbox( | |
| label="sequence", | |
| value="PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK", | |
| ) | |
| jobname = gr.Textbox(label="jobname", value="test") | |
| with gr.Accordion("Additional settings", open=False): | |
| sym = gr.Textbox(label="sym", value="X") | |
| order = gr.Slider(label="order", value=1, step=1, minimum=1, maximum=12) | |
| msa_concat_mode = gr.Dropdown( | |
| label="msa_concat_mode", | |
| value="default", | |
| choices=["diag", "repeat", "default"], | |
| ) | |
| msa_method = gr.Dropdown( | |
| label="msa_method", | |
| value="single_sequence", | |
| choices=[ | |
| "mmseqs2", | |
| "single_sequence", | |
| ], # dont allow custom a3m for now , "custom_a3m" | |
| ) | |
| pair_mode = gr.Dropdown( | |
| label="pair_mode", | |
| value="unpaired_paired", | |
| choices=["unpaired_paired", "paired", "unpaired"], | |
| ) | |
| num_recycles = gr.Dropdown( | |
| label="num_recycles", value="6", choices=["0", "1", "3", "6", "12", "24"] | |
| ) | |
| use_mlm = gr.Checkbox(label="use_mlm", value=False) | |
| use_dropout = gr.Checkbox(label="use_dropout", value=False) | |
| collapse_identical = gr.Checkbox(label="collapse_identical", value=False) | |
| max_msa = gr.Dropdown( | |
| choices=["16", "32", "64", "128", "256", "512"], | |
| value="16", | |
| label="max_msa", | |
| ) | |
| random_seed = gr.Textbox(label="random_seed", value=0) | |
| num_models = gr.Dropdown( | |
| label="num_models", value="1", choices=["1", "2", "4", "8", "16", "32"] | |
| ) | |
| btn = gr.Button("Run", visible=False) | |
| btn_web = gr.Button("Run") | |
| output_plain = gr.HTML() | |
| output = gr.HTML() | |
| btn.click( | |
| fn=predict_api, | |
| inputs=[ | |
| sequence, | |
| jobname, | |
| sym, | |
| order, | |
| msa_concat_mode, | |
| msa_method, | |
| pair_mode, | |
| collapse_identical, | |
| num_recycles, | |
| use_mlm, | |
| use_dropout, | |
| max_msa, | |
| random_seed, | |
| num_models, | |
| ], | |
| outputs=output_plain, | |
| api_name="rosettafold2", | |
| ) | |
| btn_web.click( | |
| fn=predict_web, | |
| inputs=[ | |
| sequence, | |
| jobname, | |
| sym, | |
| order, | |
| msa_concat_mode, | |
| msa_method, | |
| pair_mode, | |
| collapse_identical, | |
| num_recycles, | |
| use_mlm, | |
| use_dropout, | |
| max_msa, | |
| random_seed, | |
| num_models, | |
| ], | |
| outputs=output, | |
| ) | |
| rosettafold.launch() | |