--- title: TAGNN++ keywords: fastai sidebar: home_sidebar summary: "Mitheran et. al. Improved Representation Learning for Session-based Recommendation. arXiv, 2021." description: "Mitheran et. al. Improved Representation Learning for Session-based Recommendation. arXiv, 2021." nb_path: "nbs/models/models.tagnn_pp.ipynb" ---
class Args():
dataset = 'sample'
batchSize = 100 # input batch size
hiddenSize = 100 # hidden state size
epoch = 5 # the number of epochs to train for
lr = 0.001 # learning rate') # [0.001, 0.0005, 0.000
lr_dc = 0.1 # learning rate decay rate
lr_dc_step = 3 # the number of steps after which the learning rate decay
l2 = 1e-5 # l2 penalty') # [0.001, 0.0005, 0.0001, 0.00005, 0.0000
step = 1 # gnn propogation steps
patience = 10 # the number of epoch to wait before early stop
nonhybrid = True # only use the global preference to predict
validation = True # validation
valid_portion = 0.1 # split the portion of training set as validation set
n_node = 310
args = Args()
def to_cuda(input_variable):
if torch.cuda.is_available():
return input_variable.cuda()
else:
return input_variable
def to_cpu(input_variable):
if torch.cuda.is_available():
return input_variable.cpu()
else:
return input_variable
def forward(model, i, data):
alias_inputs, A, items, mask, targets = data.get_slice(i)
alias_inputs = to_cuda(torch.Tensor(alias_inputs).long())
items = to_cuda(torch.Tensor(items).long())
A = to_cuda(torch.Tensor(A).float())
mask = to_cuda(torch.Tensor(mask).long())
hidden = model(items, A)
def get(i): return hidden[i][alias_inputs[i]]
seq_hidden = torch.stack([get(i)
for i in torch.arange(len(alias_inputs)).long()])
return targets, model.compute_scores(seq_hidden, mask)
def train_test(model, train_data, test_data):
model.scheduler.step()
print('Start training: ', datetime.datetime.now())
model.train()
total_loss = 0.0
slices = train_data.generate_batch(model.batch_size)
from tqdm.notebook import tqdm
for i, j in tqdm(zip(slices, np.arange(len(slices))), total=len(slices)):
model.optimizer.zero_grad()
targets, scores = forward(model, i, train_data)
targets = to_cuda(torch.Tensor(targets).long())
loss = model.loss_function(scores, targets - 1)
loss.backward()
model.optimizer.step()
total_loss += loss.item()
if j % int(len(slices) / 5 + 1) == 0:
print('[%d/%d] Loss: %.4f' % (j, len(slices), loss.item()))
print('\tLoss Value:\t%.3f' % total_loss)
print('Start Prediction: ', datetime.datetime.now())
model.eval()
hit, mrr = [], []
slices = test_data.generate_batch(model.batch_size)
for i in slices:
targets, scores = forward(model, i, test_data)
sub_scores = scores.topk(20)[1]
sub_scores = to_cpu(sub_scores).detach().numpy()
for score, target, mask in zip(sub_scores, targets, test_data.mask):
hit.append(np.isin(target - 1, score))
if len(np.where(score == target - 1)[0]) == 0:
mrr.append(0)
else:
mrr.append(1 / (np.where(score == target - 1)[0][0] + 1))
hit = np.mean(hit) * 100
mrr = np.mean(mrr) * 100
return hit, mrr
def get_pos(seq_len):
return torch.arange(seq_len).unsqueeze(0)
def str2bool(v):
return v.lower() in ('true')
def split_validation(train_set, valid_portion):
train_set_x, train_set_y = train_set
n_samples = len(train_set_x)
sidx = np.arange(n_samples, dtype='int32')
np.random.shuffle(sidx)
n_train = int(np.round(n_samples * (1. - valid_portion)))
valid_set_x = [train_set_x[s] for s in sidx[n_train:]]
valid_set_y = [train_set_y[s] for s in sidx[n_train:]]
train_set_x = [train_set_x[s] for s in sidx[:n_train]]
train_set_y = [train_set_y[s] for s in sidx[:n_train]]
return (train_set_x, train_set_y), (valid_set_x, valid_set_y)
import os
import pickle
import time
from torch.utils.tensorboard import SummaryWriter
from recohut.datasets.session import SampleSessionDataset, GraphData
import warnings
warnings.filterwarnings('ignore')
model_save_dir = 'saved/'
log_dir='saved/logs'
writer = SummaryWriter(log_dir=log_dir)
os.makedirs(log_dir, exist_ok=True)
_ = SampleSessionDataset('./session_ds')
train_data = pickle.load(open('./session_ds/processed/train.txt', 'rb'))
if args.validation:
train_data, valid_data = split_validation(train_data, args.valid_portion)
test_data = valid_data
else:
test_data = pickle.load(open('./session_ds/processed/test.txt', 'rb'))
train_data = GraphData(train_data, shuffle=True)
test_data = GraphData(test_data, shuffle=False)
model = to_cuda(TAGNN_PP(args))
start = time.time()
best_result = [0, 0]
best_epoch = [0, 0]
bad_counter = 0
for epoch in range(args.epoch):
print('-' * 50)
print('Epoch: ', epoch)
hit, mrr = train_test(model, train_data, test_data)
flag = 0
# Logging
writer.add_scalar('epoch/recall', hit, epoch)
writer.add_scalar('epoch/mrr', mrr, epoch)
flag = 0
if hit >= best_result[0]:
best_result[0] = hit
best_epoch[0] = epoch
flag = 1
torch.save(model, model_save_dir + 'epoch_' +
str(epoch) + '_recall_' + str(hit) + '_.pt')
if mrr >= best_result[1]:
best_result[1] = mrr
best_epoch[1] = epoch
flag = 1
torch.save(model, model_save_dir + 'epoch_' +
str(epoch) + '_mrr_' + str(mrr) + '_.pt')
print('Best Result:')
print('\tRecall@20:\t%.4f\tMRR@20:\t%.4f\tEpoch:\t%d,\t%d' %
(best_result[0], best_result[1], best_epoch[0], best_epoch[1]))
bad_counter += 1 - flag
if bad_counter >= args.patience:
break
print('-' * 50)
end = time.time()
print("Running time: %f seconds" % (end - start))