"""
Train a new model.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import IterableDataset, DataLoader
from sklearn.metrics import average_precision_score as average_precision
import sys
import argparse
import h5py
import subprocess as sp
import numpy as np
import pandas as pd
import gzip as gz
from tqdm import tqdm
from ..alphabets import Uniprot21
from ..utils import PairedDataset, collate_paired_sequences
from ..models.embedding import FullyConnectedEmbed
from ..models.contact import ContactCNN
from ..models.interaction import ModelInteraction
def add_args(parser):
"""
Create parser for command line utility.
:meta private:
"""
data_grp = parser.add_argument_group("Data")
proj_grp = parser.add_argument_group("Projection Module")
contact_grp = parser.add_argument_group("Contact Module")
inter_grp = parser.add_argument_group("Interaction Module")
train_grp = parser.add_argument_group("Training")
misc_grp = parser.add_argument_group("Output and Device")
# Data
data_grp.add_argument(
"--train", required=True, help="list of training pairs"
)
data_grp.add_argument(
"--test", required=True, help="list of validation/testing pairs"
)
data_grp.add_argument(
"--embedding", required=True, help="h5py path containing embedded sequences"
)
data_grp.add_argument(
"--no-augment",
action="store_true",
help="data is automatically augmented by adding (B A) for all pairs (A B). Set this flag to not augment data",
)
parser.add_argument(
"--protein-size",
default=800,
help="maximum protein size to use in training data (default: 800)",
)
# Embedding model
proj_grp.add_argument(
"--input-dim",
type=int,
default=6165,
help="dimension of input language model embedding (per amino acid) (default: 6165)",
)
proj_grp.add_argument(
"--projection-dim",
type=int,
default=100,
help="dimension of embedding projection layer (default: 100)",
)
proj_grp.add_argument(
"--dropout-p",
type=float,
default=0.5,
help="parameter p for embedding dropout layer (default: 0.5)",
)
# Contact model
contact_grp.add_argument(
"--hidden-dim",
type=int,
default=50,
help="number of hidden units for comparison layer in contact prediction (default: 50)",
)
contact_grp.add_argument(
"--kernel-width",
type=int,
default=7,
help="width of convolutional filter for contact prediction (default: 7)",
)
# Interaction Model
inter_grp.add_argument(
"--no-w",
action="store_true",
help="don't use weight matrix in interaction prediction model",
)
inter_grp.add_argument(
"--no-sigmoid",
action="store_true",
help="don't use sigmoid activation at end of interaction model",
)
inter_grp.add_argument(
"--do-pool",
action="store_true",
help="use max pool layer in interaction prediction model",
)
inter_grp.add_argument(
"--pool-width",
type=int,
default=9,
help="size of max-pool in interaction model (default: 9)",
)
# Training
train_grp.add_argument(
"--num-epochs", type=int, default=10, help="number of epochs (default: 10)"
)
train_grp.add_argument(
"--batch-size", type=int, default=25, help="minibatch size (default: 25)"
)
train_grp.add_argument(
"--weight-decay", type=float, default=0, help="L2 regularization (default: 0)"
)
train_grp.add_argument(
"--lr", type=float, default=0.001, help="learning rate (default: 0.001)"
)
train_grp.add_argument(
"--lambda",
dest="lambda_",
type=float,
default=0.35,
help="weight on the similarity objective (default: 0.35)",
)
# Output
misc_grp.add_argument("-o", "--output", help="output file path (default: stdout)")
misc_grp.add_argument("--save-prefix", help="path prefix for saving models")
misc_grp.add_argument(
"-d", "--device", type=int, default=-1, help="compute device to use"
)
misc_grp.add_argument(
"--checkpoint", help="checkpoint model to start training from"
)
return parser
[docs]def predict_interaction(model, n0, n1, tensors, use_cuda):
"""
Predict whether a list of protein pairs will interact.
:param model: Model to be trained
:type model: dscript.models.interaction.ModelInteraction
:param n0: First protein names
:type n0: list[str]
:param n1: Second protein names
:type n1: list[str]
:param tensors: Dictionary of protein names to embeddings
:type tensors: dict[str, torch.Tensor]
:param use_cuda: Whether to use GPU
:type use_cuda: bool
"""
b = len(n0)
p_hat = []
for i in range(b):
z_a = tensors[n0[i]]
z_b = tensors[n1[i]]
if use_cuda:
z_a = z_a.cuda()
z_b = z_b.cuda()
p_hat.append(model.predict(z_a, z_b))
p_hat = torch.stack(p_hat, 0)
return p_hat
[docs]def predict_cmap_interaction(model, n0, n1, tensors, use_cuda):
"""
Predict whether a list of protein pairs will interact, as well as their contact map.
:param model: Model to be trained
:type model: dscript.models.interaction.ModelInteraction
:param n0: First protein names
:type n0: list[str]
:param n1: Second protein names
:type n1: list[str]
:param tensors: Dictionary of protein names to embeddings
:type tensors: dict[str, torch.Tensor]
:param use_cuda: Whether to use GPU
:type use_cuda: bool
"""
b = len(n0)
p_hat = []
c_map_mag = []
for i in range(b):
z_a = tensors[n0[i]]
z_b = tensors[n1[i]]
if use_cuda:
z_a = z_a.cuda()
z_b = z_b.cuda()
cm, ph = model.map_predict(z_a, z_b)
p_hat.append(ph)
c_map_mag.append(torch.mean(cm))
p_hat = torch.stack(p_hat, 0)
c_map_mag = torch.stack(c_map_mag, 0)
return c_map_mag, p_hat
[docs]def interaction_grad(model, n0, n1, y, tensors, use_cuda, weight=0.35):
"""
Compute gradient and backpropagate loss for a batch.
:param model: Model to be trained
:type model: dscript.models.interaction.ModelInteraction
:param n0: First protein names
:type n0: list[str]
:param n1: Second protein names
:type n1: list[str]
:param y: Interaction labels
:type y: torch.Tensor
:param tensors: Dictionary of protein names to embeddings
:type tensors: dict[str, torch.Tensor]
:param use_cuda: Whether to use GPU
:type use_cuda: bool
:param weight: Weight on the contact map magnitude objective. BCE loss is :math:`1 - \\text{weight}`.
:type weight: float
:return: (Loss, number correct, mean square error, batch size)
:rtype: (torch.Tensor, int, torch.Tensor, int)
"""
c_map_mag, p_hat = predict_cmap_interaction(model, n0, n1, tensors, use_cuda)
if use_cuda:
y = y.cuda()
y = Variable(y)
p_hat = p_hat.float()
bce_loss = F.binary_cross_entropy(p_hat.float(), y.float())
cmap_loss = torch.mean(c_map_mag)
loss = (weight * bce_loss) + ((1 - weight) * cmap_loss)
b = len(p_hat)
# backprop loss
loss.backward()
if use_cuda:
y = y.cpu()
p_hat = p_hat.cpu()
with torch.no_grad():
guess_cutoff = 0.5
p_hat = p_hat.float()
p_guess = (guess_cutoff * torch.ones(b) < p_hat).float()
y = y.float()
correct = torch.sum(p_guess == y).item()
mse = torch.mean((y.float() - p_hat) ** 2).item()
return loss, correct, mse, b
[docs]def interaction_eval(model, test_iterator, tensors, use_cuda):
"""
Evaluate test data set performance.
:param model: Model to be trained
:type model: dscript.models.interaction.ModelInteraction
:param test_iterator: Test data iterator
:type test_iterator: torch.utils.data.DataLoader
:param tensors: Dictionary of protein names to embeddings
:type tensors: dict[str, torch.Tensor]
:param use_cuda: Whether to use GPU
:type use_cuda: bool
:return: (Loss, number correct, mean square error, precision, recall, F1 Score, AUPR)
:rtype: (torch.Tensor, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)
"""
p_hat = []
true_y = []
for n0, n1, y in test_iterator:
p_hat.append(predict_interaction(model, n0, n1, tensors, use_cuda))
true_y.append(y)
y = torch.cat(true_y, 0)
p_hat = torch.cat(p_hat, 0)
if use_cuda:
y.cuda()
p_hat = torch.Tensor([x.cuda() for x in p_hat])
p_hat.cuda()
loss = F.binary_cross_entropy(p_hat.float(), y.float()).item()
b = len(y)
with torch.no_grad():
guess_cutoff = torch.Tensor([0.5]).float()
p_hat = p_hat.float()
y = y.float()
p_guess = (guess_cutoff * torch.ones(b) < p_hat).float()
correct = torch.sum(p_guess == y).item()
mse = torch.mean((y.float() - p_hat) ** 2).item()
tp = torch.sum(y * p_hat).item()
pr = tp / torch.sum(p_hat).item()
re = tp / torch.sum(y).item()
f1 = 2 * pr * re / (pr + re)
y = y.cpu().numpy()
p_hat = p_hat.data.cpu().numpy()
aupr = average_precision(y, p_hat)
return loss, correct, mse, pr, re, f1, aupr
[docs]def train_model(args, output):
## Create data sets
batch_size = args.batch_size
use_cuda = (args.device > -1) and torch.cuda.is_available()
train_fi = args.train
test_fi = args.test
no_augment = args.no_augment
embedding_h5 = args.embedding
h5fi = h5py.File(embedding_h5, "r")
print(f"# Loading training pairs from {train_fi}...", file=output)
output.flush()
train_df = pd.read_csv(train_fi,sep='\t',header=None)
train_df.columns = ["prot1","prot2","label"]
if no_augment:
train_p1 = train_df["prot1"]
train_p2 = train_df["prot2"]
train_y = torch.from_numpy(train_df["label"].values)
else:
train_p1 = pd.concat((train_df["prot1"], train_df["prot2"]), axis=0).reset_index(drop=True)
train_p2 = pd.concat((train_df["prot2"], train_df["prot1"]), axis=0).reset_index(drop=True)
train_y = torch.from_numpy(pd.concat((train_df["label"], train_df["label"])).values)
train_dataset = PairedDataset(train_p1, train_p2, train_y)
train_iterator = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
collate_fn=collate_paired_sequences,
shuffle=True,
)
print(f"# Loaded {len(train_p1)} training pairs", file=output)
print(f"# Loading testing pairs from {test_fi}...", file=output)
output.flush()
test_df = pd.read_csv(test_fi,sep='\t',header=None)
test_df.columns = ["prot1","prot2","label"]
test_p1 = test_df["prot1"]
test_p2 = test_df["prot2"]
test_y = torch.from_numpy(test_df["label"].values)
test_dataset = PairedDataset(test_p1, test_p2, test_y)
test_iterator = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
collate_fn=collate_paired_sequences,
shuffle=False,
)
print(f"# Loaded {len(test_p1)} test pairs", file=output)
print(f"# Loading embeddings", file=output)
output.flush()
embeddings = {}
all_proteins = (
set(train_p1)
.union(train_p2)
.union(test_p1)
.union(test_p2)
)
for prot_name in tqdm(all_proteins):
embeddings[prot_name] = torch.from_numpy(h5fi[prot_name][:, :])
if args.checkpoint is None:
# Create embedding model
input_dim = args.input_dim
projection_dim = args.projection_dim
dropout_p = args.dropout_p
embedding_model = FullyConnectedEmbed(input_dim, projection_dim, dropout=dropout_p)
print("# Initializing embedding model with:", file=output)
print(f"\tprojection_dim: {projection_dim}", file=output)
print(f"\tdropout_p: {dropout_p}", file=output)
# Create contact model
hidden_dim = args.hidden_dim
kernel_width = args.kernel_width
print("# Initializing contact model with:", file=output)
print(f"\thidden_dim: {hidden_dim}", file=output)
print(f"\tkernel_width: {kernel_width}", file=output)
contact_model = ContactCNN(projection_dim, hidden_dim, kernel_width)
# Create the full model
do_w = not args.no_w
do_pool = args.do_pool
pool_width = args.pool_width
do_sigmoid = not args.no_sigmoid
print("# Initializing interaction model with:", file=output)
print(f"\tdo_poool: {do_pool}", file=output)
print(f"\tpool_width: {pool_width}", file=output)
print(f"\tdo_w: {do_w}", file=output)
print(f"\tdo_sigmoid: {do_sigmoid}", file=output)
model = ModelInteraction(
embedding_model,
contact_model,
use_cuda,
do_w=do_w,
pool_size=pool_width,
do_pool=do_pool,
do_sigmoid=do_sigmoid,
)
print(model, file=output)
else:
print("# Loading model from checkpoint {}".format(args.checkpoint), file=output)
model = torch.load(args.checkpoint)
model.use_cuda = use_cuda
if use_cuda:
model.cuda()
# Train the model
lr = args.lr
wd = args.weight_decay
num_epochs = args.num_epochs
batch_size = args.batch_size
inter_weight = args.lambda_
cmap_weight = 1 - inter_weight
digits = int(np.floor(np.log10(num_epochs))) + 1
save_prefix = args.save_prefix
params = [p for p in model.parameters() if p.requires_grad]
optim = torch.optim.Adam(params, lr=lr, weight_decay=wd)
print(f'# Using save prefix "{save_prefix}"', file=output)
print(f"# Training with Adam: lr={lr}, weight_decay={wd}", file=output)
print(f"\tnum_epochs: {num_epochs}", file=output)
print(f"\tbatch_size: {batch_size}", file=output)
print(f"\tinteraction weight: {inter_weight}", file=output)
print(f"\tcontact map weight: {cmap_weight}", file=output)
output.flush()
batch_report_fmt = (
"# [{}/{}] training {:.1%}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}"
)
epoch_report_fmt = "# Finished Epoch {}/{}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}, Precision={:.6}, Recall={:.6}, F1={:.6}, AUPR={:.6}"
N = len(train_iterator) * batch_size
for epoch in range(num_epochs):
model.train()
n = 0
loss_accum = 0
acc_accum = 0
mse_accum = 0
# Train batches
for (z0, z1, y) in train_iterator:
loss, correct, mse, b = interaction_grad(
model,
z0,
z1,
y,
embeddings,
use_cuda,
weight=inter_weight,
)
n += b
delta = b * (loss - loss_accum)
loss_accum += delta / n
delta = correct - b * acc_accum
acc_accum += delta / n
delta = b * (mse - mse_accum)
mse_accum += delta / n
report = (n - b) // 100 < n // 100
optim.step()
optim.zero_grad()
model.clip()
if report:
tokens = [
epoch + 1,
num_epochs,
n / N,
loss_accum,
acc_accum,
mse_accum,
]
print(batch_report_fmt.format(*tokens), file=output)
output.flush()
model.eval()
with torch.no_grad():
(
inter_loss,
inter_correct,
inter_mse,
inter_pr,
inter_re,
inter_f1,
inter_aupr,
) = interaction_eval(model, test_iterator, embeddings, use_cuda)
tokens = [
epoch + 1,
num_epochs,
inter_loss,
inter_correct / (len(test_iterator) * batch_size),
inter_mse,
inter_pr,
inter_re,
inter_f1,
inter_aupr,
]
print(epoch_report_fmt.format(*tokens), file=output)
output.flush()
# Save the model
if save_prefix is not None:
save_path = (
save_prefix + "_epoch" + str(epoch + 1).zfill(digits) + ".sav"
)
print(f"# Saving model to {save_path}", file=output)
model.cpu()
torch.save(model, save_path)
if use_cuda:
model.cuda()
output.flush()
if save_prefix is not None:
save_path = save_prefix + "_final.sav"
print(f"# Saving final model to {save_path}", file=output)
model.cpu()
torch.save(model, save_path)
if use_cuda:
model.cuda()
def main(args):
"""
Run training from arguments.
:meta private:
"""
output = args.output
if output is None:
output = sys.stdout
else:
output = open(output, "w")
print(f'# Called as: {" ".join(sys.argv)}', file=output)
if output is not sys.stdout:
print(f'Called as: {" ".join(sys.argv)}')
## Set the device
device = args.device
use_cuda = (device > -1) and torch.cuda.is_available()
if use_cuda:
torch.cuda.set_device(device)
print(
f"# Using CUDA device {device} - {torch.cuda.get_device_name(device)}",
file=output,
)
else:
print("# Using CPU", file=output)
device = "cpu"
train_model(args, output)
output.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
add_args(parser)
main(parser.parse_args())