cafierom commited on
Commit
426afd8
·
verified ·
1 Parent(s): b9805ff

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai.chat_models import ChatOpenAI
2
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
3
+ from google.colab import userdata
4
+ from langchain_core.tools import tool
5
+ from langgraph.graph import START, StateGraph
6
+ from langgraph.graph.message import add_messages
7
+ from langgraph.prebuilt import ToolNode, tools_condition
8
+ import gradio as gr
9
+ import spaces
10
+
11
+ from PIL import Image
12
+ from collections import Counter
13
+
14
+ from typing import Annotated, TypedDict
15
+ import time, sys, os
16
+
17
+ sys.path.append('code')
18
+ from modrag_molecule_functions import *
19
+ from modrag_property_functions import *
20
+ from modrag_protein_functions import *
21
+
22
+ openai_key = os.getenv("OPENAI_API_KEY")
23
+
24
+ tools = [name_node, smiles_node, related_node, structure_node,
25
+ substitution_node, lipinski_node, pharmfeature_node,
26
+ uniprot_node, listbioactives_node, getbioactives_node,
27
+ predict_node, gpt_node, pdb_node, find_node, docking_node,
28
+ target_node]
29
+
30
+ model = ChatOpenAI(model_name="gpt-5.2", api_key=openai_key).bind_tools(tools)
31
+
32
+ class State(TypedDict):
33
+ messages: Annotated[list, add_messages]
34
+
35
+ def model_node(state: State) -> State:
36
+ res = model.invoke(state['messages'])
37
+ return {'messages': res}
38
+
39
+ builder = StateGraph(State)
40
+ builder.add_node('model', model_node)
41
+ builder.add_node('tools', ToolNode(tools))
42
+ builder.add_edge(START, 'model')
43
+ builder.add_conditional_edges('model', tools_condition)
44
+ builder.add_edge('tools', 'model')
45
+
46
+ graph = builder.compile()
47
+ sys_message = SystemMessage(content="You are a helpful cat who says nyan and meow a lot.")
48
+ global messages
49
+ messages = [sys_message]
50
+
51
+ def start_chat():
52
+ '''
53
+ '''
54
+ global chat_history, messages, reasoning
55
+ chat_history = []
56
+ reasoning = []
57
+ messages = [sys_message]
58
+
59
+ @spaces.GPU
60
+ def chat_turn(prompt: str):
61
+ '''
62
+ '''
63
+ human_message = HumanMessage(content=prompt)
64
+ messages.append(human_message)
65
+ global chat_history
66
+ local_history = [prompt]
67
+
68
+ input = {
69
+ 'messages' : messages
70
+ }
71
+
72
+ for c in graph.stream(input):
73
+ try:
74
+ ai_mes = c['model']['messages'].content
75
+ messages.append(AIMessage(ai_mes))
76
+ if ai_mes != '':
77
+ print(f'message is {ai_mes}')
78
+ local_history.append(ai_mes)
79
+ except:
80
+ pass
81
+ try:
82
+ if os.path.exists('current_image.png'):
83
+ if os.path.getmtime('current_image.png') > time.time() - 30:
84
+ img = Image.open('current_image.png')
85
+ else:
86
+ img = None
87
+ else:
88
+ img = None
89
+ except:
90
+ img = None
91
+ try:
92
+ reasoning.append(c['tools']['messages'][0].content)
93
+ except:
94
+ pass
95
+
96
+ if len(local_history) != 2:
97
+ local_history.append('no message')
98
+ chat_history.append(local_history)
99
+ return '', img, chat_history
100
+
101
+ def send_reasoning():
102
+ global reasoning
103
+ return reasoning
104
+
105
+ start_chat()
106
+
107
+ with gr.Blocks(fill_height=True) as OpenAIMoDrAg:
108
+ gr.Markdown('''
109
+ # MoDrAg Chatbot using ChatGPT 5.2
110
+ - The *MOdular DRug design AGent*!
111
+ - This chatbot can answer questions about molecules, proteins, and their interactions.
112
+ It can also perform tasks such as predicting properties, finding similar molecules, and docking. Try it out!
113
+ - See the tool log box at the bottom for direct tool outputs.
114
+ ''')
115
+
116
+
117
+ chat = gr.Chatbot()
118
+ with gr.Row(equal_height = True):
119
+ msg = gr.Textbox(label = 'query', scale = 8)
120
+ sub_button = gr.Button("Submit", scale = 2)
121
+ clear = gr.ClearButton([msg, chat])
122
+ img_box = gr.Image()
123
+ reasoning_box = gr.Textbox(label="Tool logs", lines = 20)
124
+
125
+ msg.submit(chat_turn, [msg], [msg, img_box, chat]).then(send_reasoning, [], [reasoning_box])
126
+ sub_button.click(chat_turn, [msg], [msg, img_box, chat])
127
+ clear.click(start_chat, [], [])
128
+
129
+ OpenAIMoDrAg.launch(mcp_server = True)
finetune_gpt.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepchem as dc
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import random
5
+ import pandas as pd
6
+ from rdkit import Chem
7
+ from rdkit.Chem import Draw
8
+ import os
9
+
10
+ def finetune_gpt(df, chembl_id):
11
+ '''
12
+ accepts a dataframe with SMILES and uses deepchem to tokenize the dataset,
13
+ then uses tensorflow and a pre-trained model to fine tune the model on the dataset.
14
+ The pretrained model was trained on 305K molecules from the ZN15 dataset, including at least
15
+ 50K that are bioactive.
16
+
17
+ Returns:
18
+ out_text: the generated molecules
19
+ img: the image of the generated molecules
20
+
21
+ requires files:
22
+ vocab.txt
23
+ vocab_305K.txt
24
+ GPT_ZN305_50epochs.weights.h5
25
+ layer_store_GPT_ZN305_50epochs.txt
26
+ ZN305K_smiles.csv
27
+
28
+ '''
29
+ # check to see if f"gen_smiles_{chembl_id}.csv" exists
30
+ if os.path.exists(f"gen_smiles_{chembl_id}.csv"):
31
+ df = pd.read_csv(f"gen_smiles_{chembl_id}.csv")
32
+ final_smiles = df["SMILES"].to_list()
33
+ final_mols = [Chem.MolFromSmiles(smile) for smile in final_smiles]
34
+ else:
35
+
36
+ # Prepare dataset from chembl ==========================================
37
+
38
+ if len(df) > 2000:
39
+ df = df.sample(n=2000, random_state=42)
40
+
41
+ smiles_list = df["SMILES"].to_list()
42
+
43
+ Xa = []
44
+ for smiles in smiles_list:
45
+ smiles = smiles.replace("[Na+].","").replace("[Cl-].","").replace(".[Cl-]","").replace(".[Na+]","")
46
+ smiles = smiles.replace("[K+].","").replace("[Br-].","").replace(".[K+]","").replace(".[Br-]","")
47
+ smiles = smiles.replace("[I-].","").replace(".[I-]","").replace("[Ca2+].","").replace(".[Ca2+]","")
48
+ Xa.append(smiles)
49
+
50
+ tokenizer=dc.feat.SmilesTokenizer(vocab_file="vocab.txt")
51
+ featname="SMILES Tokenizer"
52
+
53
+ fl = list(map(lambda x: tokenizer.encode(x),Xa))
54
+
55
+ biggest = 1
56
+ smallest = 200
57
+ for i in range(len(fl)):
58
+ temp = len(fl[i])
59
+ if temp > biggest:
60
+ biggest = temp
61
+ if temp < smallest:
62
+ smallest = temp
63
+
64
+ print(biggest, smallest)
65
+
66
+ string_length = smallest - 1
67
+ max_length = biggest
68
+
69
+ fl2 = list(map(lambda x: tokenizer.add_padding_tokens(x,max_length),fl))
70
+
71
+ fl2set=set()
72
+ for sublist in fl2:
73
+ fl2set.update(sublist)
74
+ new_vocab_size = len(fl2set)
75
+ print("New vocabulary size: ",new_vocab_size)
76
+
77
+ f = open("vocab_305K.txt", "r")
78
+ raw_lines = f.readlines()
79
+ f.close()
80
+ VOCAB_SIZE = len(raw_lines)
81
+ print("Vocabulary size for standard dataset: ",VOCAB_SIZE)
82
+
83
+ lines = []
84
+ for line in raw_lines:
85
+ lines.append(line.replace("\n",""))
86
+
87
+ novel_items = []
88
+ for item in fl2set:
89
+ item = tokenizer.decode([item])
90
+ item = tokenizer.convert_tokens_to_string(item)
91
+ item = item.replace(" ","")
92
+
93
+ if item not in lines:
94
+ print(f"{item} not in standard vocabulary")
95
+ novel_items.append(item)
96
+
97
+ if(len(novel_items) > 0):
98
+ print("This dataset is not compatible with the Foundation model vocabulary")
99
+ else:
100
+ print("This dataset is compatible with the Foundation model vocabulary")
101
+
102
+ if max_length > 166:
103
+ print("This dataset's context window is not compatible with the Foundation model.")
104
+ else:
105
+ print("This dataset's context window is compatible with the Foundation model")
106
+
107
+ smiles_removed_tokens = []
108
+ for i,smiles in enumerate(Xa):
109
+ bad_list = [True if (token in smiles) else False for token in novel_items]
110
+ if not any(bad_list):
111
+ smiles_removed_tokens.append(smiles)
112
+
113
+ smiles_no_long = []
114
+ for i,smiles in enumerate(smiles_removed_tokens):
115
+ if len(smiles) <= 166:
116
+ smiles_no_long.append(smiles)
117
+
118
+ print(f"Removed {len(Xa) - len(smiles_no_long)} entries from the list!")
119
+
120
+ new_dict = {"SMILES": smiles_no_long}
121
+ new_df = pd.DataFrame(new_dict)
122
+
123
+ Xa = []
124
+ for smiles in new_df['SMILES']:
125
+ Xa.append(smiles)
126
+
127
+ tokenizer=dc.feat.SmilesTokenizer(vocab_file="vocab_305K.txt")
128
+ featname="SMILES Tokenizer"
129
+
130
+ fl = list(map(lambda x: tokenizer.encode(x),Xa))
131
+
132
+ biggest = 1
133
+ smallest = 200
134
+ for i in range(len(fl)):
135
+ temp = len(fl[i])
136
+ if temp > biggest:
137
+ biggest = temp
138
+ if temp < smallest:
139
+ smallest = temp
140
+
141
+ print(biggest, smallest)
142
+
143
+ string_length = smallest - 1
144
+ max_length = biggest
145
+
146
+ fl2 = list(map(lambda x: tokenizer.add_padding_tokens(x,max_length),fl))
147
+
148
+ f = open("vocab_305K.txt", "r")
149
+ lines = f.readlines()
150
+ f.close()
151
+ VOCAB_SIZE = len(lines)
152
+ print("Vocabulary size for this dataset: ",VOCAB_SIZE)
153
+
154
+ x = []
155
+ y = []
156
+ i=0
157
+ for string in fl2:
158
+ x.append(string[0:max_length-1]) #string_length
159
+ y.append(string[1:max_length]) #string_length+1
160
+
161
+ fx = np.array(x)
162
+ fy = np.array(y)
163
+ print("Number of features and datapoints, targets: ",fx.shape,fy.shape)
164
+
165
+ # Load foundation model ==================================================
166
+
167
+ VOCAB_SIZE = 100
168
+ max_length = 166
169
+ num_new_blocks = 2
170
+ EMBEDDING_DIM = 256
171
+ N_HEADS = 4
172
+ KEY_DIM = 256
173
+ FEED_FORWARD_DIM = 256
174
+
175
+ inputs = tf.keras.layers.Input(shape=(None,),dtype=tf.int32)
176
+ x = TokenAndPositionEmbedding(max_length,VOCAB_SIZE,EMBEDDING_DIM)(inputs)
177
+ for i in range(num_new_blocks+2):
178
+ x, attentions_scores = TransformerBlock(N_HEADS,KEY_DIM,EMBEDDING_DIM,FEED_FORWARD_DIM)(x)
179
+ outputs = tf.keras.layers.Dense(VOCAB_SIZE,activation="softmax")(x)
180
+
181
+ gpt_ft = tf.keras.models.Model(inputs = inputs, outputs =[outputs, attentions_scores])
182
+
183
+ f = open("layer_store_GPT_ZN305_50epochs.txt", "r")
184
+ layer_name_store_raw = f.readlines()
185
+ f.close()
186
+
187
+ print("Reading in layers:")
188
+ layer_name_store = []
189
+ for line in layer_name_store_raw:
190
+ line = line.replace("\n","")
191
+ layer_name_store.append(line)
192
+ print(line)
193
+ print("===========================================")
194
+
195
+ new_layers = num_new_blocks + 1
196
+ for i,layer in enumerate(gpt_ft.layers[:-new_layers]):
197
+ layer.name = layer_name_store[i]
198
+ print(f"{layer.name} has been named!")
199
+
200
+ for i,layer in enumerate(gpt_ft.layers[-new_layers:-1]):
201
+ layer.name = f"transformer_block_X_{i+1}"
202
+ print(f"{layer.name} has been named!")
203
+
204
+ gpt_ft.layers[-1].name = "dense_X"
205
+
206
+ gpt_ft.load_weights("GPT_ZN305_50epochs.weights.h5", skip_mismatch=True)
207
+
208
+ for layer in gpt_ft.layers[0:-new_layers]: #make old layers freeze and only train new layers
209
+ layer.trainable=False
210
+ print(f"setting layer {layer.name} untrainable.")
211
+
212
+ for layer in gpt_ft.layers[-new_layers:]:
213
+ layer.trainable=True
214
+ print(f"setting layer {layer.name} trainable.")
215
+
216
+ # train new layers =======================================================
217
+
218
+ batch_size = 512
219
+ gpt_ft.compile("adam",loss=[tf.keras.losses.SparseCategoricalCrossentropy(),None])
220
+ gpt_ft.fit(fx,fy,epochs = 50, batch_size = batch_size)
221
+
222
+ # train all together =====================================================
223
+ for layer in gpt_ft.layers:
224
+ layer.trainable=True
225
+ print(f"setting layer {layer.name} trainable.")
226
+
227
+ gpt_ft.compile("adam",loss=[tf.keras.losses.SparseCategoricalCrossentropy(),None])
228
+ gpt_ft.fit(fx,fy,epochs = 25, batch_size = batch_size)
229
+
230
+ # make prompts ============================================================
231
+
232
+ df_prompts = pd.read_csv("ZN305K_smiles.csv")
233
+
234
+ Xap = []
235
+ for smiles in df_prompts["SMILES"]:
236
+ smiles = smiles.replace("[Na+].","").replace("[Cl-].","").replace(".[Cl-]","").replace(".[Na+]","")
237
+ smiles = smiles.replace("[K+].","").replace("[Br-].","").replace(".[K+]","").replace(".[Br-]","")
238
+ smiles = smiles.replace("[I-].","").replace(".[I-]","").replace("[Ca2+].","").replace(".[Ca2+]","")
239
+ Xap.append(smiles)
240
+
241
+ raw_prompts = random.choices(Xap,k=50)
242
+
243
+ test_string = []
244
+ for smile in raw_prompts:
245
+ test_string.append(smile[:2])
246
+
247
+ # inference ================================================================
248
+
249
+ tf.random.set_seed(42)
250
+
251
+ batch_length = len(test_string)
252
+ prompt_length = len(test_string[0])
253
+ test_xlist = np.empty([batch_length,prompt_length], dtype=int)
254
+
255
+ test_tokenized = list(map(lambda x: tokenizer.encode(x),test_string))
256
+ for i in range(batch_length):
257
+ test_xlist[i][:] = test_tokenized[i][:prompt_length]
258
+ test_array = np.array(test_xlist)
259
+
260
+ proba = np.empty([batch_length,VOCAB_SIZE])
261
+ rescaled_logits = np.empty([batch_length,VOCAB_SIZE])
262
+ preds = np.empty([batch_length])
263
+ gen_molecules = np.empty([batch_length])
264
+
265
+ c_final = 60 - prompt_length
266
+ sig_start = 0.10
267
+ TEMP = 1.5
268
+
269
+ for c in range(0,c_final,1):
270
+
271
+ c_o = int(c_final*sig_start)
272
+
273
+ T_int = TEMP*(1/(1+np.exp(-(c-c_o))))
274
+
275
+ results, _ = gpt_ft.predict(test_array)
276
+
277
+ if T_int < 0.015:
278
+ print(f"using zero temp generation with {T_int}.")
279
+ for j in range(batch_length):
280
+ preds[j] = tf.argmax(results[j][-1])
281
+ preds = list(map(lambda x: int(x),preds))
282
+ else:
283
+ print(f"using variable temp generation with {T_int}.")
284
+ for j in range(batch_length):
285
+ proba[j] = (results[j][-1:]) ** (1/T_int)
286
+ rescaled_logits[j] = ( proba[j][:] ) / np.sum(proba[j][:])
287
+ preds[j] = np.random.choice(len(rescaled_logits[j][:]),
288
+ p=rescaled_logits[j][:])
289
+ preds = list(map(lambda x: int(x),preds))
290
+ test_array = np.c_[test_array,preds]
291
+ print(test_array.shape)
292
+
293
+ gen_molecules = list(map(lambda x: tokenizer.decode(x),test_array))
294
+ gen_molecules = list(map(lambda x: tokenizer.convert_tokens_to_string(x),
295
+ gen_molecules))
296
+ gen_molecules = list(map(lambda x: strip_smiles(x),gen_molecules))
297
+
298
+ mols, smiles = mols_from_smiles(gen_molecules)
299
+
300
+ final_smiles = []
301
+ final_mols = []
302
+ for smile, mol in zip(smiles,mols):
303
+ if smile not in final_smiles:
304
+ final_smiles.append(smile)
305
+ final_mols.append(mol)
306
+
307
+ final_dict = {"SMILES": final_smiles}
308
+ final_df = pd.DataFrame.from_dict(final_dict)
309
+ final_df.to_csv(f"gen_smiles_{chembl_id}.csv", index = False)
310
+
311
+ print(f"Generated {len(final_smiles)} unique molecules.")
312
+
313
+ img = Draw.MolsToGridImage(final_mols,molsPerRow=3,legends=final_smiles)
314
+ #img.save("Substitution_image.png")
315
+
316
+ out_text = f'The novel molecules generated by a GPT trained on {chembl_id} are: \n'
317
+ for smile in final_smiles:
318
+ out_text += f'{smile}\n'
319
+
320
+ return final_smiles, out_text, img
321
+
322
+ def casual_attention_mask(batch_size,n_dest,n_src,dtype):
323
+ '''
324
+ Make a causal attention mask
325
+ '''
326
+ i = tf.range(n_dest)[:,None]
327
+ j = tf.range(n_src)
328
+ m = i >= j - n_src + n_dest
329
+ mask = tf.cast(m,dtype)
330
+ mask = tf.reshape(mask,[1,n_dest,n_src])
331
+ mult = tf.concat([tf.expand_dims(batch_size,-1),tf.constant([1,1],dtype=tf.int32)],0)
332
+ return tf.tile(mask,mult)
333
+
334
+ class TransformerBlock(tf.keras.layers.Layer):
335
+ '''
336
+ Transformer block with multi-head attention.
337
+ '''
338
+ def __init__(self,num_heads,key_dim,embed_dim,ff_dim,dropout_rate=0.1):
339
+ super(TransformerBlock,self).__init__()
340
+ self.num_heads = num_heads
341
+ self.key_dim = key_dim
342
+ self.embed_dim = embed_dim
343
+ self.ff_dim = ff_dim
344
+ self.dropout_rate = dropout_rate
345
+ self.attn = tf.keras.layers.MultiHeadAttention(self.num_heads,self.key_dim,
346
+ output_shape=self.embed_dim)
347
+ self.dropout_1 = tf.keras.layers.Dropout(self.dropout_rate)
348
+ self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=0.000001)
349
+ self.ffn_1 = tf.keras.layers.Dense(self.ff_dim,activation="relu")
350
+ self.ffn_2 = tf.keras.layers.Dense(self.embed_dim)
351
+ self.dropout_2 = tf.keras.layers.Dropout(self.dropout_rate)
352
+ self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=0.000001)
353
+
354
+ def call(self,inputs):
355
+ input_shape = tf.shape(inputs)
356
+ batch_size2 = input_shape[0]
357
+ seq_len = input_shape[1]
358
+ casual_mask = casual_attention_mask(batch_size2,seq_len,seq_len,tf.bool)
359
+ attention_output, attention_scores = self.attn(inputs,inputs,
360
+ attention_mask=casual_mask,
361
+ return_attention_scores=True)
362
+ attention_output = self.dropout_1(attention_output)
363
+ out1 = self.ln_1(inputs + attention_output)
364
+ ffn_1 = self.ffn_1(out1)
365
+ ffn_2 = self.ffn_2(ffn_1)
366
+ ffn_output = self.dropout_2(ffn_2)
367
+ return (self.ln_2(out1+ffn_output),attention_scores)
368
+
369
+ def get_config(self):
370
+ config = super().get_config()
371
+ config.update({"key_dim": self.key_dim, "embed_dim": self.embed_dim,
372
+ "num_heads": self.num_heads,"ff_dim": self.ff_dim,
373
+ "dropout_rate": self.dropout_rate})
374
+ return config
375
+
376
+ class TokenAndPositionEmbedding(tf.keras.layers.Layer):
377
+ '''
378
+ Embeds tokens and positions.
379
+ '''
380
+ def __init__(self,max_len,vocab_size,embed_dim):
381
+ super(TokenAndPositionEmbedding,self).__init__()
382
+ self.max_len = max_len
383
+ self.vocab_size = vocab_size
384
+ self.embed_dim = embed_dim
385
+ self.token_emb = tf.keras.layers.Embedding(input_dim=vocab_size,
386
+ output_dim = embed_dim)
387
+ self.pos_emb = tf.keras.layers.Embedding(input_dim=max_len,output_dim=embed_dim)
388
+
389
+ def call(self,x):
390
+ maxlen = tf.shape(x)[-1]
391
+ positions = tf.range(start=0,limit=maxlen,delta=1)
392
+ positions = self.pos_emb(positions)
393
+ x = self.token_emb(x)
394
+ return x + positions
395
+
396
+ def get_config(self):
397
+ config = super().get_config()
398
+ config.update({"max_len": self.max_len, "vocab_size": self.vocab_size,
399
+ "embed_dim": self.embed_dim})
400
+ return config
401
+
402
+ def strip_smiles(input_string):
403
+ '''
404
+ Cleans un-needed tokens from the SMILES string.
405
+
406
+ Args:
407
+ input_string: SMILES string
408
+ Returns:
409
+ output_string: cleaned SMILES string
410
+ '''
411
+ output_string = input_string.replace(" ","").replace("[CLS]","").replace("[SEP]","").replace("[PAD]","")
412
+ output_string = output_string.replace("[Na+].","").replace(".[Na+]","")
413
+ return output_string
414
+
415
+ def mols_from_smiles(input_smiles_list):
416
+ '''
417
+ Converts a list of SMILES strings to a list of RDKit molecules.
418
+
419
+ Args:
420
+ input_smiles_list: list of SMILES strings
421
+ Returns:
422
+ valid_mols: list of RDKit molecules
423
+ valid_smiles: list of SMILES strings
424
+ '''
425
+ valid_mols = []
426
+ valid_smiles = []
427
+
428
+ good_count = 0
429
+ for ti, smile in enumerate(input_smiles_list):
430
+ temp_mol = Chem.MolFromSmiles(smile)
431
+ if temp_mol != None:
432
+ valid_mols.append(temp_mol)
433
+ valid_smiles.append(smile)
434
+ good_count += 1
435
+ else:
436
+ print(f"SMILES {ti} was not valid!")
437
+
438
+ if len(valid_mols) == len(valid_smiles) == good_count:
439
+ print(f"Generated a total of {good_count} mol objects")
440
+ else:
441
+ print("mismatch!")
442
+ return valid_mols, valid_smiles
modrag_molecule_functions.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+ from rdkit import Chem
4
+ from rdkit.Chem import AllChem, QED
5
+ from rdkit.Chem import Draw
6
+ from rdkit.Chem.Draw import MolsToGridImage
7
+ from rdkit import rdBase
8
+ from rdkit.Chem import rdMolAlign
9
+ import os, re
10
+ from rdkit import RDConfig
11
+ import pubchempy as pcp
12
+ from PIL import Image
13
+ from collections import Counter
14
+ from langchain_core.tools import tool
15
+
16
+ @tool
17
+ def name_node(smiles_list: list[str]) -> (list[str], str):
18
+ '''
19
+ Queries Pubchem for the name of the molecule based on the smiles string.
20
+ Args:
21
+ smiles_list: the list of input smiles strings
22
+ Returns:
23
+ names_list: the list of names of the molecules
24
+ name_string: a string of the tool results
25
+ '''
26
+ print("name tool")
27
+ print('===================================================')
28
+
29
+ names = []
30
+ name_string = ''
31
+ for smiles in smiles_list:
32
+ try:
33
+ res = pcp.get_compounds(smiles, "smiles")
34
+ name = res[0].iupac_name
35
+ names.append(name)
36
+ name_string += f'{smiles}: IUPAC molecule name: {name}\n'
37
+ print(smiles, name)
38
+ syn_list = pcp.get_synonyms(res[0].cid)
39
+ for alt_name in syn_list[0]['Synonym'][:5]:
40
+ name_string += f'{smiles}: alternative or common name: {alt_name}\n'
41
+ except:
42
+ name = "unknown"
43
+ name_string += f'{smiles}: Fail\n'
44
+
45
+ return names, name_string, None
46
+
47
+ @tool
48
+ def smiles_node(names_list: list[str]) -> (list[str], str):
49
+ '''
50
+ Queries Pubchem for the smiles string of the molecule based on the name.
51
+ Args:
52
+ names_list: the list of molecule names
53
+ Returns:
54
+ smiles_list: the list of smiles strings of the molecules
55
+ smiles_string: a string of the tool results
56
+ '''
57
+ print("smiles tool")
58
+ print('===================================================')
59
+
60
+ smiles_list = []
61
+ smiles_string = ''
62
+ for name in names_list:
63
+ try:
64
+ res = pcp.get_compounds(name, "name")
65
+ smiles = res[0].smiles
66
+ #smiles = smiles.replace('#','~')
67
+ smiles_list.append(smiles)
68
+ smiles_string += f'{name}: The SMILES string for the molecule is: {smiles}\n'
69
+ except:
70
+ smiles = "unknown"
71
+ smiles_string += f'{name}: Fail\n'
72
+
73
+ return smiles_list, smiles_string, None
74
+
75
+ @tool
76
+ def related_node(smiles_list: list[str]) -> (list[list[str]], str, list):
77
+ '''
78
+ Queries Pubchem for similar molecules based on the smiles string or name
79
+ Args:
80
+ smiles: the input smiles string, OR
81
+ name: the molecule name
82
+ Returns:
83
+ total_similar_list: a list of lists of similar molecules
84
+ related_string: a string of the tool results
85
+ all_images: a list of images of the similar molecules
86
+ '''
87
+ print("related tool")
88
+ print('===================================================')
89
+
90
+
91
+ total_similar_list = []
92
+ all_images = []
93
+ related_string = ''
94
+ for smiles in smiles_list:
95
+ try:
96
+ res = pcp.get_compounds(smiles, "smiles", searchtype="similarity",listkey_count=50)
97
+ related_string += f'The following molecules are similar to {smiles}: \n'
98
+ print('got related molecules with smiles')
99
+
100
+ sub_smiles = []
101
+
102
+ i = 0
103
+ for compound in res:
104
+ if i == 0:
105
+ print(compound.iupac_name)
106
+ i+=1
107
+ sub_smiles.append(compound.smiles)
108
+ related_string += f'Name: {compound.iupac_name}\n'
109
+ related_string += f'SMILES: {compound.smiles}\n'
110
+ related_string += f'Molecular Weight: {compound.molecular_weight}\n'
111
+ related_string += f'LogP: {compound.xlogp}\n'
112
+ related_string += '===================\n'
113
+
114
+ sub_mols = [Chem.MolFromSmiles(smile) for smile in sub_smiles]
115
+ legend = [str(compound.smiles) for compound in res]
116
+
117
+ total_similar_list.append(sub_smiles)
118
+ img = Draw.MolsToGridImage(sub_mols, legends=legend, molsPerRow=4, subImgSize=(250, 250))
119
+ #pic = img.data
120
+ all_images.append(img)
121
+ except:
122
+ related_string += f'{smiles}: Fail\n'
123
+ total_similar_list.append([])
124
+ all_images.append(None)
125
+
126
+ pic = img.data
127
+ with open('current_image.png', 'wb') as f:
128
+ f.write(pic)
129
+ img = Image.open('current_image.png')
130
+
131
+ return total_similar_list, related_string, img
132
+
133
+ @tool
134
+ def structure_node(smiles_list: list[str]) -> (list[str], str, list):
135
+ '''
136
+ Generates the 3D structure of the molecule based on the smiles string.
137
+ Args:
138
+ smiles: the input smiles string
139
+ Returns:
140
+ all_structures: a list of strings of the 3D structure of the molecule
141
+ output_string: a string of the chemical formulae.
142
+ all_images: a list of images of the 3D structure of the molecule
143
+ '''
144
+ print("structure tool")
145
+
146
+ all_mols = []
147
+ all_structures = []
148
+ output_string = ''
149
+
150
+ for smile in smiles_list:
151
+ mol = Chem.MolFromSmiles(smile)
152
+ molH = Chem.AddHs(mol)
153
+ AllChem.EmbedMolecule(molH)
154
+ AllChem.MMFFOptimizeMolecule(molH)
155
+
156
+ structure_string = ""
157
+ all_symbols = []
158
+ for atom in molH.GetAtoms():
159
+ symbol = atom.GetSymbol()
160
+ all_symbols.append(symbol)
161
+ pos = molH.GetConformer().GetAtomPosition(atom.GetIdx())
162
+ structure_string += f'{symbol} {pos[0]} {pos[1]} {pos[2]}\n'
163
+
164
+ atom_freqs = Counter(all_symbols)
165
+ formula = ''.join([f'{atom}{count}' for atom, count in atom_freqs.items()])
166
+
167
+ output_string += f'For {smile}: Formula is: {formula}\n'
168
+ all_structures.append(structure_string)
169
+ all_mols.append(molH)
170
+
171
+ img = Draw.MolsToGridImage(all_mols, molsPerRow=3, subImgSize=(250, 250))
172
+
173
+ #save the image as current_image.png
174
+ pic = img.data
175
+ with open('current_image.png', 'wb') as f:
176
+ f.write(pic)
177
+ img = Image.open('current_image.png')
178
+ return all_structures, output_string, img
modrag_property_functions.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ from rdkit.Chem import AllChem, QED
3
+ from rdkit.Chem import Draw
4
+ from rdkit import rdBase
5
+ from rdkit.Chem import rdMolAlign
6
+ import os, re
7
+ from rdkit import RDConfig
8
+ from rdkit.Chem.Features.ShowFeats import _featColors as featColors
9
+ from rdkit.Chem.FeatMaps import FeatMaps
10
+ from PIL import Image
11
+ from langchain_core.tools import tool
12
+
13
+ fdef = AllChem.BuildFeatureFactory(os.path.join(RDConfig.RDDataDir,'BaseFeatures.fdef'))
14
+
15
+ fmParams = {}
16
+ for k in fdef.GetFeatureFamilies():
17
+ fparams = FeatMaps.FeatMapParams()
18
+ fmParams[k] = fparams
19
+
20
+ @tool
21
+ def substitution_node(smiles_list: list[str]) -> (list[str], str, list):
22
+ '''
23
+ A simple substitution routine that looks for a substituent on a phenyl ring and
24
+ substitutes different fragments in that location. Returns a list of novel molecules and their
25
+ QED score (1 is most drug-like, 0 is least drug-like).
26
+
27
+ Args:
28
+ smiles: the input smiles string
29
+ Returns:
30
+ new_smiles_list: a list of novel molecules and their QED scores.
31
+ new_smiles_string: a string of the tool results
32
+ '''
33
+ print("substitution tool")
34
+ print('===================================================')
35
+
36
+ new_fragments = ["c(Cl)c", "c(F)c", "c(O)c", "c(C)c", "c(OC)c", "c([NH3+])c",
37
+ "c(Br)c", "c(C(F)(F)(F))c"]
38
+
39
+ total_sub_smiles_list = []
40
+ total_sub_smiles_string = ''
41
+ total_sub_images = []
42
+
43
+ for smiles in smiles_list:
44
+ try:
45
+ new_smiles = []
46
+ for fragment in new_fragments:
47
+ m = re.findall(r"c(\D\D*)c", smiles)
48
+ if len(m) != 0:
49
+ for group in m:
50
+ #print(group)
51
+ if fragment not in group:
52
+ new_smile = smiles.replace(group[1:], fragment)
53
+ new_smiles.append(new_smile)
54
+
55
+ qeds = []
56
+ for new_smile in new_smiles:
57
+ qeds.append(get_qed(new_smile))
58
+ original_qed = get_qed(smiles)
59
+
60
+ total_sub_smiles_string += "Substitution or Analogue creation tool results: \n"
61
+ total_sub_smiles_string += f"The original molecule SMILES was {smiles} with QED {original_qed}.\n"
62
+ total_sub_smiles_string += "Novel Molecules or Analogues and QED values: \n"
63
+ for i in range(len(new_smiles)):
64
+ total_sub_smiles_string += f"SMILES: {new_smiles[i]}, QED: {qeds[i]:.3f}\n"
65
+ total_sub_smiles_list.append(new_smiles)
66
+
67
+ mols = [Chem.MolFromSmiles(smile) for smile in new_smiles]
68
+ img = Draw.MolsToGridImage(mols,legends=new_smiles, molsPerRow=4, subImgSize=(250, 250))
69
+ total_sub_images.append(img)
70
+ except:
71
+ total_sub_smiles_list.append([])
72
+ total_sub_smiles_string += f"SMILES: {smiles}, Fail\n"
73
+ total_sub_images.append(None)
74
+
75
+ pic = img.data
76
+ with open('current_image.png', 'wb') as f:
77
+ f.write(pic)
78
+ img = Image.open('current_image.png')
79
+
80
+ return total_sub_smiles_list, total_sub_smiles_string, img
81
+
82
+ def get_qed(smiles):
83
+ '''
84
+ Helper function to compute QED for a given molecule.
85
+ Args:
86
+ smiles: the input smiles string
87
+ Returns:
88
+ qed: the QED score of the molecule.
89
+ '''
90
+ mol = Chem.MolFromSmiles(smiles)
91
+ qed = Chem.QED.default(mol)
92
+ return qed
93
+
94
+ @tool
95
+ def lipinski_node(smiles_list: list[str]) -> (list[float], str):
96
+ '''
97
+ A tool to calculate QED and other lipinski properties of a molecule.
98
+ Args:
99
+ smiles: the input smiles string
100
+ Returns:
101
+ total_lipinski_list: a list of the QED and other lipinski properties of the molecules,
102
+ including Molecular Weight, LogP, HBA, HBD, Polar Surface Area,
103
+ Rotatable Bonds, Aromatic Rings and Undesireable Moieties.
104
+ total_lipinski_string: a string of the tool results
105
+ '''
106
+ print("lipinski tool")
107
+ print('===================================================')
108
+
109
+ total_lipinski_list = []
110
+ total_lipinski_string = ''
111
+
112
+ for smiles in smiles_list:
113
+ for ion in ['.[Na+]', '.[K+]', '.[Cl-]', '.[Br-]', '[Na+].', '[K+].', '[Cl-].', '[Br-].']:
114
+ smiles = smiles.replace(ion, '')
115
+ lipinski_list = []
116
+ try:
117
+ mol = Chem.MolFromSmiles(smiles)
118
+ qed = Chem.QED.default(mol)
119
+
120
+ p = Chem.QED.properties(mol)
121
+ mw = p[0]
122
+ logP = p[1]
123
+ hba = p[2]
124
+ hbd = p[3]
125
+ psa = p[4]
126
+ rb = p[5]
127
+ ar = p[6]
128
+ um = p[7]
129
+
130
+ lipinski_list.append(qed)
131
+ lipinski_list.append(mw)
132
+ lipinski_list.append(logP)
133
+ lipinski_list.append(hba)
134
+ lipinski_list.append(hbd)
135
+ lipinski_list.append(psa)
136
+ lipinski_list.append(rb)
137
+ lipinski_list.append(ar)
138
+ lipinski_list.append(um)
139
+
140
+ total_lipinski_string += f"Properties of SMILES: {smiles}: QED: {qed:.3f}\n"
141
+ total_lipinski_string += f"Molecular Weight: {mw:.3f}, LogP: {logP:.3f}\n"
142
+ total_lipinski_string += f"Hydrogen bond acceptors: {hba}, Hydrogen bond donors: {hbd}\n"
143
+ total_lipinski_string += f"Polar Surface Area: {psa:.3f}, Rotatable Bonds: {rb}\n"
144
+ total_lipinski_string += f"Aromatic Rings: {ar}, Undesireable moieties: {um}\n"
145
+ total_lipinski_string += "===================================================\n"
146
+ total_lipinski_list.append(lipinski_list)
147
+ except:
148
+ total_lipinski_list.append([])
149
+ total_lipinski_string += f"SMILES: {smiles}, Could not get properties\n"
150
+ return total_lipinski_list, total_lipinski_string, None
151
+
152
+ @tool
153
+ def pharmfeature_node(known_smiles: str, test_smiles: list[str]) -> (list[float], str):
154
+ '''
155
+ A tool to compare the pharmacophore features of a query molecule against
156
+ a those of a reference molecule and report the pharmacophore features of both and the feature
157
+ score of the query molecule.
158
+
159
+ Args:
160
+ known_smiles: the reference smiles string
161
+ test_smiles: the query smiles string
162
+ Returns:
163
+ total_pharmfeature_scores: a list of the pharmacophore feature scores of the query molecules.
164
+ total_pharmfeature_string: a string of the tool results
165
+ '''
166
+ print("pharmfeature tool")
167
+ print('===================================================')
168
+
169
+ keep = ('Donor', 'Acceptor', 'NegIonizable', 'PosIonizable', 'ZnBinder', 'Aromatic', 'LumpedHydrophobe')
170
+ feat_hash = {'Donor': 'Hydrogen bond donors', 'Acceptor': 'Hydrogen bond acceptors',
171
+ 'NegIonizable': 'Negatively ionizable groups', 'PosIonizable': 'Positively ionizable groups',
172
+ 'ZnBinder': 'Zinc Binders', 'Aromatic': 'Aromatic rings', 'LumpedHydrophobe': 'Hydrophobic/non-polar groups' }
173
+
174
+
175
+ smiles = [known_smiles, *test_smiles]
176
+ mols = [Chem.MolFromSmiles(x) for x in smiles]
177
+
178
+ mols = [Chem.AddHs(m) for m in mols]
179
+ ps = AllChem.ETKDGv3()
180
+
181
+ for m in mols:
182
+ AllChem.EmbedMolecule(m,ps)
183
+
184
+ total_pharmfeature_scores = []
185
+ total_pharmfeature_string = ''
186
+
187
+ #i = 1
188
+ for i in range(1, len(mols)):
189
+ o3d = rdMolAlign.GetO3A(mols[i],mols[0])
190
+ o3d.Align()
191
+
192
+ feat_vectors = []
193
+ for m in [mols[0], mols[i]]:
194
+ rawFeats = fdef.GetFeaturesForMol(m)
195
+ feat_vectors.append([f for f in rawFeats if f.GetFamily() in keep])
196
+
197
+ feat_maps = [FeatMaps.FeatMap(feats = x,weights=[1]*len(x),params=fmParams) for x in feat_vectors]
198
+ test_score = feat_maps[0].ScoreFeats(feat_maps[1].GetFeatures())/(feat_maps[0].GetNumFeatures())
199
+
200
+ feats_known = {}
201
+ feats_test = {}
202
+ for feat in feat_vectors[0]:
203
+ if feat.GetFamily() not in feats_known.keys():
204
+ feats_known[feat.GetFamily()] = 1
205
+ else:
206
+ feats_known[feat.GetFamily()] += 1
207
+
208
+ for feat in feat_vectors[1]:
209
+ if feat.GetFamily() not in feats_test.keys():
210
+ feats_test[feat.GetFamily()] = 1
211
+ else:
212
+ feats_test[feat.GetFamily()] += 1
213
+
214
+ total_pharmfeature_string += f"PharmFeature tool results for SMILES: {smiles[i]}: \n"
215
+ total_pharmfeature_string += f"The Pharmacophore Feature Overlap Score of the test molecule \
216
+ versus the reference molecule is {test_score:.3f}. \n\n"
217
+ total_pharmfeature_scores.append(test_score)
218
+
219
+ for feat in feats_known.keys():
220
+ total_pharmfeature_string += f"There are {feats_known[feat]} {feat_hash[feat]} in the reference molecule. \n"
221
+
222
+ for feat in feats_test.keys():
223
+ total_pharmfeature_string += f"There are {feats_test[feat]} {feat_hash[feat]} in the test molecule. \n"
224
+ #i += 1
225
+ total_pharmfeature_string += "===================================================\n"
226
+
227
+ return total_pharmfeature_scores, total_pharmfeature_string, None
modrag_protein_functions.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ from rdkit.Chem import AllChem, QED
3
+ from rdkit.Chem import Draw
4
+ from rdkit.Chem.Draw import MolsToGridImage
5
+ from rdkit import rdBase
6
+ from rdkit.Chem import rdMolAlign
7
+ import os, re
8
+ from rdkit import RDConfig
9
+ from PIL import Image
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from chembl_webresource_client.new_client import new_client
14
+ from tqdm.auto import tqdm
15
+ import requests, json
16
+ from rcsbapi.search import TextQuery
17
+ import itertools
18
+
19
+ import lightgbm as lgb
20
+ from lightgbm import LGBMRegressor
21
+ import deepchem as dc
22
+ from sklearn.model_selection import train_test_split, GridSearchCV
23
+ from sklearn.preprocessing import StandardScaler
24
+ import tensorflow as tf
25
+ import random
26
+ from finetune_gpt import *
27
+ from dockstring import load_target
28
+ from langchain_core.tools import tool
29
+
30
+
31
+ @tool
32
+ def uniprot_node(protein_names: list[str], human_flag: bool = False) -> (list[str], str):
33
+ '''
34
+ This tool takes in the user requested protein and searches UNIPROT for matches.
35
+ It returns a string scontaining the protein ID, gene name, organism, and protein name.
36
+ Args:
37
+ query_protein: the name of the protein to search for.
38
+
39
+ Returns:
40
+ total_ids: a list of UNIPROT IDs for the given protein names.
41
+ protein_string: a string containing the protein ID, gene name, organism, and protein name.
42
+
43
+ '''
44
+ print("UNIPROT tool")
45
+ print('===================================================')
46
+
47
+ total_ids = []
48
+ protein_string = ''
49
+
50
+ for protein_name in protein_names:
51
+ try:
52
+ url = f'https://rest.uniprot.org/uniprotkb/search?query={protein_name}&format=tsv'
53
+ response = requests.get(url).text
54
+
55
+ f = open(f"{protein_name}_uniprot_ids.tsv", "w")
56
+ f.write(response)
57
+ f.close()
58
+
59
+ prot_df_raw = pd.read_csv(f'{protein_name}_uniprot_ids.tsv', sep='\t')
60
+ if human_flag:
61
+ prot_df = prot_df_raw[prot_df_raw['Organism'] == "Homo sapiens (Human)"]
62
+ print(f"Found {len(prot_df)} Human proteins out of {len(prot_df_raw)} total proteins")
63
+ else:
64
+ prot_df = prot_df_raw
65
+
66
+ prot_ids = prot_df['Entry'].tolist()
67
+ genes = prot_df['Gene Names'].tolist()
68
+ organisms = prot_df['Organism'].tolist()
69
+ names = prot_df['Protein names'].tolist()
70
+
71
+ sub_ids = []
72
+ for id, gene, organism, name in zip(prot_ids, genes, organisms, names):
73
+ protein_string += f'Protein {protein_name}, ID: {id}, Gene: {gene}, Organism: {organism}, Name: {name}\n'
74
+ sub_ids.append(id)
75
+
76
+ protein_string += '==========================================================================================\n'
77
+ total_ids.append(sub_ids)
78
+ except:
79
+ protein_string += f'No proteins found for {protein_name}'
80
+ protein_string += '==========================================================================================\n'
81
+ total_ids.append([])
82
+
83
+ return total_ids, protein_string, None
84
+
85
+ def get_qed(smiles):
86
+ '''
87
+ Helper function to compute QED for a given molecule.
88
+ Args:
89
+ smiles: the input smiles string
90
+ Returns:
91
+ qed: the QED score of the molecule.
92
+ '''
93
+ mol = Chem.MolFromSmiles(smiles)
94
+ qed = Chem.QED.default(mol)
95
+ return qed
96
+
97
+ @tool
98
+ def listbioactives_node(up_ids_list: list[str]) -> (list[int], list[str], str):
99
+ '''
100
+ Accepts a UNIPROT ID and searches for bioactive molecules
101
+ Args:
102
+ up_ids_list: the UNIPROT IDs of the proteins to search for.
103
+ Returns:
104
+ total_bioacts_list: a list of the number of bioactive molecules for each protein
105
+ total_chembl_ids_list: a list of the ChEMBL IDs for each protein
106
+ bioact_string: a string containing the results of the search.
107
+ '''
108
+ print("List bioactives tool")
109
+ print('===================================================')
110
+
111
+ total_bioacts_list = []
112
+ total_chembl_ids_list = []
113
+ bioact_string = ''
114
+
115
+ for up_id in up_ids_list:
116
+
117
+ targets = new_client.target
118
+ bioact = new_client.activity
119
+
120
+ try:
121
+ target_info = targets.get(target_components__accession=up_id).only("target_chembl_id","organism", "pref_name", "target_type")
122
+ target_info = pd.DataFrame.from_records(target_info)
123
+ print(target_info)
124
+ if len(target_info) > 0:
125
+ print(f"Found info for Uniprot ID: {up_id}")
126
+
127
+ chembl_ids = target_info['target_chembl_id'].tolist()
128
+
129
+ chembl_ids = list(set(chembl_ids))
130
+ print(f"Found {len(chembl_ids)} unique ChEMBL IDs")
131
+
132
+ len_all_bioacts = []
133
+ for chembl_id in chembl_ids:
134
+ bioact_chosen = bioact.filter(target_chembl_id=chembl_id, type="IC50", relation="=").only(
135
+ "molecule_chembl_id",
136
+ "type",
137
+ "standard_units",
138
+ "relation",
139
+ "standard_value",
140
+ )
141
+ len_this_bioacts = len(bioact_chosen)
142
+ len_all_bioacts.append(len_this_bioacts)
143
+ bioact_string += f"For Uniprot {up_id}: length of Bioactivities for ChEMBL ID {chembl_id}: {len_this_bioacts}\n"
144
+
145
+ bioact_string += f'================================================================================================\n'
146
+ total_chembl_ids_list.append(chembl_ids)
147
+ total_bioacts_list.append(len_all_bioacts)
148
+
149
+ except:
150
+ bioact_string += f'No bioactives found for Uniprot {up_id}\n'
151
+ bioact_string += f'================================================================================================\n'
152
+ total_chembl_ids_list.append([])
153
+ total_bioacts_list.append([])
154
+ return total_bioacts_list, bioact_string, None
155
+
156
+ @tool
157
+ def getbioactives_node(chembl_ids_list: list[str]) -> (list[str], str):
158
+ '''
159
+ Accepts a Chembl ID and get all bioactives molecule SMILES and IC50s for that ID
160
+ Args:
161
+ chembl_id: the chembl ID to query
162
+ Returns:
163
+ bioactives_list: a list of the bioactive molecules for each chembl ID
164
+ bioactives_string: a string containing the results of the search.
165
+ bioactives_images: a list of images for each bioactive molecule.
166
+ '''
167
+ print("Get bioactives tool")
168
+ print('===================================================')
169
+
170
+ bioactives_list = []
171
+ bioactives_images = []
172
+ bioactives_string = ''
173
+
174
+ for chembl_id in chembl_ids_list:
175
+ try:
176
+ #check if f'{chembl_id}_bioactives.csv' exists
177
+ chembl_id = chembl_id.upper()
178
+ if os.path.exists(f'{chembl_id}_bioactives.csv'):
179
+ print(f'Found {chembl_id}_bioactives.csv')
180
+ total_bioact_df = pd.read_csv(f'{chembl_id}_bioactives.csv')
181
+ print(f"number of records: {len(total_bioact_df)}")
182
+ else:
183
+
184
+ compounds = new_client.molecule
185
+ bioact = new_client.activity
186
+
187
+ bioact_chosen = bioact.filter(target_chembl_id=chembl_id, type="IC50", relation="=").only(
188
+ "molecule_chembl_id",
189
+ "type",
190
+ "standard_units",
191
+ "relation",
192
+ "standard_value",
193
+ )
194
+
195
+ chembl_ids = []
196
+ ic50s = []
197
+ for record in bioact_chosen:
198
+ if record["standard_units"] == 'nM':
199
+ chembl_ids.append(record["molecule_chembl_id"])
200
+ ic50s.append(float(record["standard_value"]))
201
+
202
+ bioact_dict = {'chembl_ids' : chembl_ids, 'IC50s': ic50s}
203
+ bioact_df = pd.DataFrame.from_dict(bioact_dict)
204
+ bioact_df.drop_duplicates(subset=["chembl_ids"], keep= "last")
205
+ print(f"Number of records: {len(bioact_df)}")
206
+ print(bioact_df.shape)
207
+
208
+ compounds_provider = compounds.filter(molecule_chembl_id__in=bioact_df["chembl_ids"].to_list()).only(
209
+ "molecule_chembl_id",
210
+ "molecule_structures"
211
+ )
212
+
213
+ cids_list = []
214
+ smiles_list = []
215
+
216
+ for record in compounds_provider:
217
+ cid = record['molecule_chembl_id']
218
+ cids_list.append(cid)
219
+
220
+ if record['molecule_structures']:
221
+ if record['molecule_structures']['canonical_smiles']:
222
+ smile = record['molecule_structures']['canonical_smiles']
223
+ else:
224
+ print("No canonical smiles")
225
+ smile = None
226
+ else:
227
+ print('no structures')
228
+ smile = None
229
+ smiles_list.append(smile)
230
+
231
+ new_dict = {'SMILES': smiles_list, 'chembl_ids_2': cids_list}
232
+ new_df = pd.DataFrame.from_dict(new_dict)
233
+
234
+ total_bioact_df = pd.merge(bioact_df, new_df, left_on='chembl_ids', right_on='chembl_ids_2')
235
+ print(f"number of records: {len(total_bioact_df)}")
236
+
237
+ total_bioact_df.drop_duplicates(subset=["chembl_ids"], keep= "last")
238
+ print(f"number of records after removing duplicates: {len(total_bioact_df)}")
239
+
240
+ total_bioact_df.dropna(axis=0, how='any', inplace=True)
241
+ total_bioact_df.drop(["chembl_ids_2"],axis=1,inplace=True)
242
+ print(f"number of records after dropping Null values: {len(total_bioact_df)}")
243
+
244
+ total_bioact_df.sort_values(by=["IC50s"],inplace=True)
245
+
246
+ if len(total_bioact_df) > 0:
247
+ total_bioact_df.to_csv(f'{chembl_id}_bioactives.csv')
248
+
249
+ limit = 50
250
+ if len(total_bioact_df) > limit:
251
+ total_bioact_df = total_bioact_df.iloc[:limit]
252
+
253
+ bioact_tuple_list = []
254
+ bioactives_string += f'Results for top bioactivity (IC50 value) for molecules in ChEMBL ID: {chembl_id}. \n'
255
+ for smile, ic50 in zip(total_bioact_df['SMILES'], total_bioact_df['IC50s']):
256
+ bioactives_string += f'Molecule SMILES: {smile}, IC50 (nM): {ic50}\n'
257
+ bioact_tuple_list.append((smile, ic50))
258
+ bioactives_string += f'=========================================================================================\n'
259
+
260
+ mols = [Chem.MolFromSmiles(smile) for smile in total_bioact_df['SMILES'].to_list()]
261
+ legends = [f'IC50: {ic50}' for ic50 in total_bioact_df['IC50s'].to_list()]
262
+ img = MolsToGridImage(mols, molsPerRow=5, legends=legends, subImgSize=(200,200))
263
+ bioactives_images.append(img)
264
+ bioactives_list.append(bioact_tuple_list)
265
+ except:
266
+ bioactives_list.append([])
267
+ bioactives_string += f'No bioactives found for ChEMBL ID: {chembl_id}\n'
268
+ bioactives_string += f'=========================================================================================\n'
269
+ bioactives_images.append(None)
270
+
271
+ try:
272
+ pic = img.data
273
+ with open('current_image.png', 'wb') as f:
274
+ f.write(pic)
275
+ img = Image.open('current_image.png')
276
+
277
+ except Exception as e:
278
+ print(f"Error occurred while processing image: {e}")
279
+ img = None
280
+
281
+ return bioactives_list, bioactives_string, img
282
+
283
+ @tool
284
+ def predict_node(smiles_list_in: list[str], chembl_id: str) -> (list[float],str):
285
+ '''
286
+ uses the current_bioactives.csv file from the get_bioactives node to fit the
287
+ Light GBM model and predict the IC50 for the current smiles.
288
+ Args:
289
+ smiles_list: the SMILES strings of the molecules to predict
290
+ chembl_id: the chembl ID to query
291
+ Returns:
292
+ preds: a list of predicted IC50 values for the input SMILES
293
+ preds_string: a string containing the predicted IC50 values for the input SMILES
294
+ '''
295
+ print("Predict Tool")
296
+ print('===================================================')
297
+
298
+ # if f'{chembl_id}_bioactives.csv' does not exist, call the bioactives node
299
+ if not os.path.exists(f'{chembl_id}_bioactives.csv'):
300
+ _, _, _ = getbioactives_node([chembl_id])
301
+
302
+ try:
303
+ chembl_id = chembl_id.upper()
304
+ df = pd.read_csv(f'{chembl_id}_bioactives.csv')
305
+ #if length of the dataframe is over 2000, take a random sample of 2000 points
306
+ if len(df) > 2000:
307
+ df = df.sample(n=2000, random_state=42)
308
+
309
+ y_raw = df["IC50s"].to_list()
310
+ smiles_list = df["SMILES"].to_list()
311
+ ions_to_clean = ["[Na+].",".[Na+]","[Cl-].",".[Cl-]","[K+].",".[K+]"]
312
+ Xa = []
313
+ y = []
314
+ for smile, value in zip(smiles_list, y_raw):
315
+ for ion in ions_to_clean:
316
+ smile = smile.replace(ion,"")
317
+ y.append(np.log10(value))
318
+ Xa.append(smile)
319
+
320
+ mols = [Chem.MolFromSmiles(smile) for smile in Xa]
321
+ print(f"Number of molecules: {len(mols)}")
322
+
323
+ featurizer=dc.feat.RDKitDescriptors()
324
+ featname="RDKitDescriptors"
325
+ f = featurizer.featurize(mols)
326
+
327
+ nan_indicies = np.isnan(f)
328
+ bad_rows = []
329
+ for i, row in enumerate(nan_indicies):
330
+ for item in row:
331
+ if item == True:
332
+ if i not in bad_rows:
333
+ print(f"Row {i} has a NaN.")
334
+ bad_rows.append(i)
335
+
336
+ print(f"Old dimensions are: {f.shape}.")
337
+
338
+ for j,i in enumerate(bad_rows):
339
+ k=i-j
340
+ f = np.delete(f,k,axis=0)
341
+ y = np.delete(y,k,axis=0)
342
+ Xa = np.delete(Xa,k,axis=0)
343
+ print(f"Deleting row {k} from arrays.")
344
+
345
+ print(f"New dimensions are: {f.shape}")
346
+ if f.shape[0] != len(y) or f.shape[0] != len(Xa):
347
+ raise ValueError("Number of rows in X and y do not match.")
348
+
349
+ X_train, X_test, y_train, y_test = train_test_split(f, y, test_size=0.2, random_state=42)
350
+ scaler = StandardScaler()
351
+ X_train = scaler.fit_transform(X_train)
352
+ X_test = scaler.transform(X_test)
353
+
354
+ model = LGBMRegressor(metric='rmse', max_depth = 50, verbose = -1, num_leaves = 31,
355
+ feature_fraction = 0.8, min_data_in_leaf = 20)
356
+ modelname = "LightGBM Regressor"
357
+ model.fit(X_train, y_train)
358
+
359
+ train_score = model.score(X_train,y_train)
360
+ print(f"score for training set: {train_score:.3f}")
361
+
362
+ valid_score = model.score(X_test, y_test)
363
+ print(f"score for validation set: {valid_score:.3f}")
364
+ except:
365
+ return [], 'Model training failed, unable to predict.', None
366
+
367
+ preds = []
368
+ preds_string = ''
369
+
370
+ for smiles in smiles_list_in:
371
+ print(f"in predict node, smiles: {smiles}")
372
+ try:
373
+ for ion in ions_to_clean:
374
+ smiles = smiles.replace(ion,"")
375
+ test_mol = Chem.MolFromSmiles(smiles)
376
+ test_feat = featurizer.featurize([test_mol])
377
+ test_feat = scaler.transform(test_feat)
378
+ prediction = model.predict(test_feat)
379
+ test_ic50 = 10**(prediction[0])
380
+ print(f"Predicted IC50 for {smiles}: {test_ic50}")
381
+ preds_string += f"The predicted IC50 value for {smiles} is : {test_ic50:.3f} nM.\n"
382
+
383
+ preds.append(test_ic50)
384
+ except:
385
+ preds.append(None)
386
+ preds_string += f"The prediction for {smiles} failed.\n"
387
+
388
+ preds_string += f"The Bioactive data was fitted with the LightGMB model, using RDKit descriptors. The training score \
389
+ was {train_score:.3f} and the testing score was {valid_score:.3f}. "
390
+ return preds, preds_string, None
391
+
392
+ @tool
393
+ def gpt_node(chembl_id: str) -> (list[str], str, Image.Image):
394
+ '''
395
+ Uses a Chembl dataset, previously stored in a CSV file by the get_bioactives node, to
396
+ to finetune a GPT model to generate novel molecules for the target protein.
397
+
398
+ Args:
399
+ chembl_id: the ChEMBL ID to query
400
+ returns:
401
+ smiles_list: a list of generated SMILES strings
402
+ gpt_string: a string containing the results of the GPT finetuning and generation.
403
+ img: an image containing the generated molecules.
404
+ '''
405
+ print("GPT node")
406
+ print('===================================================')
407
+
408
+ # if f'{chembl_id}_bioactives.csv' does not exist, call the bioactives node
409
+ chembl_id = chembl_id.upper()
410
+ if not os.path.exists(f'{chembl_id}_bioactives.csv'):
411
+ _, _, _ = getbioactives_node([chembl_id])
412
+
413
+ try:
414
+ df = pd.read_csv(f'{chembl_id}_bioactives.csv')
415
+ smiles_list, gpt_string, img = finetune_gpt(df, chembl_id)
416
+
417
+ except:
418
+ gpt_string = ''
419
+ smiles_list = []
420
+ img = None
421
+
422
+ return smiles_list, gpt_string, img
423
+
424
+ def get_protein_from_pdb(pdb_id):
425
+ '''
426
+ Helper function to get the protein information from the PDB database.
427
+ Args:
428
+ pdb_id: the PDB ID of the protein
429
+ Returns:
430
+ r.text: the PDB information as a string
431
+ '''
432
+ url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
433
+ r = requests.get(url)
434
+ return r.text
435
+
436
+ def one_to_three(one_seq):
437
+ '''
438
+ Converts a one-letter amino acid sequence to a three-letter sequence.
439
+ Args:
440
+ one_seq: the one-letter amino acid sequence
441
+ Returns:
442
+ three_seq: the three-letter amino acid sequence
443
+ '''
444
+ rev_aa_hash = {
445
+ 'A': 'ALA',
446
+ 'R': 'ARG',
447
+ 'N': 'ASN',
448
+ 'D': 'ASP',
449
+ 'C': 'CYS',
450
+ 'Q': 'GLN',
451
+ 'E': 'GLU',
452
+ 'G': 'GLY',
453
+ 'H': 'HIS',
454
+ 'I': 'ILE',
455
+ 'L': 'LEU',
456
+ 'K': 'LYS',
457
+ 'M': 'MET',
458
+ 'F': 'PHE',
459
+ 'P': 'PRO',
460
+ 'S': 'SER',
461
+ 'T': 'THR',
462
+ 'W': 'TRP',
463
+ 'Y': 'TYR',
464
+ 'V': 'VAL'
465
+ }
466
+
467
+ try:
468
+ three_seq = rev_aa_hash[one_seq]
469
+ except:
470
+ three_seq = 'X'
471
+
472
+ return three_seq
473
+
474
+ def three_to_one(three_seq):
475
+ '''
476
+ Converts a three-letter amino acid sequence to a one-letter sequence.
477
+ Args:
478
+ three_seq: the three-letter amino acid sequence
479
+ Returns:
480
+ one_seq: the one-letter amino acid sequence
481
+ '''
482
+ aa_hash = {
483
+ 'ALA': 'A',
484
+ 'ARG': 'R',
485
+ 'ASN': 'N',
486
+ 'ASP': 'D',
487
+ 'CYS': 'C',
488
+ 'GLN': 'Q',
489
+ 'GLU': 'E',
490
+ 'GLY': 'G',
491
+ 'HIS': 'H',
492
+ 'ILE': 'I',
493
+ 'LEU': 'L',
494
+ 'LYS': 'K',
495
+ 'MET': 'M',
496
+ 'PHE': 'F',
497
+ 'PRO': 'P',
498
+ 'SER': 'S',
499
+ 'THR': 'T',
500
+ 'TRP': 'W',
501
+ 'TYR': 'Y',
502
+ 'VAL': 'V'
503
+ }
504
+
505
+ one_seq = []
506
+ for residue in three_seq:
507
+ try:
508
+ one_seq.append(aa_hash[residue])
509
+ except:
510
+ one_seq.append('X')
511
+ return one_seq
512
+
513
+ @tool
514
+ def pdb_node(test_pdb_list: list[str]) -> (list[str], str):
515
+ '''
516
+ Accepts a PDB ID and queires the protein databank for the sequence of the protein, as well as other
517
+ information such as ligands.
518
+ Args:
519
+ test_pdb_list: the PDB IDs to query
520
+ Returns:
521
+ all_seqs: a list of the sequences for each PDB ID
522
+ total_pdb_string: a string containing the results of the PDB query.
523
+ (collects all ligands but does not return them currently)
524
+ '''
525
+
526
+ print(f"pdb toolS")
527
+ print('===================================================')
528
+
529
+ total_pdb_string = ''
530
+ all_seqs = []
531
+ all_ligands = []
532
+
533
+ for test_pdb in test_pdb_list:
534
+ try:
535
+ pdb_str = get_protein_from_pdb(test_pdb)
536
+ chains = {}
537
+ other_molecules = {}
538
+
539
+ #print(pdb_str.split('\n')[0])
540
+ for line in pdb_str.split('\n'):
541
+ parts = line.split()
542
+ try:
543
+ if parts[0] == 'SEQRES':
544
+ if parts[2] not in chains:
545
+ chains[parts[2]] = []
546
+ chains[parts[2]].extend(parts[4:])
547
+ if parts[0] == 'HETNAM':
548
+ j = 1
549
+ if parts[1].strip() in ['2','3','4','5','6','7','8','9']:
550
+ j = 2
551
+ print(parts[j])
552
+ if parts[j] not in other_molecules:
553
+ other_molecules[parts[j]] = []
554
+ other_molecules[parts[j]].extend(parts[2:])
555
+ except:
556
+ print('Blank line')
557
+
558
+ chains_ol = {}
559
+ for chain in chains:
560
+ chains_ol[chain] = three_to_one(chains[chain])
561
+
562
+ sub_seqs = []
563
+ sub_ligands = []
564
+ total_pdb_string += f"Chains in PDB ID {test_pdb}: {', '.join(chains.keys())} \n"
565
+ for chain in chains_ol:
566
+ total_pdb_string += f"Chain {chain}: {''.join(chains_ol[chain])} \n"
567
+ sub_seqs.append(''.join(chains_ol[chain]))
568
+ print(f"Chain {chain}: {''.join(chains_ol[chain])}")
569
+ total_pdb_string += f"Ligands in PDB ID {test_pdb}.\n"
570
+ for mol in other_molecules:
571
+ total_pdb_string += f"Molecule {mol}: {''.join(other_molecules[mol])} \n"
572
+ sub_ligands.append(''.join(other_molecules[mol]))
573
+ total_pdb_string += f'=========================================================================================\n'
574
+
575
+ all_seqs.append(sub_seqs)
576
+ all_ligands.append(sub_ligands)
577
+ except:
578
+ total_pdb_string += f'Failed to get data for PDB ID {test_pdb}\n'
579
+ total_pdb_string += f'=========================================================================================\n'
580
+ all_seqs.append([])
581
+ all_ligands.append([])
582
+ return all_seqs, total_pdb_string, None
583
+
584
+ @tool
585
+ def find_node(test_protein_list: list[str]) -> (list[str], str):
586
+ '''
587
+ Accepts a protein name and searches the protein databack for PDB IDs that match along with the entry titles.
588
+ Args:
589
+ test_protein_list: the protein names to query
590
+ Returns:
591
+ total_ids: a list of the PDB IDs for each protein name
592
+ pdb_string: a string containing the results of the PDB search.
593
+ '''
594
+
595
+ print(f"PDB search tool")
596
+ print('===================================================')
597
+
598
+ total_ids = []
599
+ pdb_string = ''
600
+ which_pdbs = 0
601
+
602
+ for test_protein in test_protein_list:
603
+ try:
604
+ query = TextQuery(value=test_protein)
605
+ results = query()
606
+
607
+ def pdb_gen():
608
+ for rid in results:
609
+ yield(rid)
610
+
611
+ take10 = itertools.islice(pdb_gen(), which_pdbs, which_pdbs+10, 1)
612
+
613
+ local_ids = []
614
+ pdb_string += f'10 PDBs that match the protein {test_protein} are: \n'
615
+ for pdb in take10:
616
+ data = requests.get(f"https://data.rcsb.org/rest/v1/core/entry/{pdb}").json()
617
+ title = data['struct']['title']
618
+ pdb_string += f'PDB ID: {pdb}, with title: {title} \n'
619
+ local_ids.append(pdb)
620
+ total_ids.append(local_ids)
621
+ except:
622
+ pdb_string += f'Failed to get PDB IDs for protein {test_protein}\n'
623
+ total_ids.append([])
624
+ return total_ids, pdb_string, None
625
+
626
+ @tool
627
+ def docking_node(smiles_list: list[str], query_protein: str) -> (list[float], str):
628
+ '''
629
+ Docking tool: uses dockstring to dock the molecule into the protein
630
+ Args:
631
+ smiles_list: the SMILES strings of the molecules to dock
632
+ protein: the protein to dock into
633
+ Returns:
634
+ docking_scores: a list of docking scores for each molecule
635
+ docking_string: a string containing the results of the docking.
636
+ '''
637
+ print("docking tool")
638
+ print('===================================================')
639
+ cpuCount = os.cpu_count()
640
+ print(f"Number of CPUs: {cpuCount}")
641
+
642
+ print(f'query_protein: {query_protein}')
643
+
644
+ scores_list = []
645
+ scores_string = 'Docking below performed with AutoDock Vina on protein structures from the DUDE database.\n'
646
+
647
+ for query_smiles in smiles_list:
648
+ try:
649
+ query_smiles = query_smiles.replace('.[Na+]','').replace('.[Na+]','').replace('.[K+]','').replace('[K+].','').replace('.[Cl-]','').replace('[Cl-].','')
650
+ target = load_target(query_protein)
651
+ print("===============================================")
652
+ print(f"Docking molecule with {cpuCount} cpu cores.")
653
+ score, aux = target.dock(query_smiles, num_cpus = cpuCount)
654
+ scores_list.append(score)
655
+ mol = aux['ligand']
656
+ print(f"Docking score: {score}")
657
+ print("===============================================")
658
+ atoms_list = ""
659
+ template = mol
660
+ molH = Chem.AddHs(mol)
661
+ AllChem.ConstrainedEmbed(molH,template, useTethers=True)
662
+ xyz_string = f"{molH.GetNumAtoms()}\n\n"
663
+ for atom in molH.GetAtoms():
664
+ atoms_list += atom.GetSymbol()
665
+ pos = molH.GetConformer().GetAtomPosition(atom.GetIdx())
666
+ xyz_string += f"{atom.GetSymbol()} {pos[0]} {pos[1]} {pos[2]}\n"
667
+ scores_string += f"Docking score for molecule with SMILES: {query_smiles} is: {score} kcal/mol \n\n"
668
+ scores_string += f"pose XYZ structure for molecule with SMILES: {query_smiles} is: \n"
669
+ lines = xyz_string.split('\n')
670
+ for line in lines[2:]:
671
+ scores_string += f'{line}\n'
672
+ scores_string += f"=========================================================\n"
673
+
674
+ except:
675
+ print(f"Molecule {query_smiles} could not be docked!")
676
+ scores_string = 'Could not dock!'
677
+ scores_list.append(None)
678
+ return scores_list, scores_string, None
679
+
680
+ @tool
681
+ def target_node(search_descriptors: list[str]):
682
+ '''
683
+ Accepts a disease name and searches Open Targets for associated targets
684
+
685
+ Args:
686
+ search_descriptor (str): Disease name
687
+
688
+ Returns:
689
+ targets_list (list): List of targets
690
+ targets_string (str): String of targets
691
+ None
692
+ '''
693
+ base_url = "https://api.platform.opentargets.org/api/v4/graphql"
694
+
695
+ disease_query_string = """
696
+ query searchEntity($queryString: String!) {
697
+ search(queryString: $queryString){
698
+ total
699
+ hits {
700
+ id
701
+ entity
702
+ description
703
+ }
704
+ }
705
+ }
706
+ """
707
+
708
+ target_query_string = """
709
+ query associatedTargets($efo_id: String!) {
710
+ disease(efoId: $efo_id) {
711
+ id
712
+ name
713
+ associatedTargets {
714
+ count
715
+ rows {
716
+ target {
717
+ id
718
+ approvedSymbol
719
+ }
720
+ score
721
+ }
722
+ }
723
+ }
724
+ }
725
+ """
726
+ total_targets_list = []
727
+ total_targets_string = ''
728
+
729
+ for search_descriptor in search_descriptors:
730
+
731
+ variables = {"queryString": search_descriptor}
732
+ r = requests.post(base_url, json={"query": disease_query_string, "variables": variables})
733
+
734
+ disease_list = []
735
+ targets_list = []
736
+
737
+ if r.status_code == 200:
738
+ api_response = json.loads(r.text)
739
+ if len(api_response['data']['search']['hits']) > 0:
740
+ for hit in api_response['data']['search']['hits']:
741
+ if hit['entity'] == 'disease':
742
+ disease_list.append(hit['id'])
743
+ else:
744
+ print('Could not find results.')
745
+
746
+ if len(disease_list) > 0:
747
+ q = requests.post(base_url, json={"query": target_query_string, "variables": {"efo_id": disease_list[0]}})
748
+ if q.status_code == 200:
749
+ api_response = json.loads(q.text)
750
+ for target in api_response['data']['disease']['associatedTargets']['rows']:
751
+ targets_list.append(target['target']['approvedSymbol'])
752
+
753
+ targets_string = f'Possible targets for {search_descriptor} include: \n'
754
+ if len(targets_list) > 0:
755
+ for i, target in enumerate(targets_list):
756
+ targets_string += f'{i+1}. {target}\n'
757
+ else:
758
+ targets_string = f'No targets found for {search_descriptor}'
759
+
760
+ total_targets_list.append(targets_list)
761
+ total_targets_string += targets_string
762
+
763
+ return total_targets_list, total_targets_string, None
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bitsandbytes
2
+ pubchempy
3
+ rdkit
4
+ chembl_webresource_client
5
+ rcsb-api
6
+ deepchem
7
+ dockstring
8
+ openbabel-wheel
9
+ openai
10
+ langchain_core
11
+ langchain_openai
12
+ langgraph
13
+ gradio
14
+ torch
15
+ matplotlib
16
+ pillow
17
+ gradio-client
18
+ transformers
19
+ dockstring
20
+ openbabel-wheel
21
+ numpy
22
+ elevenlabs
23
+ lightgbm
24
+ tf-keras
25
+ tensorflow
26
+ accelerate