File size: 1,739 Bytes
370f342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
"""
FlowAMP Usage Example
This script demonstrates how to use the FlowAMP model for AMP generation.
Note: This is a demonstration version. For full functionality, you'll need to train the model.
"""

import torch
from final_flow_model import AMPFlowMatcherCFGConcat

def main():
    print("=== FlowAMP Usage Example ===")
    print("This demonstrates the model architecture and usage.")
    
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using CUDA")
    else:
        device = torch.device("cpu")
        print("Using CPU")
    
    # Initialize model
    model = AMPFlowMatcherCFGConcat(
        hidden_dim=480,
        compressed_dim=80,
        n_layers=4,
        n_heads=8,
        dim_ff=1920,
        dropout=0.1,
        max_seq_len=25,
        use_cfg=True
    ).to(device)
    
    print("Model initialized successfully!")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Demonstrate model forward pass
    batch_size = 2
    seq_len = 25
    compressed_dim = 80
    
    # Create dummy input
    x = torch.randn(batch_size, seq_len, compressed_dim).to(device)
    time_steps = torch.rand(batch_size, 1).to(device)
    
    # Forward pass
    with torch.no_grad():
        output = model(x, time_steps)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print("✓ Model forward pass successful!")
    
    print("\nTo use this model for AMP generation:")
    print("1. Train the model using the provided training scripts")
    print("2. Use generate_amps.py for peptide generation")
    print("3. Use test_generated_peptides.py for evaluation")

if __name__ == "__main__":
    main()