File size: 7,581 Bytes
bc21134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#!/usr/bin/env python3
"""
Decode all 80 generated sequences and test them with HMD-AMP.
"""

import torch
import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
import os
from datetime import datetime
from tqdm import tqdm
import sys

# Import the decoder
from final_sequence_decoder import EmbeddingToSequenceConverter

# Import HMD-AMP components
sys.path.append('/home/edwardsun/flow/HMD-AMP')
from sklearn.utils import shuffle
import esm
from deepforest import CascadeForestClassifier
from src.utils import *

def load_generated_embeddings():
    """Load all generated embeddings from today."""
    base_path = '/data2/edwardsun/generated_samples'
    today = '20250829'
    
    files = [
        f'generated_amps_best_model_no_cfg_{today}.pt',
        f'generated_amps_best_model_weak_cfg_{today}.pt', 
        f'generated_amps_best_model_strong_cfg_{today}.pt',
        f'generated_amps_best_model_very_strong_cfg_{today}.pt'
    ]
    
    all_embeddings = []
    all_labels = []
    
    for file in files:
        file_path = os.path.join(base_path, file)
        if os.path.exists(file_path):
            print(f"Loading {file}...")
            embeddings = torch.load(file_path, map_location='cpu')
            
            # Extract config type from filename
            if 'no_cfg' in file:
                cfg_type = 'no_cfg'
            elif 'weak_cfg' in file:
                cfg_type = 'weak_cfg'
            elif 'strong_cfg' in file and 'very' not in file:
                cfg_type = 'strong_cfg'
            elif 'very_strong_cfg' in file:
                cfg_type = 'very_strong_cfg'
            
            # Each file contains 20 sequences
            for i in range(embeddings.shape[0]):
                all_embeddings.append(embeddings[i])
                all_labels.append(f"{cfg_type}_{i+1}")
                
    print(f"βœ“ Loaded {len(all_embeddings)} embeddings total")
    return all_embeddings, all_labels

def decode_embeddings_to_sequences(embeddings, labels):
    """Decode embeddings to sequences."""
    print("Initializing sequence decoder...")
    decoder = EmbeddingToSequenceConverter(device='cuda')
    
    sequences = []
    sequence_ids = []
    
    print("Decoding embeddings to sequences...")
    for i, (embedding, label) in enumerate(tqdm(zip(embeddings, labels), total=len(embeddings))):
        # Decode using diverse method for better results
        sequence = decoder.embedding_to_sequence(
            embedding, 
            method='diverse', 
            temperature=0.8
        )
        sequences.append(sequence)
        sequence_ids.append(f"generated_seq_{i+1}_{label}")
        
    return sequences, sequence_ids

def save_sequences_as_fasta(sequences, sequence_ids, filename):
    """Save sequences as FASTA file."""
    records = []
    for seq_id, seq in zip(sequence_ids, sequences):
        record = SeqRecord(Seq(seq), id=seq_id, description="")
        records.append(record)
    
    SeqIO.write(records, filename, "fasta")
    print(f"βœ“ Saved {len(sequences)} sequences to {filename}")

def test_with_hmd_amp(sequences, sequence_ids):
    """Test sequences with HMD-AMP classifier."""
    print("\n🧬 Testing sequences with HMD-AMP classifier...")
    
    # Set device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load models
    ftmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/ft_parts.pth'
    clsmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/clsmodel'
    
    # Create temporary FASTA file for HMD-AMP
    temp_fasta = 'temp_sequences.fasta'
    save_sequences_as_fasta(sequences, sequence_ids, temp_fasta)
    
    try:
        # Generate sequence features using HMD-AMP's feature extraction
        seq_embeddings, _, seq_ids = amp_feature_extraction(ftmodel_save_path, device, temp_fasta)
        
        # Load classifier
        cls_model = CascadeForestClassifier()
        cls_model.load(clsmodel_save_path)
        
        # Make predictions
        binary_pred = cls_model.predict(seq_embeddings)
        
        print(f"πŸ“Š HMD-AMP Results:")
        print(f"Total sequences: {len(sequences)}")
        print(f"Predicted AMPs: {np.sum(binary_pred)} ({np.sum(binary_pred)/len(sequences)*100:.1f}%)")
        print(f"Predicted non-AMPs: {len(sequences) - np.sum(binary_pred)} ({(len(sequences) - np.sum(binary_pred))/len(sequences)*100:.1f}%)")
        
        # Analyze results by CFG type
        results_df = pd.DataFrame({
            'ID': sequence_ids,
            'Sequence': sequences,
            'AMP_Prediction': binary_pred,
            'CFG_Type': [seq_id.split('_')[-2] for seq_id in sequence_ids]
        })
        
        # Group by CFG type
        cfg_analysis = results_df.groupby('CFG_Type')['AMP_Prediction'].agg(['count', 'sum', 'mean']).round(3)
        cfg_analysis.columns = ['Total', 'Predicted_AMPs', 'AMP_Rate']
        
        print(f"\nπŸ“‹ Results by CFG Configuration:")
        print(cfg_analysis)
        
        # Show predicted AMPs
        amp_results = results_df[results_df['AMP_Prediction'] == 1]
        if len(amp_results) > 0:
            print(f"\nπŸ† Sequences predicted as AMPs ({len(amp_results)}):")
            for idx, row in amp_results.iterrows():
                seq = row['Sequence']
                cationic = seq.count('K') + seq.count('R')
                net_charge = seq.count('K') + seq.count('R') + seq.count('H') - seq.count('D') - seq.count('E')
                print(f"  {row['ID']}: {seq}")
                print(f"    Length: {len(seq)}, Cationic (K+R): {cationic}, Net charge: {net_charge:+d}")
        else:
            print(f"\n❌ No sequences predicted as AMPs")
        
        # Save detailed results
        results_df.to_csv('hmd_amp_detailed_results.csv', index=False)
        cfg_analysis.to_csv('hmd_amp_cfg_analysis.csv')
        
        print(f"\nπŸ’Ύ Results saved:")
        print(f"  - hmd_amp_detailed_results.csv (detailed per-sequence results)")
        print(f"  - hmd_amp_cfg_analysis.csv (summary by CFG type)")
        
        return results_df, cfg_analysis
        
    finally:
        # Clean up temporary file
        if os.path.exists(temp_fasta):
            os.remove(temp_fasta)

def main():
    print("πŸš€ Starting sequence decoding and HMD-AMP testing...")
    
    # Load embeddings
    embeddings, labels = load_generated_embeddings()
    
    # Decode to sequences
    sequences, sequence_ids = decode_embeddings_to_sequences(embeddings, labels)
    
    # Save sequences as FASTA
    fasta_filename = f'generated_sequences_{datetime.now().strftime("%Y%m%d_%H%M%S")}.fasta'
    save_sequences_as_fasta(sequences, sequence_ids, fasta_filename)
    
    # Test with HMD-AMP
    results_df, cfg_analysis = test_with_hmd_amp(sequences, sequence_ids)
    
    print(f"\nβœ… Complete! Generated and tested {len(sequences)} sequences")
    print(f"πŸ“ Sequences saved as: {fasta_filename}")
    
    # Final summary
    total_amps = results_df['AMP_Prediction'].sum()
    print(f"\nπŸ“Š FINAL SUMMARY:")
    print(f"Generated sequences: {len(sequences)}")
    print(f"HMD-AMP predicted AMPs: {total_amps}/{len(sequences)} ({total_amps/len(sequences)*100:.1f}%)")
    
    if total_amps > 0:
        print(f"✨ Success! Your flow model generated {total_amps} sequences that HMD-AMP classifies as AMPs!")
    else:
        print(f"πŸ” No sequences classified as AMPs - this may indicate the need for stronger AMP conditioning.")

if __name__ == "__main__":
    main()