File size: 7,251 Bytes
3527383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from datasets import load_from_disk
from transformers import AutoTokenizer
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, accuracy_score
from argparse import ArgumentParser
import os
import torch.distributed as dist
import pandas as pd
import pdb

from modules.bindevaluator_modules import *  # Import your model and other necessary classes/functions here

def parse_motifs(motif: str) -> list:
    parts = motif.split(',')
    result = []

    for part in parts:
        part = part.strip()
        if '-' in part:
            start, end = map(int, part.split('-'))
            result.extend(range(start, end + 1))
        else:
            result.append(int(part))

    result = [pos-1 for pos in result]
    print(f'Target Motifs: {result}')
    return torch.tensor(result)


class PeptideModel(pl.LightningModule):
    def __init__(self, n_layers, d_model, d_hidden, n_head,
                 d_k, d_v, d_inner, dropout=0.2,
                 learning_rate=0.00001, max_epochs=15, kl_weight=1):
        super(PeptideModel, self).__init__()

        self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
        # freeze all the esm_model parameters
        for param in self.esm_model.parameters():
            param.requires_grad = False

        self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden,
                                               n_head, d_k, d_v, d_inner, dropout=dropout)

        self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
                                                                d_k, d_v, dropout=dropout)

        self.final_ffn = FFN(d_model, d_inner, dropout=dropout)

        self.output_projection_prot = nn.Linear(d_model, 1)

        self.learning_rate = learning_rate
        self.max_epochs = max_epochs
        self.kl_weight = kl_weight

        self.classification_threshold = nn.Parameter(torch.tensor(0.5))  # Initial threshold
        self.historical_memory = 0.9
        self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925])  # binding_site weights, non-bidning site weights

    def forward(self, binder_tokens, target_tokens):
        peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
        protein_sequence = self.esm_model(**target_tokens).last_hidden_state

        prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
            seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
                                                                                    protein_sequence)

        prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)

        prot_enc = self.final_ffn(prot_enc)

        prot_enc = self.output_projection_prot(prot_enc)

        return prot_enc


def calculate_score(target_sequence, binder_sequence, model, args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
    anchor_tokens = tokenizer(target_sequence, return_tensors='pt', padding=True, truncation=True, max_length=40000)
    positive_tokens = tokenizer(binder_sequence, return_tensors='pt', padding=True, truncation=True, max_length=40000)

    anchor_tokens['attention_mask'][0][0] = 0
    anchor_tokens['attention_mask'][0][-1] = 0
    positive_tokens['attention_mask'][0][0] = 0
    positive_tokens['attention_mask'][0][-1] = 0

    target_tokens = {'input_ids': anchor_tokens["input_ids"].to(device),
                     'attention_mask': anchor_tokens["attention_mask"].to(device)}
    binder_tokens = {'input_ids': positive_tokens['input_ids'].to(device),
                     'attention_mask': positive_tokens['attention_mask'].to(device)}

    model.eval()

    # pdb.set_trace()

    prediction = model(binder_tokens, target_tokens).squeeze(-1)[0][1:-1]
    prediction = torch.sigmoid(prediction)

    return prediction, model.classification_threshold


def compute_metrics(true_residues, predicted_residues, length):
    # Initialize the true and predicted lists with 0
    true_list = [0] * length
    predicted_list = [0] * length

    # Set the values to 1 based on the provided lists
    for index in true_residues:
        true_list[index] = 1
    for index in predicted_residues:
        predicted_list[index] = 1

    # Compute the metrics
    accuracy = accuracy_score(true_list, predicted_list)
    f1 = f1_score(true_list, predicted_list)
    mcc = matthews_corrcoef(true_list, predicted_list)

    return accuracy, f1, mcc


def main():
    parser = ArgumentParser()
    parser.add_argument("-sm", default='/home/tc415/muPPIt/muppit/train_base_1/model-epoch=14-val_loss=0.40.ckpt',
                        help="File containing initial params", type=str)
    parser.add_argument("-batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("-lr", type=float, default=1e-3)
    parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
    parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
    parser.add_argument("-d_hidden", type=int, default=128, help="Dimension of CNN block")
    parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
    parser.add_argument("-d_inner", type=int, default=64)
    parser.add_argument("-target", type=str)
    parser.add_argument("-binder", type=str)
    parser.add_argument("-gt", type=str, default=None)
    parser.add_argument("-motifs", type=str, default=None)
    args = parser.parse_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = PeptideModel.load_from_checkpoint(args.sm,
                                            n_layers=args.n_layers,
                                            d_model=args.d_model,
                                            d_hidden=args.d_hidden,
                                            n_head=args.n_head,
                                            d_k=64,
                                            d_v=128,
                                            d_inner=64).to(device)

    prediction, _ = calculate_score(args.target, args.binder, model, args)
    # print(prediction)
    # print(model.classification_threshold)

    binding_site = []
    for i in range(len(prediction)):
        if prediction[i] >= 0.5:
            binding_site.append(i)

    print("Prediction: ", binding_site)
    prediction = prediction.detach().cpu().tolist()
    np.set_printoptions(precision=2, suppress=True)
    print(prediction)

    if args.motifs is not None:
        motifs = parse_motifs(args.motifs).tolist()
        print(f"Motif Score: {torch.sum(prediction[motifs]) / len(motifs)}")

    if args.gt is not None:
        L = len(args.target)
        # print(L)
        gt = parse_motifs(args.gt)
        print("Ground Truth: ", gt)

        acc, f1, mcc = compute_metrics(gt, binding_site, L)
        print(f"Accuracy={acc}\tF1={f1}\tMCC={mcc}")

    # print("Prediction Logits: ", prediction[binding_site])


if __name__ == "__main__":
    main()